This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new fd4775119 perf: refactor sum int with specialized implementations for
each eval_mode (#3054)
fd4775119 is described below
commit fd477511957b291d505c8efec6f8360a43922992
Author: Andy Grove <[email protected]>
AuthorDate: Wed Feb 4 05:44:37 2026 -0700
perf: refactor sum int with specialized implementations for each eval_mode
(#3054)
---
native/spark-expr/benches/aggregate.rs | 182 +++++-
native/spark-expr/src/agg_funcs/sum_int.rs | 851 ++++++++++++++++++++---------
2 files changed, 759 insertions(+), 274 deletions(-)
diff --git a/native/spark-expr/benches/aggregate.rs
b/native/spark-expr/benches/aggregate.rs
index 72628975b..47e2cf61c 100644
--- a/native/spark-expr/benches/aggregate.rs
+++ b/native/spark-expr/benches/aggregate.rs
@@ -15,8 +15,8 @@
// specific language governing permissions and limitations
// under the License.use arrow::array::{ArrayRef, BooleanBuilder,
Int32Builder, RecordBatch, StringBuilder};
-use arrow::array::builder::{Decimal128Builder, StringBuilder};
-use arrow::array::{ArrayRef, RecordBatch};
+use arrow::array::builder::{Decimal128Builder, Int64Builder, StringBuilder};
+use arrow::array::{ArrayRef, Int64Array, RecordBatch};
use arrow::datatypes::SchemaRef;
use arrow::datatypes::{DataType, Field, Schema};
use criterion::{criterion_group, criterion_main, Criterion};
@@ -25,14 +25,14 @@ use datafusion::datasource::source::DataSourceExec;
use datafusion::execution::TaskContext;
use datafusion::functions_aggregate::average::avg_udaf;
use datafusion::functions_aggregate::sum::sum_udaf;
-use datafusion::logical_expr::AggregateUDF;
+use datafusion::logical_expr::function::AccumulatorArgs;
+use datafusion::logical_expr::{AggregateUDF, AggregateUDFImpl, EmitTo};
use datafusion::physical_expr::aggregate::AggregateExprBuilder;
use datafusion::physical_expr::expressions::Column;
use datafusion::physical_expr::PhysicalExpr;
use datafusion::physical_plan::aggregates::{AggregateExec, AggregateMode,
PhysicalGroupBy};
use datafusion::physical_plan::ExecutionPlan;
-use datafusion_comet_spark_expr::SumDecimal;
-use datafusion_comet_spark_expr::{AvgDecimal, EvalMode};
+use datafusion_comet_spark_expr::{AvgDecimal, EvalMode, SumDecimal,
SumInteger};
use futures::StreamExt;
use std::hint::black_box;
use std::sync::Arc;
@@ -111,6 +111,153 @@ fn criterion_benchmark(c: &mut Criterion) {
});
group.finish();
+
+ // SumInteger benchmarks
+ let mut group = c.benchmark_group("sum_integer");
+ let int_batch = create_int64_record_batch(num_rows);
+ let mut int_batches = Vec::new();
+ for _ in 0..10 {
+ int_batches.push(int_batch.clone());
+ }
+ let int_partitions = &[int_batches];
+ let int_c0: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c0", 0));
+ let int_c1: Arc<dyn PhysicalExpr> = Arc::new(Column::new("c1", 1));
+
+ group.bench_function("sum_int64_datafusion", |b| {
+ let datafusion_sum = sum_udaf();
+ b.to_async(&rt).iter(|| {
+ black_box(agg_test(
+ int_partitions,
+ int_c0.clone(),
+ int_c1.clone(),
+ datafusion_sum.clone(),
+ "sum",
+ ))
+ })
+ });
+
+ group.bench_function("sum_int64_comet_legacy", |b| {
+ let comet_sum = Arc::new(AggregateUDF::new_from_impl(
+ SumInteger::try_new(DataType::Int64, EvalMode::Legacy).unwrap(),
+ ));
+ b.to_async(&rt).iter(|| {
+ black_box(agg_test(
+ int_partitions,
+ int_c0.clone(),
+ int_c1.clone(),
+ comet_sum.clone(),
+ "sum",
+ ))
+ })
+ });
+
+ group.bench_function("sum_int64_comet_ansi", |b| {
+ let comet_sum = Arc::new(AggregateUDF::new_from_impl(
+ SumInteger::try_new(DataType::Int64, EvalMode::Ansi).unwrap(),
+ ));
+ b.to_async(&rt).iter(|| {
+ black_box(agg_test(
+ int_partitions,
+ int_c0.clone(),
+ int_c1.clone(),
+ comet_sum.clone(),
+ "sum",
+ ))
+ })
+ });
+
+ group.bench_function("sum_int64_comet_try", |b| {
+ let comet_sum = Arc::new(AggregateUDF::new_from_impl(
+ SumInteger::try_new(DataType::Int64, EvalMode::Try).unwrap(),
+ ));
+ b.to_async(&rt).iter(|| {
+ black_box(agg_test(
+ int_partitions,
+ int_c0.clone(),
+ int_c1.clone(),
+ comet_sum.clone(),
+ "sum",
+ ))
+ })
+ });
+
+ group.finish();
+
+ // Direct accumulator benchmarks (bypassing execution framework)
+ let mut group = c.benchmark_group("sum_integer_accumulator");
+ let int64_array: ArrayRef =
Arc::new(Int64Array::from_iter_values(0..8192i64));
+ let arrays: Vec<ArrayRef> = vec![int64_array];
+
+ let return_field = Arc::new(Field::new("sum", DataType::Int64, true));
+ let schema = Schema::new(vec![Field::new("c0", DataType::Int64, true)]);
+ let expr_field = Arc::new(Field::new("c0", DataType::Int64, true));
+ let expr_fields: Vec<Arc<Field>> = vec![expr_field];
+
+ // Single-row Accumulator benchmarks
+ for (name, eval_mode) in [
+ ("row_legacy", EvalMode::Legacy),
+ ("row_ansi", EvalMode::Ansi),
+ ("row_try", EvalMode::Try),
+ ] {
+ let return_field = return_field.clone();
+ let expr_fields = expr_fields.clone();
+ group.bench_function(name, |b| {
+ let udf = SumInteger::try_new(DataType::Int64, eval_mode).unwrap();
+ b.iter(|| {
+ let acc_args = AccumulatorArgs {
+ return_field: return_field.clone(),
+ schema: &schema,
+ ignore_nulls: false,
+ order_bys: &[],
+ name: "sum",
+ is_distinct: false,
+ is_reversed: false,
+ exprs: &[],
+ expr_fields: &expr_fields,
+ };
+ let mut acc = udf.accumulator(acc_args).unwrap();
+ for _ in 0..10 {
+ acc.update_batch(&arrays).unwrap();
+ }
+ black_box(acc.evaluate().unwrap())
+ })
+ });
+ }
+
+ // GroupsAccumulator benchmarks
+ let group_indices: Vec<usize> = (0..8192).map(|i| i % 1024).collect();
+ for (name, eval_mode) in [
+ ("groups_legacy", EvalMode::Legacy),
+ ("groups_ansi", EvalMode::Ansi),
+ ("groups_try", EvalMode::Try),
+ ] {
+ let return_field = return_field.clone();
+ let expr_fields = expr_fields.clone();
+ group.bench_function(name, |b| {
+ let udf = SumInteger::try_new(DataType::Int64, eval_mode).unwrap();
+ b.iter(|| {
+ let acc_args = AccumulatorArgs {
+ return_field: return_field.clone(),
+ schema: &schema,
+ ignore_nulls: false,
+ order_bys: &[],
+ name: "sum",
+ is_distinct: false,
+ is_reversed: false,
+ exprs: &[],
+ expr_fields: &expr_fields,
+ };
+ let mut acc = udf.create_groups_accumulator(acc_args).unwrap();
+ for _ in 0..10 {
+ acc.update_batch(&arrays, &group_indices, None, 1024)
+ .unwrap();
+ }
+ black_box(acc.evaluate(EmitTo::All).unwrap())
+ })
+ });
+ }
+
+ group.finish();
}
async fn agg_test(
@@ -187,6 +334,31 @@ fn create_record_batch(num_rows: usize) -> RecordBatch {
RecordBatch::try_new(Arc::new(schema), columns).unwrap()
}
+fn create_int64_record_batch(num_rows: usize) -> RecordBatch {
+ let mut int64_builder = Int64Builder::with_capacity(num_rows);
+ let mut string_builder = StringBuilder::with_capacity(num_rows, num_rows *
32);
+ for i in 0..num_rows {
+ int64_builder.append_value(i as i64);
+ string_builder.append_value(format!("group_{}", i % 1024));
+ }
+ let int64_array = Arc::new(int64_builder.finish());
+ let string_array = Arc::new(string_builder.finish());
+
+ let mut fields = vec![];
+ let mut columns: Vec<ArrayRef> = vec![];
+
+ // string column for grouping
+ fields.push(Field::new("c0", DataType::Utf8, false));
+ columns.push(string_array);
+
+ // int64 column for summing
+ fields.push(Field::new("c1", DataType::Int64, false));
+ columns.push(int64_array);
+
+ let schema = Schema::new(fields);
+ RecordBatch::try_new(Arc::new(schema), columns).unwrap()
+}
+
fn config() -> Criterion {
Criterion::default()
.measurement_time(Duration::from_millis(500))
diff --git a/native/spark-expr/src/agg_funcs/sum_int.rs
b/native/spark-expr/src/agg_funcs/sum_int.rs
index d226c5ede..2ea07c743 100644
--- a/native/spark-expr/src/agg_funcs/sum_int.rs
+++ b/native/spark-expr/src/agg_funcs/sum_int.rs
@@ -69,7 +69,11 @@ impl AggregateUDFImpl for SumInteger {
}
fn accumulator(&self, _acc_args: AccumulatorArgs) -> DFResult<Box<dyn
Accumulator>> {
- Ok(Box::new(SumIntegerAccumulator::new(self.eval_mode)))
+ match self.eval_mode {
+ EvalMode::Legacy =>
Ok(Box::new(SumIntegerAccumulatorLegacy::new())),
+ EvalMode::Ansi => Ok(Box::new(SumIntegerAccumulatorAnsi::new())),
+ EvalMode::Try => Ok(Box::new(SumIntegerAccumulatorTry::new())),
+ }
}
fn state_fields(&self, _args: StateFieldsArgs) -> DFResult<Vec<FieldRef>> {
@@ -91,7 +95,11 @@ impl AggregateUDFImpl for SumInteger {
&self,
_args: AccumulatorArgs,
) -> DFResult<Box<dyn GroupsAccumulator>> {
- Ok(Box::new(SumIntGroupsAccumulator::new(self.eval_mode)))
+ match self.eval_mode {
+ EvalMode::Legacy =>
Ok(Box::new(SumIntGroupsAccumulatorLegacy::new())),
+ EvalMode::Ansi => Ok(Box::new(SumIntGroupsAccumulatorAnsi::new())),
+ EvalMode::Try => Ok(Box::new(SumIntGroupsAccumulatorTry::new())),
+ }
}
fn reverse_expr(&self) -> ReversedUDAF {
@@ -100,39 +108,222 @@ impl AggregateUDFImpl for SumInteger {
}
#[derive(Debug)]
-struct SumIntegerAccumulator {
+struct SumIntegerAccumulatorLegacy {
sum: Option<i64>,
- eval_mode: EvalMode,
- has_all_nulls: bool,
}
-impl SumIntegerAccumulator {
- fn new(eval_mode: EvalMode) -> Self {
- if eval_mode == EvalMode::Try {
- Self {
- // Try mode starts with 0 (because if this is init to None we
cant say if it is none due to all nulls or due to an overflow)
- sum: Some(0),
- has_all_nulls: true,
- eval_mode,
+impl SumIntegerAccumulatorLegacy {
+ fn new() -> Self {
+ Self { sum: None }
+ }
+}
+
+impl Accumulator for SumIntegerAccumulatorLegacy {
+ fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
+ fn update_sum<T>(int_array: &PrimitiveArray<T>, mut sum: i64) ->
DFResult<i64>
+ where
+ T: ArrowPrimitiveType,
+ {
+ for i in 0..int_array.len() {
+ if !int_array.is_null(i) {
+ let v = int_array.value(i).to_i64().ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "Failed to convert value {:?} to i64",
+ int_array.value(i)
+ ))
+ })?;
+ sum = v.add_wrapping(sum);
+ }
+ }
+ Ok(sum)
+ }
+
+ let values = &values[0];
+ if values.len() == values.null_count() {
+ return Ok(());
+ }
+
+ let running_sum = self.sum.unwrap_or(0);
+ let sum = match values.data_type() {
+ DataType::Int64 =>
update_sum(as_primitive_array::<Int64Type>(values), running_sum)?,
+ DataType::Int32 =>
update_sum(as_primitive_array::<Int32Type>(values), running_sum)?,
+ DataType::Int16 =>
update_sum(as_primitive_array::<Int16Type>(values), running_sum)?,
+ DataType::Int8 =>
update_sum(as_primitive_array::<Int8Type>(values), running_sum)?,
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "unsupported data type: {:?}",
+ values.data_type()
+ )));
}
+ };
+ self.sum = Some(sum);
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> DFResult<ScalarValue> {
+ Ok(ScalarValue::Int64(self.sum))
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+
+ fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
+ Ok(vec![ScalarValue::Int64(self.sum)])
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
+ if states.len() != 1 {
+ return Err(DataFusionError::Internal(format!(
+ "Invalid state while merging batch. Expected 1 element but
found {}",
+ states.len()
+ )));
+ }
+
+ let that_sum_array = states[0].as_primitive::<Int64Type>();
+ let that_sum = if that_sum_array.is_null(0) {
+ None
} else {
- Self {
- sum: None,
- has_all_nulls: false,
- eval_mode,
+ Some(that_sum_array.value(0))
+ };
+
+ if that_sum.is_none() {
+ return Ok(());
+ }
+ if self.sum.is_none() {
+ self.sum = that_sum;
+ return Ok(());
+ }
+
+ self.sum = Some(self.sum.unwrap().add_wrapping(that_sum.unwrap()));
+ Ok(())
+ }
+}
+
+#[derive(Debug)]
+struct SumIntegerAccumulatorAnsi {
+ sum: Option<i64>,
+}
+
+impl SumIntegerAccumulatorAnsi {
+ fn new() -> Self {
+ Self { sum: None }
+ }
+}
+
+impl Accumulator for SumIntegerAccumulatorAnsi {
+ fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
+ fn update_sum<T>(int_array: &PrimitiveArray<T>, mut sum: i64) ->
DFResult<i64>
+ where
+ T: ArrowPrimitiveType,
+ {
+ for i in 0..int_array.len() {
+ if !int_array.is_null(i) {
+ let v = int_array.value(i).to_i64().ok_or_else(|| {
+ DataFusionError::Internal(format!(
+ "Failed to convert value {:?} to i64",
+ int_array.value(i)
+ ))
+ })?;
+ sum = v
+ .add_checked(sum)
+ .map_err(|_|
DataFusionError::from(arithmetic_overflow_error("integer")))?;
+ }
+ }
+ Ok(sum)
+ }
+
+ let values = &values[0];
+ if values.len() == values.null_count() {
+ return Ok(());
+ }
+
+ let running_sum = self.sum.unwrap_or(0);
+ let sum = match values.data_type() {
+ DataType::Int64 =>
update_sum(as_primitive_array::<Int64Type>(values), running_sum)?,
+ DataType::Int32 =>
update_sum(as_primitive_array::<Int32Type>(values), running_sum)?,
+ DataType::Int16 =>
update_sum(as_primitive_array::<Int16Type>(values), running_sum)?,
+ DataType::Int8 =>
update_sum(as_primitive_array::<Int8Type>(values), running_sum)?,
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "unsupported data type: {:?}",
+ values.data_type()
+ )));
}
+ };
+ self.sum = Some(sum);
+ Ok(())
+ }
+
+ fn evaluate(&mut self) -> DFResult<ScalarValue> {
+ Ok(ScalarValue::Int64(self.sum))
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+
+ fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
+ Ok(vec![ScalarValue::Int64(self.sum)])
+ }
+
+ fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
+ if states.len() != 1 {
+ return Err(DataFusionError::Internal(format!(
+ "Invalid state while merging batch. Expected 1 element but
found {}",
+ states.len()
+ )));
+ }
+
+ let that_sum_array = states[0].as_primitive::<Int64Type>();
+ let that_sum = if that_sum_array.is_null(0) {
+ None
+ } else {
+ Some(that_sum_array.value(0))
+ };
+
+ if that_sum.is_none() {
+ return Ok(());
+ }
+ if self.sum.is_none() {
+ self.sum = that_sum;
+ return Ok(());
+ }
+
+ self.sum = Some(
+ self.sum
+ .unwrap()
+ .add_checked(that_sum.unwrap())
+ .map_err(|_|
DataFusionError::from(arithmetic_overflow_error("integer")))?,
+ );
+ Ok(())
+ }
+}
+
+#[derive(Debug)]
+struct SumIntegerAccumulatorTry {
+ sum: Option<i64>,
+ has_all_nulls: bool,
+}
+
+impl SumIntegerAccumulatorTry {
+ fn new() -> Self {
+ Self {
+ // Try mode starts with 0 (because if this is init to None we cant
say if it is none due to all nulls or due to an overflow)
+ sum: Some(0),
+ has_all_nulls: true,
}
}
+
+ fn overflowed(&self) -> bool {
+ !self.has_all_nulls && self.sum.is_none()
+ }
}
-impl Accumulator for SumIntegerAccumulator {
+impl Accumulator for SumIntegerAccumulatorTry {
fn update_batch(&mut self, values: &[ArrayRef]) -> DFResult<()> {
- // accumulator internal to add sum and return null sum (and has_nulls
false) if there is an overflow in Try Eval mode
- fn update_sum_internal<T>(
- int_array: &PrimitiveArray<T>,
- eval_mode: EvalMode,
- mut sum: i64,
- ) -> Result<Option<i64>, DataFusionError>
+ /// Returns Ok(Some(sum)) on success, Ok(None) on overflow
+ fn update_sum<T>(int_array: &PrimitiveArray<T>, mut sum: i64) ->
DFResult<Option<i64>>
where
T: ArrowPrimitiveType,
{
@@ -144,72 +335,41 @@ impl Accumulator for SumIntegerAccumulator {
int_array.value(i)
))
})?;
- match eval_mode {
- EvalMode::Legacy => {
- sum = v.add_wrapping(sum);
- }
- EvalMode::Ansi | EvalMode::Try => {
- match v.add_checked(sum) {
- Ok(v) => sum = v,
- Err(_e) => {
- return if eval_mode == EvalMode::Ansi {
-
Err(DataFusionError::from(arithmetic_overflow_error(
- "integer",
- )))
- } else {
- Ok(None)
- };
- }
- };
- }
+ match v.add_checked(sum) {
+ Ok(new_sum) => sum = new_sum,
+ Err(_) => return Ok(None),
}
}
}
Ok(Some(sum))
}
- if self.eval_mode == EvalMode::Try && !self.has_all_nulls &&
self.sum.is_none() {
- // we saw an overflow earlier (Try eval mode). Skip processing
+ // Skip if we already saw an overflow
+ if self.overflowed() {
return Ok(());
}
+
let values = &values[0];
if values.len() == values.null_count() {
- Ok(())
- } else {
- // No nulls so there should be a non-null sum / null incase
overflow in Try eval
- let running_sum = self.sum.unwrap_or(0);
- let sum = match values.data_type() {
- DataType::Int64 => update_sum_internal(
- as_primitive_array::<Int64Type>(values),
- self.eval_mode,
- running_sum,
- )?,
- DataType::Int32 => update_sum_internal(
- as_primitive_array::<Int32Type>(values),
- self.eval_mode,
- running_sum,
- )?,
- DataType::Int16 => update_sum_internal(
- as_primitive_array::<Int16Type>(values),
- self.eval_mode,
- running_sum,
- )?,
- DataType::Int8 => update_sum_internal(
- as_primitive_array::<Int8Type>(values),
- self.eval_mode,
- running_sum,
- )?,
- _ => {
- return Err(DataFusionError::Internal(format!(
- "unsupported data type: {:?}",
- values.data_type()
- )));
- }
- };
- self.sum = sum;
- self.has_all_nulls = false;
- Ok(())
+ return Ok(());
}
+
+ let running_sum = self.sum.unwrap_or(0);
+ let sum = match values.data_type() {
+ DataType::Int64 =>
update_sum(as_primitive_array::<Int64Type>(values), running_sum)?,
+ DataType::Int32 =>
update_sum(as_primitive_array::<Int32Type>(values), running_sum)?,
+ DataType::Int16 =>
update_sum(as_primitive_array::<Int16Type>(values), running_sum)?,
+ DataType::Int8 =>
update_sum(as_primitive_array::<Int8Type>(values), running_sum)?,
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "unsupported data type: {:?}",
+ values.data_type()
+ )));
+ }
+ };
+ self.sum = sum;
+ self.has_all_nulls = false;
+ Ok(())
}
fn evaluate(&mut self) -> DFResult<ScalarValue> {
@@ -225,26 +385,16 @@ impl Accumulator for SumIntegerAccumulator {
}
fn state(&mut self) -> DFResult<Vec<ScalarValue>> {
- if self.eval_mode == EvalMode::Try {
- Ok(vec![
- ScalarValue::Int64(self.sum),
- ScalarValue::Boolean(Some(self.has_all_nulls)),
- ])
- } else {
- Ok(vec![ScalarValue::Int64(self.sum)])
- }
+ Ok(vec![
+ ScalarValue::Int64(self.sum),
+ ScalarValue::Boolean(Some(self.has_all_nulls)),
+ ])
}
fn merge_batch(&mut self, states: &[ArrayRef]) -> DFResult<()> {
- let expected_state_len = if self.eval_mode == EvalMode::Try {
- 2
- } else {
- 1
- };
- if expected_state_len != states.len() {
+ if states.len() != 2 {
return Err(DataFusionError::Internal(format!(
- "Invalid state while merging batch. Expected {} elements but
found {}",
- expected_state_len,
+ "Invalid state while merging batch. Expected 2 elements but
found {}",
states.len()
)));
}
@@ -255,94 +405,326 @@ impl Accumulator for SumIntegerAccumulator {
} else {
Some(that_sum_array.value(0))
};
+ let that_has_all_nulls = states[1].as_boolean().value(0);
- // Check for overflow for early termination
- if self.eval_mode == EvalMode::Try {
- let that_has_all_nulls = states[1].as_boolean().value(0);
- let that_overflowed = !that_has_all_nulls && that_sum.is_none();
- let this_overflowed = !self.has_all_nulls && self.sum.is_none();
- if that_overflowed || this_overflowed {
+ let that_overflowed = !that_has_all_nulls && that_sum.is_none();
+ if that_overflowed || self.overflowed() {
+ self.sum = None;
+ self.has_all_nulls = false;
+ return Ok(());
+ }
+
+ if that_has_all_nulls {
+ return Ok(());
+ }
+ if self.has_all_nulls {
+ self.sum = that_sum;
+ self.has_all_nulls = false;
+ return Ok(());
+ }
+
+ // Both sides have non-null values
+ match self.sum.unwrap().add_checked(that_sum.unwrap()) {
+ Ok(v) => self.sum = Some(v),
+ Err(_) => {
self.sum = None;
self.has_all_nulls = false;
- return Ok(());
}
- if that_has_all_nulls {
- return Ok(());
+ }
+ Ok(())
+ }
+}
+
+struct SumIntGroupsAccumulatorLegacy {
+ sums: Vec<Option<i64>>,
+}
+
+impl SumIntGroupsAccumulatorLegacy {
+ fn new() -> Self {
+ Self { sums: Vec::new() }
+ }
+}
+
+impl GroupsAccumulator for SumIntGroupsAccumulatorLegacy {
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> DFResult<()> {
+ fn update_groups_sum<T>(
+ int_array: &PrimitiveArray<T>,
+ group_indices: &[usize],
+ sums: &mut [Option<i64>],
+ ) -> DFResult<()>
+ where
+ T: ArrowPrimitiveType,
+ T::Native: ArrowNativeType,
+ {
+ for (i, &group_index) in group_indices.iter().enumerate() {
+ if !int_array.is_null(i) {
+ let v = int_array.value(i).to_i64().ok_or_else(|| {
+ DataFusionError::Internal("Failed to convert value to
i64".to_string())
+ })?;
+ sums[group_index] =
Some(sums[group_index].unwrap_or(0).add_wrapping(v));
+ }
}
- if self.has_all_nulls {
- self.sum = that_sum;
- self.has_all_nulls = false;
- return Ok(());
+ Ok(())
+ }
+
+ debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+ let values = &values[0];
+ self.sums.resize(total_num_groups, None);
+
+ match values.data_type() {
+ DataType::Int64 => update_groups_sum(
+ as_primitive_array::<Int64Type>(values),
+ group_indices,
+ &mut self.sums,
+ )?,
+ DataType::Int32 => update_groups_sum(
+ as_primitive_array::<Int32Type>(values),
+ group_indices,
+ &mut self.sums,
+ )?,
+ DataType::Int16 => update_groups_sum(
+ as_primitive_array::<Int16Type>(values),
+ group_indices,
+ &mut self.sums,
+ )?,
+ DataType::Int8 => update_groups_sum(
+ as_primitive_array::<Int8Type>(values),
+ group_indices,
+ &mut self.sums,
+ )?,
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "Unsupported data type for SumIntGroupsAccumulatorLegacy:
{:?}",
+ values.data_type()
+ )))
}
- } else {
- if that_sum.is_none() {
- return Ok(());
+ };
+ Ok(())
+ }
+
+ fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
+ match emit_to {
+ EmitTo::All => {
+ let result = Arc::new(Int64Array::from(std::mem::take(&mut
self.sums))) as ArrayRef;
+ Ok(result)
}
- if self.sum.is_none() {
- self.sum = that_sum;
- return Ok(());
+ EmitTo::First(n) => {
+ let result =
Arc::new(Int64Array::from(self.sums.drain(..n).collect::<Vec<_>>()))
+ as ArrayRef;
+ Ok(result)
}
}
+ }
- // safe to unwrap (since we checked nulls above) but handling error
just in case state is corrupt
- let left = self.sum.ok_or_else(|| {
- DataFusionError::Internal(
- "Invalid state in merging batch. Current batch's sum is
None".to_string(),
- )
- })?;
- let right = that_sum.ok_or_else(|| {
- DataFusionError::Internal(
- "Invalid state in merging batch. Incoming sum is
None".to_string(),
- )
- })?;
+ fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
+ let sums = emit_to.take_needed(&mut self.sums);
+ Ok(vec![Arc::new(Int64Array::from(sums))])
+ }
- match self.eval_mode {
- EvalMode::Legacy => {
- self.sum = Some(left.add_wrapping(right));
+ fn merge_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> DFResult<()> {
+ debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+
+ if values.len() != 1 {
+ return Err(DataFusionError::Internal(format!(
+ "Invalid state while merging batch. Expected 1 element but
found {}",
+ values.len()
+ )));
+ }
+ let that_sums = values[0].as_primitive::<Int64Type>();
+
+ self.sums.resize(total_num_groups, None);
+
+ for (idx, &group_index) in group_indices.iter().enumerate() {
+ if that_sums.is_null(idx) {
+ continue;
}
- EvalMode::Ansi | EvalMode::Try => match left.add_checked(right) {
- Ok(v) => self.sum = Some(v),
- Err(_) => {
- if self.eval_mode == EvalMode::Ansi {
- return
Err(DataFusionError::from(arithmetic_overflow_error("integer")));
- } else {
- self.sum = None;
- self.has_all_nulls = false;
- }
+ let that_sum = that_sums.value(idx);
+
+ if self.sums[group_index].is_none() {
+ self.sums[group_index] = Some(that_sum);
+ } else {
+ self.sums[group_index] =
+
Some(self.sums[group_index].unwrap().add_wrapping(that_sum));
+ }
+ }
+ Ok(())
+ }
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
+}
+
+struct SumIntGroupsAccumulatorAnsi {
+ sums: Vec<Option<i64>>,
+}
+
+impl SumIntGroupsAccumulatorAnsi {
+ fn new() -> Self {
+ Self { sums: Vec::new() }
+ }
+}
+
+impl GroupsAccumulator for SumIntGroupsAccumulatorAnsi {
+ fn update_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> DFResult<()> {
+ fn update_groups_sum<T>(
+ int_array: &PrimitiveArray<T>,
+ group_indices: &[usize],
+ sums: &mut [Option<i64>],
+ ) -> DFResult<()>
+ where
+ T: ArrowPrimitiveType,
+ T::Native: ArrowNativeType,
+ {
+ for (i, &group_index) in group_indices.iter().enumerate() {
+ if !int_array.is_null(i) {
+ let v = int_array.value(i).to_i64().ok_or_else(|| {
+ DataFusionError::Internal("Failed to convert value to
i64".to_string())
+ })?;
+ sums[group_index] =
+
Some(sums[group_index].unwrap_or(0).add_checked(v).map_err(|_| {
+
DataFusionError::from(arithmetic_overflow_error("integer"))
+ })?);
}
- },
+ }
+ Ok(())
+ }
+
+ debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+ let values = &values[0];
+ self.sums.resize(total_num_groups, None);
+
+ match values.data_type() {
+ DataType::Int64 => update_groups_sum(
+ as_primitive_array::<Int64Type>(values),
+ group_indices,
+ &mut self.sums,
+ )?,
+ DataType::Int32 => update_groups_sum(
+ as_primitive_array::<Int32Type>(values),
+ group_indices,
+ &mut self.sums,
+ )?,
+ DataType::Int16 => update_groups_sum(
+ as_primitive_array::<Int16Type>(values),
+ group_indices,
+ &mut self.sums,
+ )?,
+ DataType::Int8 => update_groups_sum(
+ as_primitive_array::<Int8Type>(values),
+ group_indices,
+ &mut self.sums,
+ )?,
+ _ => {
+ return Err(DataFusionError::Internal(format!(
+ "Unsupported data type for SumIntGroupsAccumulatorAnsi:
{:?}",
+ values.data_type()
+ )))
+ }
+ };
+ Ok(())
+ }
+
+ fn evaluate(&mut self, emit_to: EmitTo) -> DFResult<ArrayRef> {
+ match emit_to {
+ EmitTo::All => {
+ let result = Arc::new(Int64Array::from(std::mem::take(&mut
self.sums))) as ArrayRef;
+ Ok(result)
+ }
+ EmitTo::First(n) => {
+ let result =
Arc::new(Int64Array::from(self.sums.drain(..n).collect::<Vec<_>>()))
+ as ArrayRef;
+ Ok(result)
+ }
+ }
+ }
+
+ fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
+ let sums = emit_to.take_needed(&mut self.sums);
+ Ok(vec![Arc::new(Int64Array::from(sums))])
+ }
+
+ fn merge_batch(
+ &mut self,
+ values: &[ArrayRef],
+ group_indices: &[usize],
+ opt_filter: Option<&BooleanArray>,
+ total_num_groups: usize,
+ ) -> DFResult<()> {
+ debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
+
+ if values.len() != 1 {
+ return Err(DataFusionError::Internal(format!(
+ "Invalid state while merging batch. Expected 1 element but
found {}",
+ values.len()
+ )));
+ }
+ let that_sums = values[0].as_primitive::<Int64Type>();
+
+ self.sums.resize(total_num_groups, None);
+
+ for (idx, &group_index) in group_indices.iter().enumerate() {
+ if that_sums.is_null(idx) {
+ continue;
+ }
+ let that_sum = that_sums.value(idx);
+
+ if self.sums[group_index].is_none() {
+ self.sums[group_index] = Some(that_sum);
+ } else {
+ self.sums[group_index] = Some(
+ self.sums[group_index]
+ .unwrap()
+ .add_checked(that_sum)
+ .map_err(|_|
DataFusionError::from(arithmetic_overflow_error("integer")))?,
+ );
+ }
}
Ok(())
}
+
+ fn size(&self) -> usize {
+ std::mem::size_of_val(self)
+ }
}
-struct SumIntGroupsAccumulator {
+struct SumIntGroupsAccumulatorTry {
sums: Vec<Option<i64>>,
has_all_nulls: Vec<bool>,
- eval_mode: EvalMode,
}
-impl SumIntGroupsAccumulator {
- fn new(eval_mode: EvalMode) -> Self {
+impl SumIntGroupsAccumulatorTry {
+ fn new() -> Self {
Self {
sums: Vec::new(),
- eval_mode,
has_all_nulls: Vec::new(),
}
}
- fn resize_helper(&mut self, total_num_groups: usize) {
- if self.eval_mode == EvalMode::Try {
- self.sums.resize(total_num_groups, Some(0));
- self.has_all_nulls.resize(total_num_groups, true);
- } else {
- self.sums.resize(total_num_groups, None);
- self.has_all_nulls.resize(total_num_groups, false);
- }
+ fn group_overflowed(&self, group_index: usize) -> bool {
+ !self.has_all_nulls[group_index] && self.sums[group_index].is_none()
}
}
-impl GroupsAccumulator for SumIntGroupsAccumulator {
+impl GroupsAccumulator for SumIntGroupsAccumulatorTry {
fn update_batch(
&mut self,
values: &[ArrayRef],
@@ -350,12 +732,11 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
opt_filter: Option<&BooleanArray>,
total_num_groups: usize,
) -> DFResult<()> {
- fn update_groups_sum_internal<T>(
+ fn update_groups_sum<T>(
int_array: &PrimitiveArray<T>,
group_indices: &[usize],
sums: &mut [Option<i64>],
has_all_nulls: &mut [bool],
- eval_mode: EvalMode,
) -> DFResult<()>
where
T: ArrowPrimitiveType,
@@ -363,39 +744,18 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
{
for (i, &group_index) in group_indices.iter().enumerate() {
if !int_array.is_null(i) {
- // there is an overflow in prev group in try eval. Skip
processing
- if eval_mode == EvalMode::Try
- && !has_all_nulls[group_index]
- && sums[group_index].is_none()
- {
+ // Skip if this group already overflowed
+ if !has_all_nulls[group_index] &&
sums[group_index].is_none() {
continue;
}
let v = int_array.value(i).to_i64().ok_or_else(|| {
DataFusionError::Internal("Failed to convert value to
i64".to_string())
})?;
- match eval_mode {
- EvalMode::Legacy => {
- sums[group_index] =
-
Some(sums[group_index].unwrap_or(0).add_wrapping(v));
- }
- EvalMode::Ansi | EvalMode::Try => {
- match
sums[group_index].unwrap_or(0).add_checked(v) {
- Ok(new_sum) => {
- sums[group_index] = Some(new_sum);
- }
- Err(_) => {
- if eval_mode == EvalMode::Ansi {
- return Err(DataFusionError::from(
-
arithmetic_overflow_error("integer"),
- ));
- } else {
- sums[group_index] = None;
- }
- }
- };
- }
- }
- has_all_nulls[group_index] = false
+ match sums[group_index].unwrap_or(0).add_checked(v) {
+ Ok(new_sum) => sums[group_index] = Some(new_sum),
+ Err(_) => sums[group_index] = None,
+ };
+ has_all_nulls[group_index] = false;
}
}
Ok(())
@@ -403,40 +763,37 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
let values = &values[0];
- self.resize_helper(total_num_groups);
+ self.sums.resize(total_num_groups, Some(0));
+ self.has_all_nulls.resize(total_num_groups, true);
match values.data_type() {
- DataType::Int64 => update_groups_sum_internal(
+ DataType::Int64 => update_groups_sum(
as_primitive_array::<Int64Type>(values),
group_indices,
&mut self.sums,
&mut self.has_all_nulls,
- self.eval_mode,
)?,
- DataType::Int32 => update_groups_sum_internal(
+ DataType::Int32 => update_groups_sum(
as_primitive_array::<Int32Type>(values),
group_indices,
&mut self.sums,
&mut self.has_all_nulls,
- self.eval_mode,
)?,
- DataType::Int16 => update_groups_sum_internal(
+ DataType::Int16 => update_groups_sum(
as_primitive_array::<Int16Type>(values),
group_indices,
&mut self.sums,
&mut self.has_all_nulls,
- self.eval_mode,
)?,
- DataType::Int8 => update_groups_sum_internal(
+ DataType::Int8 => update_groups_sum(
as_primitive_array::<Int8Type>(values),
group_indices,
&mut self.sums,
&mut self.has_all_nulls,
- self.eval_mode,
)?,
_ => {
return Err(DataFusionError::Internal(format!(
- "Unsupported data type for SumIntGroupsAccumulator: {:?}",
+ "Unsupported data type for SumIntGroupsAccumulatorTry:
{:?}",
values.data_type()
)))
}
@@ -453,7 +810,6 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
.zip(self.has_all_nulls.iter())
.map(|(&sum, &is_null)| if is_null { None } else { sum
}),
)) as ArrayRef;
-
self.sums.clear();
self.has_all_nulls.clear();
Ok(result)
@@ -472,16 +828,11 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
fn state(&mut self, emit_to: EmitTo) -> DFResult<Vec<ArrayRef>> {
let sums = emit_to.take_needed(&mut self.sums);
-
- if self.eval_mode == EvalMode::Try {
- let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls);
- Ok(vec![
- Arc::new(Int64Array::from(sums)),
- Arc::new(BooleanArray::from(has_all_nulls)),
- ])
- } else {
- Ok(vec![Arc::new(Int64Array::from(sums))])
- }
+ let has_all_nulls = emit_to.take_needed(&mut self.has_all_nulls);
+ Ok(vec![
+ Arc::new(Int64Array::from(sums)),
+ Arc::new(BooleanArray::from(has_all_nulls)),
+ ])
}
fn merge_batch(
@@ -493,27 +844,17 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
) -> DFResult<()> {
debug_assert!(opt_filter.is_none(), "opt_filter is not supported yet");
- let expected_state_len = if self.eval_mode == EvalMode::Try {
- 2
- } else {
- 1
- };
- if expected_state_len != values.len() {
+ if values.len() != 2 {
return Err(DataFusionError::Internal(format!(
- "Invalid state while merging batch. Expected {} elements but
found {}",
- expected_state_len,
+ "Invalid state while merging batch. Expected 2 elements but
found {}",
values.len()
)));
}
let that_sums = values[0].as_primitive::<Int64Type>();
+ let that_has_all_nulls_array = values[1].as_boolean();
- self.resize_helper(total_num_groups);
-
- let that_sums_is_all_nulls = if self.eval_mode == EvalMode::Try {
- Some(values[1].as_boolean())
- } else {
- None
- };
+ self.sums.resize(total_num_groups, Some(0));
+ self.has_all_nulls.resize(total_num_groups, true);
for (idx, &group_index) in group_indices.iter().enumerate() {
let that_sum = if that_sums.is_null(idx) {
@@ -521,62 +862,34 @@ impl GroupsAccumulator for SumIntGroupsAccumulator {
} else {
Some(that_sums.value(idx))
};
+ let that_has_all_nulls = that_has_all_nulls_array.value(idx);
- if self.eval_mode == EvalMode::Try {
- let that_has_all_nulls =
that_sums_is_all_nulls.unwrap().value(idx);
-
- let that_overflowed = !that_has_all_nulls &&
that_sum.is_none();
- let this_overflowed =
- !self.has_all_nulls[group_index] &&
self.sums[group_index].is_none();
-
- if that_overflowed || this_overflowed {
- self.sums[group_index] = None;
- self.has_all_nulls[group_index] = false;
- continue;
- }
-
- if that_has_all_nulls {
- continue;
- }
+ let that_overflowed = !that_has_all_nulls && that_sum.is_none();
+ if that_overflowed || self.group_overflowed(group_index) {
+ self.sums[group_index] = None;
+ self.has_all_nulls[group_index] = false;
+ continue;
+ }
- if self.has_all_nulls[group_index] {
- self.sums[group_index] = that_sum;
- self.has_all_nulls[group_index] = false;
- continue;
- }
- } else {
- if that_sum.is_none() {
- continue;
- }
- if self.sums[group_index].is_none() {
- self.sums[group_index] = that_sum;
- continue;
- }
+ if that_has_all_nulls {
+ continue;
}
- // Both sides have non-null. Update sums now
- let left = self.sums[group_index].unwrap();
- let right = that_sum.unwrap();
+ if self.has_all_nulls[group_index] {
+ self.sums[group_index] = that_sum;
+ self.has_all_nulls[group_index] = false;
+ continue;
+ }
- match self.eval_mode {
- EvalMode::Legacy => {
- self.sums[group_index] = Some(left.add_wrapping(right));
- }
- EvalMode::Ansi | EvalMode::Try => {
- match left.add_checked(right) {
- Ok(v) => self.sums[group_index] = Some(v),
- Err(_) => {
- if self.eval_mode == EvalMode::Ansi {
- return
Err(DataFusionError::from(arithmetic_overflow_error(
- "integer",
- )));
- } else {
- // overflow. update flag accordingly
- self.sums[group_index] = None;
- self.has_all_nulls[group_index] = false;
- }
- }
- }
+ // Both sides have non-null values
+ match self.sums[group_index]
+ .unwrap()
+ .add_checked(that_sum.unwrap())
+ {
+ Ok(v) => self.sums[group_index] = Some(v),
+ Err(_) => {
+ self.sums[group_index] = None;
+ self.has_all_nulls[group_index] = false;
}
}
}
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]