This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 1624d63070 perf: Add support for `GroupsAccumulator` to `string_agg` 
(#21154)
1624d63070 is described below

commit 1624d63070c05dde448a14727bd20764fc7887ad
Author: Neil Conway <[email protected]>
AuthorDate: Thu Mar 26 16:51:15 2026 -0400

    perf: Add support for `GroupsAccumulator` to `string_agg` (#21154)
    
    ## Which issue does this PR close?
    
    - Closes #17789.
    
    ## Rationale for this change
    
    `string_agg` previously didn't support the `GroupsAccumulator` API;
    adding support for it can significantly improve performance,
    particularly when there are many groups.
    
    Benchmarks (M4 Max):
    
     - string_agg_query_group_by_few_groups (~10): 645 µs → 564 µs, -11%
    - string_agg_query_group_by_mid_groups (~1,000): 2,692 µs → 871 µs, -68%
    - string_agg_query_group_by_many_groups (~65,000): 16,606 µs → 1,147 µs,
    -93%
    
    ## What changes are included in this PR?
    
    * Add end-to-end benchmark for `string_agg`
    * Implement `GroupsAccumulator` API for `string_agg`
    * Add unit tests
    * Minor code cleanup for existing `string_agg` code paths
    
    ## Are these changes tested?
    
    Yes.
    
    ## Are there any user-facing changes?
    
    No, other than a change to an error message string.
    
    ---------
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/benches/aggregate_query_sql.rs   |  33 ++
 datafusion/functions-aggregate/src/string_agg.rs | 405 ++++++++++++++++++++---
 datafusion/sqllogictest/test_files/aggregate.slt |   2 +-
 3 files changed, 384 insertions(+), 56 deletions(-)

diff --git a/datafusion/core/benches/aggregate_query_sql.rs 
b/datafusion/core/benches/aggregate_query_sql.rs
index 402ac9c717..d7e24aceba 100644
--- a/datafusion/core/benches/aggregate_query_sql.rs
+++ b/datafusion/core/benches/aggregate_query_sql.rs
@@ -295,6 +295,39 @@ fn criterion_benchmark(c: &mut Criterion) {
             )
         })
     });
+
+    c.bench_function("string_agg_query_group_by_few_groups", |b| {
+        b.iter(|| {
+            query(
+                ctx.clone(),
+                &rt,
+                "SELECT u64_narrow, string_agg(utf8, ',') \
+                 FROM t GROUP BY u64_narrow",
+            )
+        })
+    });
+
+    c.bench_function("string_agg_query_group_by_mid_groups", |b| {
+        b.iter(|| {
+            query(
+                ctx.clone(),
+                &rt,
+                "SELECT u64_mid, string_agg(utf8, ',') \
+                 FROM t GROUP BY u64_mid",
+            )
+        })
+    });
+
+    c.bench_function("string_agg_query_group_by_many_groups", |b| {
+        b.iter(|| {
+            query(
+                ctx.clone(),
+                &rt,
+                "SELECT u64_wide, string_agg(utf8, ',') \
+                 FROM t GROUP BY u64_wide",
+            )
+        })
+    });
 }
 
 criterion_group!(benches, criterion_benchmark);
diff --git a/datafusion/functions-aggregate/src/string_agg.rs 
b/datafusion/functions-aggregate/src/string_agg.rs
index 6f1a37302f..ea3914b1e3 100644
--- a/datafusion/functions-aggregate/src/string_agg.rs
+++ b/datafusion/functions-aggregate/src/string_agg.rs
@@ -20,23 +20,24 @@
 use std::any::Any;
 use std::hash::Hash;
 use std::mem::size_of_val;
+use std::sync::Arc;
 
 use crate::array_agg::ArrayAgg;
 
-use arrow::array::ArrayRef;
+use arrow::array::{ArrayRef, AsArray, BooleanArray, LargeStringArray};
 use arrow::datatypes::{DataType, Field, FieldRef};
-use datafusion_common::cast::{
-    as_generic_string_array, as_string_array, as_string_view_array,
-};
+use datafusion_common::cast::{as_generic_string_array, as_string_view_array};
 use datafusion_common::{
     Result, ScalarValue, internal_datafusion_err, internal_err, not_impl_err,
 };
 use datafusion_expr::function::AccumulatorArgs;
 use datafusion_expr::utils::format_state_name;
 use datafusion_expr::{
-    Accumulator, AggregateUDFImpl, Documentation, Signature, TypeSignature, 
Volatility,
+    Accumulator, AggregateUDFImpl, Documentation, EmitTo, GroupsAccumulator, 
Signature,
+    TypeSignature, Volatility,
 };
 use datafusion_functions_aggregate_common::accumulator::StateFieldsArgs;
+use 
datafusion_functions_aggregate_common::aggregate::groups_accumulator::nulls::apply_filter_as_nulls;
 use datafusion_macros::user_doc;
 use datafusion_physical_expr::expressions::Literal;
 
@@ -117,6 +118,27 @@ impl StringAgg {
             array_agg: Default::default(),
         }
     }
+
+    /// Extract the delimiter string from the second argument expression.
+    fn extract_delimiter(args: &AccumulatorArgs) -> Result<String> {
+        let Some(lit) = args.exprs[1].as_any().downcast_ref::<Literal>() else {
+            return not_impl_err!("string_agg delimiter must be a string 
literal");
+        };
+
+        if lit.value().is_null() {
+            return Ok(String::new());
+        }
+
+        match lit.value().try_as_str() {
+            Some(s) => Ok(s.unwrap_or("").to_string()),
+            None => {
+                not_impl_err!(
+                    "string_agg not supported for delimiter \"{}\"",
+                    lit.value()
+                )
+            }
+        }
+    }
 }
 
 impl Default for StringAgg {
@@ -125,8 +147,10 @@ impl Default for StringAgg {
     }
 }
 
-/// If there is no `distinct` and `order by` required by the `string_agg` 
call, a
-/// more efficient accumulator `SimpleStringAggAccumulator` will be used.
+/// Three accumulation strategies depending on query shape:
+/// - No DISTINCT / ORDER BY with GROUP BY: `StringAggGroupsAccumulator`
+/// - No DISTINCT / ORDER BY without GROUP BY: `SimpleStringAggAccumulator`
+/// - With DISTINCT or ORDER BY: `StringAggAccumulator` (delegates to 
`ArrayAgg`)
 impl AggregateUDFImpl for StringAgg {
     fn as_any(&self) -> &dyn Any {
         self
@@ -145,11 +169,7 @@ impl AggregateUDFImpl for StringAgg {
     }
 
     fn state_fields(&self, args: StateFieldsArgs) -> Result<Vec<FieldRef>> {
-        // See comments in `impl AggregateUDFImpl ...` for more detail
-        let no_order_no_distinct =
-            (args.ordering_fields.is_empty()) && (!args.is_distinct);
-        if no_order_no_distinct {
-            // Case `SimpleStringAggAccumulator`
+        if !args.is_distinct && args.ordering_fields.is_empty() {
             Ok(vec![
                 Field::new(
                     format_state_name(args.name, "string_agg"),
@@ -159,40 +179,16 @@ impl AggregateUDFImpl for StringAgg {
                 .into(),
             ])
         } else {
-            // Case `StringAggAccumulator`
             self.array_agg.state_fields(args)
         }
     }
 
     fn accumulator(&self, acc_args: AccumulatorArgs) -> Result<Box<dyn 
Accumulator>> {
-        let Some(lit) = acc_args.exprs[1].as_any().downcast_ref::<Literal>() 
else {
-            return not_impl_err!(
-                "The second argument of the string_agg function must be a 
string literal"
-            );
-        };
-
-        let delimiter = if lit.value().is_null() {
-            // If the second argument (the delimiter that joins strings) is 
NULL, join
-            // on an empty string. (e.g. [a, b, c] => "abc").
-            ""
-        } else if let Some(lit_string) = lit.value().try_as_str() {
-            lit_string.unwrap_or("")
-        } else {
-            return not_impl_err!(
-                "StringAgg not supported for delimiter \"{}\"",
-                lit.value()
-            );
-        };
-
-        // See comments in `impl AggregateUDFImpl ...` for more detail
-        let no_order_no_distinct =
-            acc_args.order_bys.is_empty() && (!acc_args.is_distinct);
+        let delimiter = Self::extract_delimiter(&acc_args)?;
 
-        if no_order_no_distinct {
-            // simple case (more efficient)
-            Ok(Box::new(SimpleStringAggAccumulator::new(delimiter)))
+        if !acc_args.is_distinct && acc_args.order_bys.is_empty() {
+            Ok(Box::new(SimpleStringAggAccumulator::new(&delimiter)))
         } else {
-            // general case
             let array_agg_acc = self.array_agg.accumulator(AccumulatorArgs {
                 return_field: Field::new(
                     "f",
@@ -215,7 +211,7 @@ impl AggregateUDFImpl for StringAgg {
 
             Ok(Box::new(StringAggAccumulator::new(
                 array_agg_acc,
-                delimiter,
+                &delimiter,
             )))
         }
     }
@@ -224,6 +220,18 @@ impl AggregateUDFImpl for StringAgg {
         datafusion_expr::ReversedUDAF::Reversed(string_agg_udaf())
     }
 
+    fn groups_accumulator_supported(&self, args: AccumulatorArgs) -> bool {
+        !args.is_distinct && args.order_bys.is_empty()
+    }
+
+    fn create_groups_accumulator(
+        &self,
+        args: AccumulatorArgs,
+    ) -> Result<Box<dyn GroupsAccumulator>> {
+        let delimiter = Self::extract_delimiter(&args)?;
+        Ok(Box::new(StringAggGroupsAccumulator::new(delimiter)))
+    }
+
     fn documentation(&self) -> Option<&Documentation> {
         self.doc()
     }
@@ -315,10 +323,136 @@ fn filter_index<T: Clone>(values: &[T], index: usize) -> 
Vec<T> {
         .collect::<Vec<_>>()
 }
 
-/// StringAgg accumulator for the simple case (no order or distinct specified)
-/// This accumulator is more efficient than `StringAggAccumulator`
-/// because it accumulates the string directly,
-/// whereas `StringAggAccumulator` uses `ArrayAggAccumulator`.
+/// GroupsAccumulator for `string_agg` without DISTINCT or ORDER BY.
+#[derive(Debug)]
+struct StringAggGroupsAccumulator {
+    /// The delimiter placed between concatenated values.
+    delimiter: String,
+    /// Accumulated string per group. `None` means no values have been seen
+    /// (the group's output will be NULL).
+    /// A potential improvement is to avoid this String allocation
+    /// See <https://github.com/apache/datafusion/issues/21156>
+    values: Vec<Option<String>>,
+    /// Running total of string data bytes across all groups.
+    total_data_bytes: usize,
+}
+
+impl StringAggGroupsAccumulator {
+    fn new(delimiter: String) -> Self {
+        Self {
+            delimiter,
+            values: Vec::new(),
+            total_data_bytes: 0,
+        }
+    }
+
+    fn append_batch<'a>(
+        &mut self,
+        iter: impl Iterator<Item = Option<&'a str>>,
+        group_indices: &[usize],
+    ) {
+        for (opt_value, &group_idx) in iter.zip(group_indices.iter()) {
+            if let Some(value) = opt_value {
+                match &mut self.values[group_idx] {
+                    Some(existing) => {
+                        let added = self.delimiter.len() + value.len();
+                        existing.reserve(added);
+                        existing.push_str(&self.delimiter);
+                        existing.push_str(value);
+                        self.total_data_bytes += added;
+                    }
+                    slot @ None => {
+                        *slot = Some(value.to_string());
+                        self.total_data_bytes += value.len();
+                    }
+                }
+            }
+        }
+    }
+}
+
+impl GroupsAccumulator for StringAggGroupsAccumulator {
+    fn update_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        self.values.resize(total_num_groups, None);
+        let array = apply_filter_as_nulls(&values[0], opt_filter)?;
+        match array.data_type() {
+            DataType::Utf8 => {
+                self.append_batch(array.as_string::<i32>().iter(), 
group_indices)
+            }
+            DataType::LargeUtf8 => {
+                self.append_batch(array.as_string::<i64>().iter(), 
group_indices)
+            }
+            DataType::Utf8View => {
+                self.append_batch(array.as_string_view().iter(), group_indices)
+            }
+            other => {
+                return internal_err!("string_agg unexpected data type: 
{other}");
+            }
+        }
+        Ok(())
+    }
+
+    fn evaluate(&mut self, emit_to: EmitTo) -> Result<ArrayRef> {
+        let to_emit = emit_to.take_needed(&mut self.values);
+        let emitted_bytes: usize = to_emit
+            .iter()
+            .filter_map(|opt| opt.as_ref().map(|s| s.len()))
+            .sum();
+        self.total_data_bytes -= emitted_bytes;
+
+        let result: ArrayRef = Arc::new(LargeStringArray::from(to_emit));
+        Ok(result)
+    }
+
+    fn state(&mut self, emit_to: EmitTo) -> Result<Vec<ArrayRef>> {
+        self.evaluate(emit_to).map(|arr| vec![arr])
+    }
+
+    fn merge_batch(
+        &mut self,
+        values: &[ArrayRef],
+        group_indices: &[usize],
+        opt_filter: Option<&BooleanArray>,
+        total_num_groups: usize,
+    ) -> Result<()> {
+        // State is always LargeUtf8, which update_batch already handles.
+        self.update_batch(values, group_indices, opt_filter, total_num_groups)
+    }
+
+    fn convert_to_state(
+        &self,
+        values: &[ArrayRef],
+        opt_filter: Option<&BooleanArray>,
+    ) -> Result<Vec<ArrayRef>> {
+        let input = apply_filter_as_nulls(&values[0], opt_filter)?;
+        let result = if input.data_type() == &DataType::LargeUtf8 {
+            input
+        } else {
+            arrow::compute::cast(&input, &DataType::LargeUtf8)?
+        };
+        Ok(vec![result])
+    }
+
+    fn supports_convert_to_state(&self) -> bool {
+        true
+    }
+
+    fn size(&self) -> usize {
+        self.total_data_bytes
+            + self.values.capacity() * size_of::<Option<String>>()
+            + self.delimiter.capacity()
+            + size_of_val(self)
+    }
+}
+
+/// Per-row accumulator for `string_agg` without DISTINCT or ORDER BY.  Used 
for
+/// non-grouped aggregation; grouped queries use 
[`StringAggGroupsAccumulator`].
 #[derive(Debug)]
 pub(crate) struct SimpleStringAggAccumulator {
     delimiter: String,
@@ -331,7 +465,7 @@ impl SimpleStringAggAccumulator {
     pub fn new(delimiter: &str) -> Self {
         Self {
             delimiter: delimiter.to_string(),
-            accumulated_string: "".to_string(),
+            accumulated_string: String::new(),
             has_value: false,
         }
     }
@@ -361,18 +495,11 @@ impl Accumulator for SimpleStringAggAccumulator {
         })?;
 
         match string_arr.data_type() {
-            DataType::Utf8 => {
-                let array = as_string_array(string_arr)?;
-                self.append_strings(array.iter());
-            }
+            DataType::Utf8 => 
self.append_strings(string_arr.as_string::<i32>().iter()),
             DataType::LargeUtf8 => {
-                let array = as_generic_string_array::<i64>(string_arr)?;
-                self.append_strings(array.iter());
-            }
-            DataType::Utf8View => {
-                let array = as_string_view_array(string_arr)?;
-                self.append_strings(array.iter());
+                self.append_strings(string_arr.as_string::<i64>().iter())
             }
+            DataType::Utf8View => 
self.append_strings(string_arr.as_string_view().iter()),
             other => {
                 return internal_err!(
                     "Planner should ensure string_agg first argument is 
Utf8-like, found {other}"
@@ -662,4 +789,172 @@ mod tests {
         acc1.merge_batch(&intermediate_state)?;
         Ok(acc1)
     }
+
+    // ---------------------------------------------------------------
+    // Tests for StringAggGroupsAccumulator
+    // ---------------------------------------------------------------
+
+    fn make_groups_acc(delimiter: &str) -> StringAggGroupsAccumulator {
+        StringAggGroupsAccumulator::new(delimiter.to_string())
+    }
+
+    /// Helper: evaluate and downcast to LargeStringArray
+    fn evaluate_groups(
+        acc: &mut StringAggGroupsAccumulator,
+        emit_to: EmitTo,
+    ) -> Vec<Option<String>> {
+        let result = acc.evaluate(emit_to).unwrap();
+        let arr = result.as_any().downcast_ref::<LargeStringArray>().unwrap();
+        arr.iter().map(|v| v.map(|s| s.to_string())).collect()
+    }
+
+    #[test]
+    fn groups_basic() -> Result<()> {
+        let mut acc = make_groups_acc(",");
+
+        // 6 rows, 3 groups: group 0 gets "a","d"; group 1 gets "b","e"; group 
2 gets "c","f"
+        let values: ArrayRef =
+            Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d", "e", 
"f"]));
+        let group_indices = vec![0, 1, 2, 0, 1, 2];
+        acc.update_batch(&[values], &group_indices, None, 3)?;
+
+        let result = evaluate_groups(&mut acc, EmitTo::All);
+        assert_eq!(
+            result,
+            vec![
+                Some("a,d".to_string()),
+                Some("b,e".to_string()),
+                Some("c,f".to_string()),
+            ]
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn groups_with_nulls() -> Result<()> {
+        let mut acc = make_groups_acc("|");
+
+        // Group 0: "a", NULL, "c" → "a|c"
+        // Group 1: NULL, "b"     → "b"
+        // Group 2: NULL only     → NULL
+        let values: ArrayRef = Arc::new(LargeStringArray::from(vec![
+            Some("a"),
+            None,
+            Some("c"),
+            None,
+            Some("b"),
+            None,
+        ]));
+        let group_indices = vec![0, 1, 0, 2, 1, 2];
+        acc.update_batch(&[values], &group_indices, None, 3)?;
+
+        let result = evaluate_groups(&mut acc, EmitTo::All);
+        assert_eq!(
+            result,
+            vec![Some("a|c".to_string()), Some("b".to_string()), None,]
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn groups_with_filter() -> Result<()> {
+        let mut acc = make_groups_acc(",");
+
+        let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", "b", 
"c", "d"]));
+        let group_indices = vec![0, 0, 1, 1];
+        // Filter: only rows 0 and 3 are included
+        let filter = BooleanArray::from(vec![true, false, false, true]);
+        acc.update_batch(&[values], &group_indices, Some(&filter), 2)?;
+
+        let result = evaluate_groups(&mut acc, EmitTo::All);
+        assert_eq!(result, vec![Some("a".to_string()), Some("d".to_string())]);
+        Ok(())
+    }
+
+    #[test]
+    fn groups_emit_first() -> Result<()> {
+        let mut acc = make_groups_acc(",");
+
+        let values: ArrayRef =
+            Arc::new(LargeStringArray::from(vec!["a", "b", "c", "d", "e", 
"f"]));
+        let group_indices = vec![0, 1, 2, 0, 1, 2];
+        acc.update_batch(&[values], &group_indices, None, 3)?;
+
+        // Emit only the first 2 groups
+        let result = evaluate_groups(&mut acc, EmitTo::First(2));
+        assert_eq!(
+            result,
+            vec![Some("a,d".to_string()), Some("b,e".to_string())]
+        );
+
+        // Group 2 (now shifted to index 0) should still be intact
+        let result = evaluate_groups(&mut acc, EmitTo::All);
+        assert_eq!(result, vec![Some("c,f".to_string())]);
+        Ok(())
+    }
+
+    #[test]
+    fn groups_merge_batch() -> Result<()> {
+        let mut acc = make_groups_acc(",");
+
+        // First batch: group 0 = "a", group 1 = "b"
+        let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", 
"b"]));
+        acc.update_batch(&[values], &[0, 1], None, 2)?;
+
+        // Simulate a second accumulator's state (LargeUtf8 partial strings)
+        let partial_state: ArrayRef = 
Arc::new(LargeStringArray::from(vec!["c,d", "e"]));
+        acc.merge_batch(&[partial_state], &[0, 1], None, 2)?;
+
+        let result = evaluate_groups(&mut acc, EmitTo::All);
+        assert_eq!(
+            result,
+            vec![Some("a,c,d".to_string()), Some("b,e".to_string())]
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn groups_empty_groups() -> Result<()> {
+        let mut acc = make_groups_acc(",");
+
+        // 4 groups total, but only groups 0 and 2 receive values
+        let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", 
"b"]));
+        acc.update_batch(&[values], &[0, 2], None, 4)?;
+
+        let result = evaluate_groups(&mut acc, EmitTo::All);
+        assert_eq!(
+            result,
+            vec![
+                Some("a".to_string()),
+                None, // group 1: never received a value
+                Some("b".to_string()),
+                None, // group 3: never received a value
+            ]
+        );
+        Ok(())
+    }
+
+    #[test]
+    fn groups_multiple_batches() -> Result<()> {
+        let mut acc = make_groups_acc("|");
+
+        // Batch 1: 2 groups
+        let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["a", 
"b"]));
+        acc.update_batch(&[values], &[0, 1], None, 2)?;
+
+        // Batch 2: same groups, plus a new group
+        let values: ArrayRef = Arc::new(LargeStringArray::from(vec!["c", "d", 
"e"]));
+        acc.update_batch(&[values], &[0, 1, 2], None, 3)?;
+
+        let result = evaluate_groups(&mut acc, EmitTo::All);
+        assert_eq!(
+            result,
+            vec![
+                Some("a|c".to_string()),
+                Some("b|d".to_string()),
+                Some("e".to_string()),
+            ]
+        );
+        Ok(())
+    }
 }
diff --git a/datafusion/sqllogictest/test_files/aggregate.slt 
b/datafusion/sqllogictest/test_files/aggregate.slt
index 1f2a81d334..e42ebd4ce7 100644
--- a/datafusion/sqllogictest/test_files/aggregate.slt
+++ b/datafusion/sqllogictest/test_files/aggregate.slt
@@ -6991,7 +6991,7 @@ SELECT STRING_AGG(DISTINCT x,'|' ORDER BY x) FROM strings
 ----
 a|b|i|j|p|x|y|z
 
-query error This feature is not implemented: The second argument of the 
string_agg function must be a string literal
+query error This feature is not implemented: string_agg delimiter must be a 
string literal
 SELECT STRING_AGG(DISTINCT x,y) FROM strings
 
 query error Execution error: In an aggregate with DISTINCT, ORDER BY 
expressions must appear in argument list


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to