andygrove commented on code in PR #1452:
URL: https://github.com/apache/datafusion-comet/pull/1452#discussion_r1975565255


##########
native/core/src/execution/shuffle/shuffle_writer.rs:
##########
@@ -829,54 +781,53 @@ impl PartitionBuffer {
                 });
             self.num_active_rows += end - start;
             repart_timer.stop();
+            start = end;
 
             if self.num_active_rows >= self.batch_size {
-                let flush = self.flush(metrics);
-                if let Err(e) = flush {
-                    return AppendRowStatus::MemDiff(Err(e));
-                }
-                mem_diff += flush.unwrap();
-
-                let init = self.init_active_if_necessary(metrics);
-                if init.is_err() {
-                    return AppendRowStatus::StartIndex(end);
-                }
-                mem_diff += init.unwrap();
+                self.flush(metrics)?;
             }
-            start = end;
         }
-        AppendRowStatus::MemDiff(Ok(mem_diff))
+        Ok(AppendRowStatus::Appended)
     }
 
     /// Flush active data into frozen bytes. This can reduce memory usage 
because the frozen
     /// bytes are compressed.
-    fn flush(&mut self, metrics: &ShuffleRepartitionerMetrics) -> 
Result<isize> {
+    fn flush(&mut self, metrics: &ShuffleRepartitionerMetrics) -> Result<()> {
         if self.num_active_rows == 0 {
-            return Ok(0);
+            return Ok(());
         }
-        let mut mem_diff = 0isize;
 
         // active -> staging
         let active = std::mem::take(&mut self.active);
         let num_rows = self.num_active_rows;
         self.num_active_rows = 0;
 
-        let mut mempool_timer = metrics.mempool_time.timer();
-        self.reservation.try_shrink(self.active_slots_mem_size)?;
-        mempool_timer.stop();
-
         let mut repart_timer = metrics.repart_time.timer();
         let frozen_batch = make_batch(Arc::clone(&self.schema), active, 
num_rows)?;
         repart_timer.stop();
 
-        let frozen_capacity_old = self.frozen.capacity();
         let mut cursor = Cursor::new(&mut self.frozen);
         cursor.seek(SeekFrom::End(0))?;
-        self.shuffle_block_writer
-            .write_batch(&frozen_batch, &mut cursor, &metrics.encode_time)?;
+        let bytes_written = self.shuffle_block_writer.write_batch(
+            &frozen_batch,
+            &mut cursor,
+            &metrics.encode_time,
+        )?;
+
+        // we typically expect the frozen bytes to take up less memory than
+        // the builders due to compression but there could be edge cases where
+        // this is not the case
+        let mut mempool_timer = metrics.mempool_time.timer();
+        if self.active_slots_mem_size >= bytes_written {
+            self.reservation
+                .try_shrink(self.active_slots_mem_size - bytes_written)?;
+        } else {
+            self.reservation
+                .try_grow(bytes_written - self.active_slots_mem_size)?;
+        }
+        mempool_timer.stop();
 
-        mem_diff += (self.frozen.capacity() - frozen_capacity_old) as isize;
-        Ok(mem_diff)

Review Comment:
   No need to return memory size because memory accounting already happened in 
this method.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to