This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new fd4775119 perf: refactor sum int with specialized implementations for 
each eval_mode (#3054)
fd4775119 is described below

commit fd477511957b291d505c8efec6f8360a43922992
Author: Andy Grove <[email protected]>
AuthorDate: Wed Feb 4 05:44:37 2026 -0700

    perf: refactor sum int with specialized implementations for each eval_mode 
(#3054)
---
 native/spark-expr/benches/aggregate.rs     | 182 +++++-
 native/spark-expr/src/agg_funcs/sum_int.rs | 851 ++++++++++++++++++++---------
 2 files changed, 759 insertions(+), 274 deletions(-)

diff --git a/native/spark-expr/benches/aggregate.rs 
b/native/spark-expr/benches/aggregate.rs
index 72628975b..47e2cf61c 100644
--- a/native/spark-expr/benches/aggregate.rs
+++ b/native/spark-expr/benches/aggregate.rs
@@ -15,8 +15,8 @@
 // specific language governing permissions and limitations
 // under the License.use arrow::array::{ArrayRef, BooleanBuilder, 
Int32Builder, RecordBatch, StringBuilder};
 
-use arrow::array::builder::{Decimal128Builder, StringBuilder};
-use arrow::array::{ArrayRef, RecordBatch};
+use arrow::array::builder::{Decimal128Builder, Int64Builder, StringBuilder};
+use arrow::array::{ArrayRef, Int64Array, RecordBatch};
 use arrow::datatypes::SchemaRef;
 use arrow::datatypes::{DataType, Field, Schema};
 use criterion::{criterion_group, criterion_main, Criterion};
@@ -25,14 +25,14 @@ use datafusion::datasource::source::DataSourceExec;
 use datafusion::execution::TaskContext;
 use datafusion::functions_aggregate::average::avg_udaf;
 use datafusion::functions_aggregate::sum::sum_udaf;
-use datafusion::logical_expr::AggregateUDF;
+use datafusion::logical_expr::function::AccumulatorArgs;
+use datafusion::logical_expr::{AggregateUDF, AggregateUDFImpl, EmitTo};
 use datafusion::physical_expr::aggregate::AggregateExprBuilder;
 use datafusion::physical_expr::expressions::Column;
 use datafusion::physical_expr::PhysicalExpr;
 use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode, 
PhysicalGroupBy};
 use datafusion::physical_plan::ExecutionPlan;
-use datafusion_comet_spark_expr::SumDecimal;
-use datafusion_comet_spark_expr::{AvgDecimal, EvalMode};
+use datafusion_comet_spark_expr::{AvgDecimal, EvalMode, SumDecimal, 
SumInteger};
 use futures::StreamExt;
 use std::hint::black_box;
 use std::sync::Arc;
@@ -111,6 +111,153 @@ fn criterion_benchmark(c: &mut Criterion) {
     });
 
     group.finish();
+
+    // SumInteger benchmarks
+    let mut group = c.benchmark_group("sum_integer");
+    let int_batch = create_int64_record_batch(num_rows);
+    let mut int_batches = Vec::new();
+    for _ in 0..10 {
+        int_batches.push(int_batch.clone());
+    }
+    let int_partitions = &[int_batches];
+    let int_c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
+    let int_c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));
+
+    group.bench_function("sum_int64_datafusion", |b| {
+        let datafusion_sum = sum_udaf();
+        b.to_async(&rt).iter(|| {
+            black_box(agg_test(
+                int_partitions,
+                int_c0.clone(),
+                int_c1.clone(),
+                datafusion_sum.clone(),
+                "sum",
+            ))
+        })
+    });
+
+    group.bench_function("sum_int64_comet_legacy", |b| {
+        let comet_sum = Arc::new(AggregateUDF::new_from_impl(
+            SumInteger::try_new(DataType::Int64, EvalMode::Legacy).unwrap(),
+        ));
+        b.to_async(&rt).iter(|| {
+            black_box(agg_test(
+                int_partitions,
+                int_c0.clone(),
+                int_c1.clone(),
+                comet_sum.clone(),
+                "sum",
+            ))
+        })
+    });
+
+    group.bench_function("sum_int64_comet_ansi", |b| {
+        let comet_sum = Arc::new(AggregateUDF::new_from_impl(
+            SumInteger::try_new(DataType::Int64, EvalMode::Ansi).unwrap(),
+        ));
+        b.to_async(&rt).iter(|| {
+            black_box(agg_test(
+                int_partitions,
+                int_c0.clone(),
+                int_c1.clone(),
+                comet_sum.clone(),
+                "sum",
+            ))
+        })
+    });
+
+    group.bench_function("sum_int64_comet_try", |b| {
+        let comet_sum = Arc::new(AggregateUDF::new_from_impl(
+            SumInteger::try_new(DataType::Int64, EvalMode::Try).unwrap(),
+        ));
+        b.to_async(&rt).iter(|| {
+            black_box(agg_test(
+                int_partitions,
+                int_c0.clone(),
+                int_c1.clone(),
+                comet_sum.clone(),
+                "sum",
+            ))
+        })
+    });
+
+    group.finish();
+
+    // Direct accumulator benchmarks (bypassing execution framework)
+    let mut group = c.benchmark_group("sum_integer_accumulator");
+    let int64_array: ArrayRef = 
Arc::new(Int64Array::from_iter_values(0..8192i64));
+    let arrays: Vec<ArrayRef> = vec![int64_array];
+
+    let return_field = Arc::new(Field::new("sum", DataType::Int64, true));
+    let schema = Schema::new(vec![Field::new("c0", DataType::Int64, true)]);
+    let expr_field = Arc::new(Field::new("c0", DataType::Int64, true));
+    let expr_fields: Vec<Arc<Field>> = vec![expr_field];
+
+    // Single-row Accumulator benchmarks
+    for (name, eval_mode) in [
+        ("row_legacy", EvalMode::Legacy),
+        ("row_ansi", EvalMode::Ansi),
+        ("row_try", EvalMode::Try),
+    ] {
+        let return_field = return_field.clone();
+        let expr_fields = expr_fields.clone();
+        group.bench_function(name, |b| {
+            let udf = SumInteger::try_new(DataType::Int64, eval_mode).unwrap();
+            b.iter(|| {
+                let acc_args = AccumulatorArgs {
+                    return_field: return_field.clone(),
+                    schema: &schema,
+                    ignore_nulls: false,
+                    order_bys: &[],
+                    name: "sum",
+                    is_distinct: false,
+                    is_reversed: false,
+                    exprs: &[],
+                    expr_fields: &expr_fields,
+                };
+                let mut acc = udf.accumulator(acc_args).unwrap();
+                for _ in 0..10 {
+                    acc.update_batch(&arrays).unwrap();
+                }
+                black_box(acc.evaluate().unwrap())
+            })
+        });
+    }
+
+    // GroupsAccumulator benchmarks
+    let group_indices: Vec<usize> = (0..8192).map(|i| i % 1024).collect();
+    for (name, eval_mode) in [
+        ("groups_legacy", EvalMode::Legacy),
+        ("groups_ansi", EvalMode::Ansi),
+        ("groups_try", EvalMode::Try),
+    ] {
+        let return_field = return_field.clone();
+        let expr_fields = expr_fields.clone();
+        group.bench_function(name, |b| {
+            let udf = SumInteger::try_new(DataType::Int64, eval_mode).unwrap();
+            b.iter(|| {
+                let acc_args = AccumulatorArgs {
+                    return_field: return_field.clone(),
+                    schema: &schema,
+                    ignore_nulls: false,
+                    order_bys: &[],
+                    name: "sum",
+                    is_distinct: false,
+                    is_reversed: false,
+                    exprs: &[],
+                    expr_fields: &expr_fields,
+                };
+                let mut acc = udf.create_groups_accumulator(acc_args).unwrap();
+                for _ in 0..10 {
+                    acc.update_batch(&arrays, &group_indices, None, 1024)
+                        .unwrap();
+                }
+                black_box(acc.evaluate(EmitTo::All).unwrap())
+            })
+        });
+    }
+
+    group.finish();
 }
 
 async fn agg_test(
@@ -187,6 +334,31 @@ fn create_record_batch(num_rows: usize) -> RecordBatch {
     RecordBatch::try_new(Arc::new(schema), columns).unwrap()
 }
 
+fn create_int64_record_batch(num_rows: usize) -> RecordBatch {
+    let mut int64_builder = Int64Builder::with_capacity(num_rows);
+    let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows * 
32);
+    for i in 0..num_rows {
+        int64_builder.append_value(i as i64);
+        string_builder.append_value(format!("group_{}", i % 1024));
+    }
+    let int64_array = Arc::new(int64_builder.finish());
+    let string_array = Arc::new(string_builder.finish());
+
+    let mut fields = vec![];
+    let mut columns: Vec<ArrayRef> = vec![];
+
+    // string column for grouping
+    fields.push(Field::new("c0", DataType::Utf8, false));
+    columns.push(string_array);
+
+    // int64 column for summing
+    fields.push(Field::new("c1", DataType::Int64, false));
+    columns.push(int64_array);
+
+    let schema = Schema::new(fields);
+    RecordBatch::try_new(Arc::new(schema), columns).unwrap()
+}
+
 fn config() -> Criterion {
     Criterion::default()
         .measurement_time(Duration::from_millis(500))
diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs 
b/native/spark-expr/src/agg_funcs/sum_int.rs
index d226c5ede..2ea07c743 100644
--- a/native/spark-expr/src/agg_funcs/sum_int.rs
+++ b/native/spark-expr/src/agg_funcs/sum_int.rs
@@ -69,7 +69,11 @@ impl AggregateUDFImpl for SumInteger {
     }
 
     fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult<Box<dyn 
Accumulator>> {
-        Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode)))
+        match self.eval_mode {
+            EvalMode::Legacy => 
Ok(Box::new(SumIntegerAccumulatorLegacy::new())),
+            EvalMode::Ansi => Ok(Box::new(SumIntegerAccumulatorAnsi::new())),
+            EvalMode::Try => Ok(Box::new(SumIntegerAccumulatorTry::new())),
+        }
     }
 
     fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<FieldRef>> {
@@ -91,7 +95,11 @@ impl AggregateUDFImpl for SumInteger {
         &self,
         _args: AccumulatorArgs,
     ) -> DFResult<Box<dyn GroupsAccumulator>> {
-        Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode)))
+        match self.eval_mode {
+            EvalMode::Legacy => 
Ok(Box::new(SumIntGroupsAccumulatorLegacy::new())),
+            EvalMode::Ansi => Ok(Box::new(SumIntGroupsAccumulatorAnsi::new())),
+            EvalMode::Try => Ok(Box::new(SumIntGroupsAccumulatorTry::new())),
+        }
     }
 
     fn reverse_expr(&self) -> ReversedUDAF {
@@ -100,39 +108,222 @@ impl AggregateUDFImpl for SumInteger {
 }
 
 #[derive(Debug)]
-struct SumIntegerAccumulator {
+struct SumIntegerAccumulatorLegacy {
     sum: Option<i64>,
-    eval_mode: EvalMode,
-    has_all_nulls: bool,
 }
 
-impl SumIntegerAccumulator {
-    fn new(eval_mode: EvalMode) -> Self {
-        if eval_mode == EvalMode::Try {
-            Self {
-                // Try mode starts with 0 (because if this is init to None we 
cant say if it is none due to all nulls or due to an overflow)
-                sum: Some(0),
-                has_all_nulls: true,
-                eval_mode,
+impl SumIntegerAccumulatorLegacy {
+    fn new() -> Self {
+        Self { sum: None }
+    }
+}
+
+impl Accumulator for SumIntegerAccumulatorLegacy {
+    fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
+        fn update_sum<T>(int_array: &PrimitiveArray<T>, mut sum: i64) -> 
DFResult<i64>
+        where
+            T: ArrowPrimitiveType,
+        {
+            for i in 0..int_array.len() {
+                if !int_array.is_null(i) {
+                    let v = int_array.value(i).to_i64().ok_or_else(|| {
+                        DataFusionError::Internal(format!(
+                            "Failed to convert value {:?} to i64",
+                            int_array.value(i)
+                        ))
+                    })?;
+                    sum = v.add_wrapping(sum);
+                }
+            }
+            Ok(sum)
+        }
+
+        let values = &values[0];
+        if values.len() == values.null_count() {
+            return Ok(());
+        }
+
+        let running_sum = self.sum.unwrap_or(0);
+        let sum = match values.data_type() {
+            DataType::Int64 => 
update_sum(as_primitive_array::<Int64Type>(values), running_sum)?,
+            DataType::Int32 => 
update_sum(as_primitive_array::<Int32Type>(values), running_sum)?,
+            DataType::Int16 => 
update_sum(as_primitive_array::<Int16Type>(values), running_sum)?,
+            DataType::Int8 => 
update_sum(as_primitive_array::<Int8Type>(values), running_sum)?,
+            _ => {
+                return Err(DataFusionError::Internal(format!(
+                    "unsupported data type: {:?}",
+                    values.data_type()
+                )));
             }
+        };
+        self.sum = Some(sum);
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> DFResult<ScalarValue> {
+        Ok(ScalarValue::Int64(self.sum))
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+    }
+
+    fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
+        Ok(vec![ScalarValue::Int64(self.sum)])
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
+        if states.len() != 1 {
+            return Err(DataFusionError::Internal(format!(
+                "Invalid state while merging batch. Expected 1 element but 
found {}",
+                states.len()
+            )));
+        }
+
+        let that_sum_array = states[0].as_primitive::<Int64Type>();
+        let that_sum = if that_sum_array.is_null(0) {
+            None
         } else {
-            Self {
-                sum: None,
-                has_all_nulls: false,
-                eval_mode,
+            Some(that_sum_array.value(0))
+        };
+
+        if that_sum.is_none() {
+            return Ok(());
+        }
+        if self.sum.is_none() {
+            self.sum = that_sum;
+            return Ok(());
+        }
+
+        self.sum = Some(self.sum.unwrap().add_wrapping(that_sum.unwrap()));
+        Ok(())
+    }
+}
+
+#[derive(Debug)]
+struct SumIntegerAccumulatorAnsi {
+    sum: Option<i64>,
+}
+
+impl SumIntegerAccumulatorAnsi {
+    fn new() -> Self {
+        Self { sum: None }
+    }
+}
+
+impl Accumulator for SumIntegerAccumulatorAnsi {
+    fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
+        fn update_sum<T>(int_array: &PrimitiveArray<T>, mut sum: i64) -> 
DFResult<i64>
+        where
+            T: ArrowPrimitiveType,
+        {
+            for i in 0..int_array.len() {
+                if !int_array.is_null(i) {
+                    let v = int_array.value(i).to_i64().ok_or_else(|| {
+                        DataFusionError::Internal(format!(
+                            "Failed to convert value {:?} to i64",
+                            int_array.value(i)
+                        ))
+                    })?;
+                    sum = v
+                        .add_checked(sum)
+                        .map_err(|_| 
DataFusionError::from(arithmetic_overflow_error("integer")))?;
+                }
+            }
+            Ok(sum)
+        }
+
+        let values = &values[0];
+        if values.len() == values.null_count() {
+            return Ok(());
+        }
+
+        let running_sum = self.sum.unwrap_or(0);
+        let sum = match values.data_type() {
+            DataType::Int64 => 
update_sum(as_primitive_array::<Int64Type>(values), running_sum)?,
+            DataType::Int32 => 
update_sum(as_primitive_array::<Int32Type>(values), running_sum)?,
+            DataType::Int16 => 
update_sum(as_primitive_array::<Int16Type>(values), running_sum)?,
+            DataType::Int8 => 
update_sum(as_primitive_array::<Int8Type>(values), running_sum)?,
+            _ => {
+                return Err(DataFusionError::Internal(format!(
+                    "unsupported data type: {:?}",
+                    values.data_type()
+                )));
             }
+        };
+        self.sum = Some(sum);
+        Ok(())
+    }
+
+    fn evaluate(&mut self) -> DFResult<ScalarValue> {
+        Ok(ScalarValue::Int64(self.sum))
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+    }
+
+    fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
+        Ok(vec![ScalarValue::Int64(self.sum)])
+    }
+
+    fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
+        if states.len() != 1 {
+            return Err(DataFusionError::Internal(format!(
+                "Invalid state while merging batch. Expected 1 element but 
found {}",
+                states.len()
+            )));
+        }
+
+        let that_sum_array = states[0].as_primitive::<Int64Type>();
+        let that_sum = if that_sum_array.is_null(0) {
+            None
+        } else {
+            Some(that_sum_array.value(0))
+        };
+
+        if that_sum.is_none() {
+            return Ok(());
+        }
+        if self.sum.is_none() {
+            self.sum = that_sum;
+            return Ok(());
+        }
+
+        self.sum = Some(
+            self.sum
+                .unwrap()
+                .add_checked(that_sum.unwrap())
+                .map_err(|_| 
DataFusionError::from(arithmetic_overflow_error("integer")))?,
+        );
+        Ok(())
+    }
+}
+
+#[derive(Debug)]
+struct SumIntegerAccumulatorTry {
+    sum: Option<i64>,
+    has_all_nulls: bool,
+}
+
+impl SumIntegerAccumulatorTry {
+    fn new() -> Self {
+        Self {
+            // Try mode starts with 0 (because if this is init to None we cant 
say if it is none due to all nulls or due to an overflow)
+            sum: Some(0),
+            has_all_nulls: true,
         }
     }
+
+    fn overflowed(&self) -> bool {
+        !self.has_all_nulls && self.sum.is_none()
+    }
 }
 
-impl Accumulator for SumIntegerAccumulator {
+impl Accumulator for SumIntegerAccumulatorTry {
     fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
-        // accumulator internal to add sum and return null sum (and has_nulls 
false) if there is an overflow in Try Eval mode
-        fn update_sum_internal<T>(
-            int_array: &PrimitiveArray<T>,
-            eval_mode: EvalMode,
-            mut sum: i64,
-        ) -> Result<Option<i64>, DataFusionError>
+        /// Returns Ok(Some(sum)) on success, Ok(None) on overflow
+        fn update_sum<T>(int_array: &PrimitiveArray<T>, mut sum: i64) -> 
DFResult<Option<i64>>
         where
             T: ArrowPrimitiveType,
         {
@@ -144,72 +335,41 @@ impl Accumulator for SumIntegerAccumulator {
                             int_array.value(i)
                         ))
                     })?;
-                    match eval_mode {
-                        EvalMode::Legacy => {
-                            sum = v.add_wrapping(sum);
-                        }
-                        EvalMode::Ansi | EvalMode::Try => {
-                            match v.add_checked(sum) {
-                                Ok(v) => sum = v,
-                                Err(_e) => {
-                                    return if eval_mode == EvalMode::Ansi {
-                                        
Err(DataFusionError::from(arithmetic_overflow_error(
-                                            "integer",
-                                        )))
-                                    } else {
-                                        Ok(None)
-                                    };
-                                }
-                            };
-                        }
+                    match v.add_checked(sum) {
+                        Ok(new_sum) => sum = new_sum,
+                        Err(_) => return Ok(None),
                     }
                 }
             }
             Ok(Some(sum))
         }
 
-        if self.eval_mode == EvalMode::Try && !self.has_all_nulls && 
self.sum.is_none() {
-            // we saw an overflow earlier (Try eval mode). Skip processing
+        // Skip if we already saw an overflow
+        if self.overflowed() {
             return Ok(());
         }
+
         let values = &values[0];
         if values.len() == values.null_count() {
-            Ok(())
-        } else {
-            // No nulls so there should be a non-null sum / null incase 
overflow in Try eval
-            let running_sum = self.sum.unwrap_or(0);
-            let sum = match values.data_type() {
-                DataType::Int64 => update_sum_internal(
-                    as_primitive_array::<Int64Type>(values),
-                    self.eval_mode,
-                    running_sum,
-                )?,
-                DataType::Int32 => update_sum_internal(
-                    as_primitive_array::<Int32Type>(values),
-                    self.eval_mode,
-                    running_sum,
-                )?,
-                DataType::Int16 => update_sum_internal(
-                    as_primitive_array::<Int16Type>(values),
-                    self.eval_mode,
-                    running_sum,
-                )?,
-                DataType::Int8 => update_sum_internal(
-                    as_primitive_array::<Int8Type>(values),
-                    self.eval_mode,
-                    running_sum,
-                )?,
-                _ => {
-                    return Err(DataFusionError::Internal(format!(
-                        "unsupported data type: {:?}",
-                        values.data_type()
-                    )));
-                }
-            };
-            self.sum = sum;
-            self.has_all_nulls = false;
-            Ok(())
+            return Ok(());
         }
+
+        let running_sum = self.sum.unwrap_or(0);
+        let sum = match values.data_type() {
+            DataType::Int64 => 
update_sum(as_primitive_array::<Int64Type>(values), running_sum)?,
+            DataType::Int32 => 
update_sum(as_primitive_array::<Int32Type>(values), running_sum)?,
+            DataType::Int16 => 
update_sum(as_primitive_array::<Int16Type>(values), running_sum)?,
+            DataType::Int8 => 
update_sum(as_primitive_array::<Int8Type>(values), running_sum)?,
+            _ => {
+                return Err(DataFusionError::Internal(format!(
+                    "unsupported data type: {:?}",
+                    values.data_type()
+                )));
+            }
+        };
+        self.sum = sum;
+        self.has_all_nulls = false;
+        Ok(())
     }
 
     fn evaluate(&mut self) -> DFResult<ScalarValue> {
@@ -225,26 +385,16 @@ impl Accumulator for SumIntegerAccumulator {
     }
 
     fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
-        if self.eval_mode == EvalMode::Try {
-            Ok(vec![
-                ScalarValue::Int64(self.sum),
-                ScalarValue::Boolean(Some(self.has_all_nulls)),
-            ])
-        } else {
-            Ok(vec![ScalarValue::Int64(self.sum)])
-        }
+        Ok(vec![
+            ScalarValue::Int64(self.sum),
+            ScalarValue::Boolean(Some(self.has_all_nulls)),
+        ])
     }
 
     fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
-        let expected_state_len = if self.eval_mode == EvalMode::Try {
-            2
-        } else {
-            1
-        };
-        if expected_state_len != states.len() {
+        if states.len() != 2 {
             return Err(DataFusionError::Internal(format!(
-                "Invalid state while merging batch. Expected {} elements but 
found {}",
-                expected_state_len,
+                "Invalid state while merging batch. Expected 2 elements but 
found {}",
                 states.len()
             )));
         }
@@ -255,94 +405,326 @@ impl Accumulator for SumIntegerAccumulator {
         } else {
             Some(that_sum_array.value(0))
         };
+        let that_has_all_nulls = states[1].as_boolean().value(0);
 
-        // Check for overflow for early termination
-        if self.eval_mode == EvalMode::Try {
-            let that_has_all_nulls = states[1].as_boolean().value(0);
-            let that_overflowed = !that_has_all_nulls && that_sum.is_none();
-            let this_overflowed = !self.has_all_nulls && self.sum.is_none();
-            if that_overflowed || this_overflowed {
+        let that_overflowed = !that_has_all_nulls && that_sum.is_none();
+        if that_overflowed || self.overflowed() {
+            self.sum = None;
+            self.has_all_nulls = false;
+            return Ok(());
+        }
+
+        if that_has_all_nulls {
+            return Ok(());
+        }
+        if self.has_all_nulls {
+            self.sum = that_sum;
+            self.has_all_nulls = false;
+            return Ok(());
+        }
+
+        // Both sides have non-null values
+        match self.sum.unwrap().add_checked(that_sum.unwrap()) {
+            Ok(v) => self.sum = Some(v),
+            Err(_) => {
                 self.sum = None;
                 self.has_all_nulls = false;
-                return Ok(());
             }
-            if that_has_all_nulls {
-                return Ok(());
+        }
+        Ok(())
+    }
+}
+
+struct SumIntGroupsAccumulatorLegacy {
+    sums: Vec<Option<i64>>,
+}
+
+impl SumIntGroupsAccumulatorLegacy {
+    fn new() -> Self {
+        Self { sums: Vec::new() }
+    }
+}
+
+impl GroupsAccumulator for SumIntGroupsAccumulatorLegacy {
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> DFResult<()> {
+        fn update_groups_sum<T>(
+            int_array: &PrimitiveArray<T>,
+            group_indices: &[usize],
+            sums: &mut [Option<i64>],
+        ) -> DFResult<()>
+        where
+            T: ArrowPrimitiveType,
+            T::Native: ArrowNativeType,
+        {
+            for (i, &group_index) in group_indices.iter().enumerate() {
+                if !int_array.is_null(i) {
+                    let v = int_array.value(i).to_i64().ok_or_else(|| {
+                        DataFusionError::Internal("Failed to convert value to 
i64".to_string())
+                    })?;
+                    sums[group_index] = 
Some(sums[group_index].unwrap_or(0).add_wrapping(v));
+                }
             }
-            if self.has_all_nulls {
-                self.sum = that_sum;
-                self.has_all_nulls = false;
-                return Ok(());
+            Ok(())
+        }
+
+        debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+        let values = &values[0];
+        self.sums.resize(total_num_groups, None);
+
+        match values.data_type() {
+            DataType::Int64 => update_groups_sum(
+                as_primitive_array::<Int64Type>(values),
+                group_indices,
+                &mut self.sums,
+            )?,
+            DataType::Int32 => update_groups_sum(
+                as_primitive_array::<Int32Type>(values),
+                group_indices,
+                &mut self.sums,
+            )?,
+            DataType::Int16 => update_groups_sum(
+                as_primitive_array::<Int16Type>(values),
+                group_indices,
+                &mut self.sums,
+            )?,
+            DataType::Int8 => update_groups_sum(
+                as_primitive_array::<Int8Type>(values),
+                group_indices,
+                &mut self.sums,
+            )?,
+            _ => {
+                return Err(DataFusionError::Internal(format!(
+                    "Unsupported data type for SumIntGroupsAccumulatorLegacy: 
{:?}",
+                    values.data_type()
+                )))
             }
-        } else {
-            if that_sum.is_none() {
-                return Ok(());
+        };
+        Ok(())
+    }
+
+    fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
+        match emit_to {
+            EmitTo::All => {
+                let result = Arc::new(Int64Array::from(std::mem::take(&mut 
self.sums))) as ArrayRef;
+                Ok(result)
             }
-            if self.sum.is_none() {
-                self.sum = that_sum;
-                return Ok(());
+            EmitTo::First(n) => {
+                let result = 
Arc::new(Int64Array::from(self.sums.drain(..n).collect::<Vec<_>>()))
+                    as ArrayRef;
+                Ok(result)
             }
         }
+    }
 
-        // safe to unwrap (since we checked nulls above) but handling error 
just in case state is corrupt
-        let left = self.sum.ok_or_else(|| {
-            DataFusionError::Internal(
-                "Invalid state in merging batch. Current batch's sum is 
None".to_string(),
-            )
-        })?;
-        let right = that_sum.ok_or_else(|| {
-            DataFusionError::Internal(
-                "Invalid state in merging batch. Incoming sum is 
None".to_string(),
-            )
-        })?;
+    fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
+        let sums = emit_to.take_needed(&mut self.sums);
+        Ok(vec![Arc::new(Int64Array::from(sums))])
+    }
 
-        match self.eval_mode {
-            EvalMode::Legacy => {
-                self.sum = Some(left.add_wrapping(right));
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> DFResult<()> {
+        debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+
+        if values.len() != 1 {
+            return Err(DataFusionError::Internal(format!(
+                "Invalid state while merging batch. Expected 1 element but 
found {}",
+                values.len()
+            )));
+        }
+        let that_sums = values[0].as_primitive::<Int64Type>();
+
+        self.sums.resize(total_num_groups, None);
+
+        for (idx, &group_index) in group_indices.iter().enumerate() {
+            if that_sums.is_null(idx) {
+                continue;
             }
-            EvalMode::Ansi | EvalMode::Try => match left.add_checked(right) {
-                Ok(v) => self.sum = Some(v),
-                Err(_) => {
-                    if self.eval_mode == EvalMode::Ansi {
-                        return 
Err(DataFusionError::from(arithmetic_overflow_error("integer")));
-                    } else {
-                        self.sum = None;
-                        self.has_all_nulls = false;
-                    }
+            let that_sum = that_sums.value(idx);
+
+            if self.sums[group_index].is_none() {
+                self.sums[group_index] = Some(that_sum);
+            } else {
+                self.sums[group_index] =
+                    
Some(self.sums[group_index].unwrap().add_wrapping(that_sum));
+            }
+        }
+        Ok(())
+    }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+    }
+}
+
+struct SumIntGroupsAccumulatorAnsi {
+    sums: Vec<Option<i64>>,
+}
+
+impl SumIntGroupsAccumulatorAnsi {
+    fn new() -> Self {
+        Self { sums: Vec::new() }
+    }
+}
+
+impl GroupsAccumulator for SumIntGroupsAccumulatorAnsi {
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> DFResult<()> {
+        fn update_groups_sum<T>(
+            int_array: &PrimitiveArray<T>,
+            group_indices: &[usize],
+            sums: &mut [Option<i64>],
+        ) -> DFResult<()>
+        where
+            T: ArrowPrimitiveType,
+            T::Native: ArrowNativeType,
+        {
+            for (i, &group_index) in group_indices.iter().enumerate() {
+                if !int_array.is_null(i) {
+                    let v = int_array.value(i).to_i64().ok_or_else(|| {
+                        DataFusionError::Internal("Failed to convert value to 
i64".to_string())
+                    })?;
+                    sums[group_index] =
+                        
Some(sums[group_index].unwrap_or(0).add_checked(v).map_err(|_| {
+                            
DataFusionError::from(arithmetic_overflow_error("integer"))
+                        })?);
                 }
-            },
+            }
+            Ok(())
+        }
+
+        debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+        let values = &values[0];
+        self.sums.resize(total_num_groups, None);
+
+        match values.data_type() {
+            DataType::Int64 => update_groups_sum(
+                as_primitive_array::<Int64Type>(values),
+                group_indices,
+                &mut self.sums,
+            )?,
+            DataType::Int32 => update_groups_sum(
+                as_primitive_array::<Int32Type>(values),
+                group_indices,
+                &mut self.sums,
+            )?,
+            DataType::Int16 => update_groups_sum(
+                as_primitive_array::<Int16Type>(values),
+                group_indices,
+                &mut self.sums,
+            )?,
+            DataType::Int8 => update_groups_sum(
+                as_primitive_array::<Int8Type>(values),
+                group_indices,
+                &mut self.sums,
+            )?,
+            _ => {
+                return Err(DataFusionError::Internal(format!(
+                    "Unsupported data type for SumIntGroupsAccumulatorAnsi: 
{:?}",
+                    values.data_type()
+                )))
+            }
+        };
+        Ok(())
+    }
+
+    fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
+        match emit_to {
+            EmitTo::All => {
+                let result = Arc::new(Int64Array::from(std::mem::take(&mut 
self.sums))) as ArrayRef;
+                Ok(result)
+            }
+            EmitTo::First(n) => {
+                let result = 
Arc::new(Int64Array::from(self.sums.drain(..n).collect::<Vec<_>>()))
+                    as ArrayRef;
+                Ok(result)
+            }
+        }
+    }
+
+    fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
+        let sums = emit_to.take_needed(&mut self.sums);
+        Ok(vec![Arc::new(Int64Array::from(sums))])
+    }
+
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> DFResult<()> {
+        debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+
+        if values.len() != 1 {
+            return Err(DataFusionError::Internal(format!(
+                "Invalid state while merging batch. Expected 1 element but 
found {}",
+                values.len()
+            )));
+        }
+        let that_sums = values[0].as_primitive::<Int64Type>();
+
+        self.sums.resize(total_num_groups, None);
+
+        for (idx, &group_index) in group_indices.iter().enumerate() {
+            if that_sums.is_null(idx) {
+                continue;
+            }
+            let that_sum = that_sums.value(idx);
+
+            if self.sums[group_index].is_none() {
+                self.sums[group_index] = Some(that_sum);
+            } else {
+                self.sums[group_index] = Some(
+                    self.sums[group_index]
+                        .unwrap()
+                        .add_checked(that_sum)
+                        .map_err(|_| 
DataFusionError::from(arithmetic_overflow_error("integer")))?,
+                );
+            }
         }
         Ok(())
     }
+
+    fn size(&self) -> usize {
+        std::mem::size_of_val(self)
+    }
 }
 
-struct SumIntGroupsAccumulator {
+struct SumIntGroupsAccumulatorTry {
     sums: Vec<Option<i64>>,
     has_all_nulls: Vec<bool>,
-    eval_mode: EvalMode,
 }
 
-impl SumIntGroupsAccumulator {
-    fn new(eval_mode: EvalMode) -> Self {
+impl SumIntGroupsAccumulatorTry {
+    fn new() -> Self {
         Self {
             sums: Vec::new(),
-            eval_mode,
             has_all_nulls: Vec::new(),
         }
     }
 
-    fn resize_helper(&mut self, total_num_groups: usize) {
-        if self.eval_mode == EvalMode::Try {
-            self.sums.resize(total_num_groups, Some(0));
-            self.has_all_nulls.resize(total_num_groups, true);
-        } else {
-            self.sums.resize(total_num_groups, None);
-            self.has_all_nulls.resize(total_num_groups, false);
-        }
+    fn group_overflowed(&self, group_index: usize) -> bool {
+        !self.has_all_nulls[group_index] && self.sums[group_index].is_none()
     }
 }
 
-impl GroupsAccumulator for SumIntGroupsAccumulator {
+impl GroupsAccumulator for SumIntGroupsAccumulatorTry {
     fn update_batch(
         &mut self,
         values: &[ArrayRef],
@@ -350,12 +732,11 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
         opt_filter: Option<&BooleanArray>,
         total_num_groups: usize,
     ) -> DFResult<()> {
-        fn update_groups_sum_internal<T>(
+        fn update_groups_sum<T>(
             int_array: &PrimitiveArray<T>,
             group_indices: &[usize],
             sums: &mut [Option<i64>],
             has_all_nulls: &mut [bool],
-            eval_mode: EvalMode,
         ) -> DFResult<()>
         where
             T: ArrowPrimitiveType,
@@ -363,39 +744,18 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
         {
             for (i, &group_index) in group_indices.iter().enumerate() {
                 if !int_array.is_null(i) {
-                    // there is an overflow in prev group in try eval. Skip 
processing
-                    if eval_mode == EvalMode::Try
-                        && !has_all_nulls[group_index]
-                        && sums[group_index].is_none()
-                    {
+                    // Skip if this group already overflowed
+                    if !has_all_nulls[group_index] && 
sums[group_index].is_none() {
                         continue;
                     }
                     let v = int_array.value(i).to_i64().ok_or_else(|| {
                         DataFusionError::Internal("Failed to convert value to 
i64".to_string())
                     })?;
-                    match eval_mode {
-                        EvalMode::Legacy => {
-                            sums[group_index] =
-                                
Some(sums[group_index].unwrap_or(0).add_wrapping(v));
-                        }
-                        EvalMode::Ansi | EvalMode::Try => {
-                            match 
sums[group_index].unwrap_or(0).add_checked(v) {
-                                Ok(new_sum) => {
-                                    sums[group_index] = Some(new_sum);
-                                }
-                                Err(_) => {
-                                    if eval_mode == EvalMode::Ansi {
-                                        return Err(DataFusionError::from(
-                                            
arithmetic_overflow_error("integer"),
-                                        ));
-                                    } else {
-                                        sums[group_index] = None;
-                                    }
-                                }
-                            };
-                        }
-                    }
-                    has_all_nulls[group_index] = false
+                    match sums[group_index].unwrap_or(0).add_checked(v) {
+                        Ok(new_sum) => sums[group_index] = Some(new_sum),
+                        Err(_) => sums[group_index] = None,
+                    };
+                    has_all_nulls[group_index] = false;
                 }
             }
             Ok(())
@@ -403,40 +763,37 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
 
         debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
         let values = &values[0];
-        self.resize_helper(total_num_groups);
+        self.sums.resize(total_num_groups, Some(0));
+        self.has_all_nulls.resize(total_num_groups, true);
 
         match values.data_type() {
-            DataType::Int64 => update_groups_sum_internal(
+            DataType::Int64 => update_groups_sum(
                 as_primitive_array::<Int64Type>(values),
                 group_indices,
                 &mut self.sums,
                 &mut self.has_all_nulls,
-                self.eval_mode,
             )?,
-            DataType::Int32 => update_groups_sum_internal(
+            DataType::Int32 => update_groups_sum(
                 as_primitive_array::<Int32Type>(values),
                 group_indices,
                 &mut self.sums,
                 &mut self.has_all_nulls,
-                self.eval_mode,
             )?,
-            DataType::Int16 => update_groups_sum_internal(
+            DataType::Int16 => update_groups_sum(
                 as_primitive_array::<Int16Type>(values),
                 group_indices,
                 &mut self.sums,
                 &mut self.has_all_nulls,
-                self.eval_mode,
             )?,
-            DataType::Int8 => update_groups_sum_internal(
+            DataType::Int8 => update_groups_sum(
                 as_primitive_array::<Int8Type>(values),
                 group_indices,
                 &mut self.sums,
                 &mut self.has_all_nulls,
-                self.eval_mode,
             )?,
             _ => {
                 return Err(DataFusionError::Internal(format!(
-                    "Unsupported data type for SumIntGroupsAccumulator: {:?}",
+                    "Unsupported data type for SumIntGroupsAccumulatorTry: 
{:?}",
                     values.data_type()
                 )))
             }
@@ -453,7 +810,6 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
                         .zip(self.has_all_nulls.iter())
                         .map(|(&sum, &is_null)| if is_null { None } else { sum 
}),
                 )) as ArrayRef;
-
                 self.sums.clear();
                 self.has_all_nulls.clear();
                 Ok(result)
@@ -472,16 +828,11 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
 
     fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
         let sums = emit_to.take_needed(&mut self.sums);
-
-        if self.eval_mode == EvalMode::Try {
-            let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls);
-            Ok(vec![
-                Arc::new(Int64Array::from(sums)),
-                Arc::new(BooleanArray::from(has_all_nulls)),
-            ])
-        } else {
-            Ok(vec![Arc::new(Int64Array::from(sums))])
-        }
+        let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls);
+        Ok(vec![
+            Arc::new(Int64Array::from(sums)),
+            Arc::new(BooleanArray::from(has_all_nulls)),
+        ])
     }
 
     fn merge_batch(
@@ -493,27 +844,17 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
     ) -> DFResult<()> {
         debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
 
-        let expected_state_len = if self.eval_mode == EvalMode::Try {
-            2
-        } else {
-            1
-        };
-        if expected_state_len != values.len() {
+        if values.len() != 2 {
             return Err(DataFusionError::Internal(format!(
-                "Invalid state while merging batch. Expected {} elements but 
found {}",
-                expected_state_len,
+                "Invalid state while merging batch. Expected 2 elements but 
found {}",
                 values.len()
             )));
         }
         let that_sums = values[0].as_primitive::<Int64Type>();
+        let that_has_all_nulls_array = values[1].as_boolean();
 
-        self.resize_helper(total_num_groups);
-
-        let that_sums_is_all_nulls = if self.eval_mode == EvalMode::Try {
-            Some(values[1].as_boolean())
-        } else {
-            None
-        };
+        self.sums.resize(total_num_groups, Some(0));
+        self.has_all_nulls.resize(total_num_groups, true);
 
         for (idx, &group_index) in group_indices.iter().enumerate() {
             let that_sum = if that_sums.is_null(idx) {
@@ -521,62 +862,34 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
             } else {
                 Some(that_sums.value(idx))
             };
+            let that_has_all_nulls = that_has_all_nulls_array.value(idx);
 
-            if self.eval_mode == EvalMode::Try {
-                let that_has_all_nulls = 
that_sums_is_all_nulls.unwrap().value(idx);
-
-                let that_overflowed = !that_has_all_nulls && 
that_sum.is_none();
-                let this_overflowed =
-                    !self.has_all_nulls[group_index] && 
self.sums[group_index].is_none();
-
-                if that_overflowed || this_overflowed {
-                    self.sums[group_index] = None;
-                    self.has_all_nulls[group_index] = false;
-                    continue;
-                }
-
-                if that_has_all_nulls {
-                    continue;
-                }
+            let that_overflowed = !that_has_all_nulls && that_sum.is_none();
+            if that_overflowed || self.group_overflowed(group_index) {
+                self.sums[group_index] = None;
+                self.has_all_nulls[group_index] = false;
+                continue;
+            }
 
-                if self.has_all_nulls[group_index] {
-                    self.sums[group_index] = that_sum;
-                    self.has_all_nulls[group_index] = false;
-                    continue;
-                }
-            } else {
-                if that_sum.is_none() {
-                    continue;
-                }
-                if self.sums[group_index].is_none() {
-                    self.sums[group_index] = that_sum;
-                    continue;
-                }
+            if that_has_all_nulls {
+                continue;
             }
 
-            // Both sides have non-null. Update sums now
-            let left = self.sums[group_index].unwrap();
-            let right = that_sum.unwrap();
+            if self.has_all_nulls[group_index] {
+                self.sums[group_index] = that_sum;
+                self.has_all_nulls[group_index] = false;
+                continue;
+            }
 
-            match self.eval_mode {
-                EvalMode::Legacy => {
-                    self.sums[group_index] = Some(left.add_wrapping(right));
-                }
-                EvalMode::Ansi | EvalMode::Try => {
-                    match left.add_checked(right) {
-                        Ok(v) => self.sums[group_index] = Some(v),
-                        Err(_) => {
-                            if self.eval_mode == EvalMode::Ansi {
-                                return 
Err(DataFusionError::from(arithmetic_overflow_error(
-                                    "integer",
-                                )));
-                            } else {
-                                // overflow. update flag accordingly
-                                self.sums[group_index] = None;
-                                self.has_all_nulls[group_index] = false;
-                            }
-                        }
-                    }
+            // Both sides have non-null values
+            match self.sums[group_index]
+                .unwrap()
+                .add_checked(that_sum.unwrap())
+            {
+                Ok(v) => self.sums[group_index] = Some(v),
+                Err(_) => {
+                    self.sums[group_index] = None;
+                    self.has_all_nulls[group_index] = false;
                 }
             }
         }


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to