aviralgarg05 commented on code in PR #20494:
URL: https://github.com/apache/datafusion/pull/20494#discussion_r2971546726
##########
datafusion/physical-plan/src/topk/mod.rs:
##########
@@ -884,6 +916,163 @@ impl TopKHeap {
Ok(Some(scalar_values))
}
}
+const I32_OFFSET_LIMIT: i64 = i32::MAX as i64;
+
+fn split_indices_by_i32_offsets(
+ record_batches: &[&RecordBatch],
+ all_indices: &[(usize, usize)],
+ max_rows_per_batch: usize,
+ max_offset: i64,
+) -> Result<Vec<Range<usize>>> {
+ if all_indices.is_empty() {
+ return Ok(Vec::new());
+ }
+
+ let var_width_columns =
+ collect_var_width_columns(record_batches.first().ok_or_else(|| {
+ internal_datafusion_err!("Missing record batches for TopK
interleave")
+ })?);
+
+ if var_width_columns.is_empty() {
+ return Ok(split_indices_by_row_count(
+ all_indices.len(),
+ max_rows_per_batch,
+ ));
+ }
+
+ let mut ranges = Vec::new();
+ let mut start = 0;
+ let mut totals = vec![0_i64; var_width_columns.len()];
+
+ for (pos, (batch_pos, row_index)) in all_indices.iter().enumerate() {
+ if pos - start >= max_rows_per_batch {
+ ranges.push(start..pos);
+ start = pos;
+ totals.fill(0);
+ }
+
+ let batch = record_batches.get(*batch_pos).ok_or_else(|| {
+ internal_datafusion_err!("Invalid batch position in TopK indices")
+ })?;
+
+ let mut row_sizes = Vec::with_capacity(var_width_columns.len());
+ for column in &var_width_columns {
+ let array = batch.column(column.column_index);
+ let size = column.row_size(array.as_ref(), *row_index)?;
+ if size > max_offset {
+ return internal_err!(
+ "TopK row requires {size} offsets which exceeds i32::MAX"
+ );
+ }
+ row_sizes.push(size);
+ }
+
+ if totals
+ .iter()
+ .zip(row_sizes.iter())
+ .any(|(total, size)| total + size > max_offset)
+ {
+ ranges.push(start..pos);
+ start = pos;
+ totals.fill(0);
+ }
+
+ for (total, size) in totals.iter_mut().zip(row_sizes.iter()) {
+ *total += *size;
+ }
+ }
+
+ if start < all_indices.len() {
+ ranges.push(start..all_indices.len());
+ }
+
+ Ok(ranges)
+}
+
+fn split_indices_by_row_count(
+ total_rows: usize,
+ max_rows_per_batch: usize,
+) -> Vec<Range<usize>> {
+ let mut ranges = Vec::new();
+ let mut start = 0;
+ let max_rows_per_batch = max_rows_per_batch.max(1);
+ while start < total_rows {
+ let end = (start + max_rows_per_batch).min(total_rows);
+ ranges.push(start..end);
+ start = end;
+ }
+ ranges
+}
+
+fn collect_var_width_columns(batch: &RecordBatch) -> Vec<VarWidthColumn> {
+ batch
+ .columns()
+ .iter()
+ .enumerate()
+ .filter_map(|(index, array)| VarWidthColumn::new(index,
array.data_type()))
+ .collect()
+}
+
+struct VarWidthColumn {
+ column_index: usize,
+ kind: VarWidthKind,
+}
+
+impl VarWidthColumn {
+ fn new(column_index: usize, data_type: &DataType) -> Option<Self> {
+ let kind = match data_type {
+ DataType::Utf8 => VarWidthKind::Utf8,
+ DataType::Binary => VarWidthKind::Binary,
+ DataType::List(_) => VarWidthKind::List,
+ DataType::Map(_, _) => VarWidthKind::Map,
+ _ => return None,
+ };
+
+ Some(Self { column_index, kind })
+ }
+
+ fn row_size(&self, array: &dyn Array, row: usize) -> Result<i64> {
+ let size = match self.kind {
+ VarWidthKind::Utf8 => array
+ .as_any()
+ .downcast_ref::<StringArray>()
+ .ok_or_else(|| {
+ internal_datafusion_err!("Expected Utf8 array for TopK
interleave")
+ })?
+ .value_length(row) as i64,
+ VarWidthKind::Binary => array
+ .as_any()
+ .downcast_ref::<BinaryArray>()
+ .ok_or_else(|| {
+ internal_datafusion_err!("Expected Binary array for TopK
interleave")
+ })?
+ .value_length(row) as i64,
+ VarWidthKind::List => array
+ .as_any()
+ .downcast_ref::<ListArray>()
+ .ok_or_else(|| {
+ internal_datafusion_err!("Expected List array for TopK
interleave")
+ })?
+ .value_length(row) as i64,
+ VarWidthKind::Map => array
+ .as_any()
+ .downcast_ref::<MapArray>()
+ .ok_or_else(|| {
+ internal_datafusion_err!("Expected Map array for TopK
interleave")
+ })?
+ .value_length(row) as i64,
Review Comment:
For `List`/`Map` this is the *element count* in that row's list/map, which
is what Arrow accumulates in its i32 offset buffer for those types. The
comparison threshold is still `i32::MAX` in both cases — just different units.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]