andygrove commented on code in PR #1452: URL: https://github.com/apache/datafusion-comet/pull/1452#discussion_r1975565255
########## native/core/src/execution/shuffle/shuffle_writer.rs: ########## @@ -829,54 +781,53 @@ impl PartitionBuffer { }); self.num_active_rows += end - start; repart_timer.stop(); + start = end; if self.num_active_rows >= self.batch_size { - let flush = self.flush(metrics); - if let Err(e) = flush { - return AppendRowStatus::MemDiff(Err(e)); - } - mem_diff += flush.unwrap(); - - let init = self.init_active_if_necessary(metrics); - if init.is_err() { - return AppendRowStatus::StartIndex(end); - } - mem_diff += init.unwrap(); + self.flush(metrics)?; } - start = end; } - AppendRowStatus::MemDiff(Ok(mem_diff)) + Ok(AppendRowStatus::Appended) } /// Flush active data into frozen bytes. This can reduce memory usage because the frozen /// bytes are compressed. - fn flush(&mut self, metrics: &ShuffleRepartitionerMetrics) -> Result<isize> { + fn flush(&mut self, metrics: &ShuffleRepartitionerMetrics) -> Result<()> { if self.num_active_rows == 0 { - return Ok(0); + return Ok(()); } - let mut mem_diff = 0isize; // active -> staging let active = std::mem::take(&mut self.active); let num_rows = self.num_active_rows; self.num_active_rows = 0; - let mut mempool_timer = metrics.mempool_time.timer(); - self.reservation.try_shrink(self.active_slots_mem_size)?; - mempool_timer.stop(); - let mut repart_timer = metrics.repart_time.timer(); let frozen_batch = make_batch(Arc::clone(&self.schema), active, num_rows)?; repart_timer.stop(); - let frozen_capacity_old = self.frozen.capacity(); let mut cursor = Cursor::new(&mut self.frozen); cursor.seek(SeekFrom::End(0))?; - self.shuffle_block_writer - .write_batch(&frozen_batch, &mut cursor, &metrics.encode_time)?; + let bytes_written = self.shuffle_block_writer.write_batch( + &frozen_batch, + &mut cursor, + &metrics.encode_time, + )?; + + // we typically expect the frozen bytes to take up less memory than + // the builders due to compression but there could be edge cases where + // this is not the case + let mut mempool_timer = metrics.mempool_time.timer(); + if self.active_slots_mem_size >= bytes_written { + self.reservation + .try_shrink(self.active_slots_mem_size - bytes_written)?; + } else { + self.reservation + .try_grow(bytes_written - self.active_slots_mem_size)?; + } + mempool_timer.stop(); - mem_diff += (self.frozen.capacity() - frozen_capacity_old) as isize; - Ok(mem_diff) Review Comment: No need to return memory size because memory accounting already happened in this method. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org