This is an automated email from the ASF dual-hosted git repository.

github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/main by this push:
     new 848cd6387a Eliminate deterministic group by keys with deterministic 
transformations (#20706)
848cd6387a is described below

commit 848cd6387af5c8d138d5cf4ab5299b3659d0605a
Author: DaniĆ«l Heres <[email protected]>
AuthorDate: Thu Mar 5 11:49:46 2026 +0100

    Eliminate deterministic group by keys with deterministic transformations 
(#20706)
    
    ## Which issue does this PR close?
    
    
    - Helps with #18489
    
    ## Rationale for this change
    
    Make queries go faster like this randomly selected one:
    
    ```
    SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) 
AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" 
- 3
    ```
    
    ## What changes are included in this PR?
    
    
    ## Are these changes tested?
    
    
    ## Are there any user-facing changes?
---
 .../optimizer/src/eliminate_group_by_constant.rs   | 111 +++++++++++++++++----
 datafusion/sqllogictest/test_files/clickbench.slt  |  17 ++--
 2 files changed, 99 insertions(+), 29 deletions(-)

diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs 
b/datafusion/optimizer/src/eliminate_group_by_constant.rs
index e93edc6240..6f5ca59e31 100644
--- a/datafusion/optimizer/src/eliminate_group_by_constant.rs
+++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs
@@ -15,10 +15,13 @@
 // specific language governing permissions and limitations
 // under the License.
 
-//! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY` 
clause
+//! [`EliminateGroupByConstant`] removes constant and functionally redundant
+//! expressions from `GROUP BY` clause
 use crate::optimizer::ApplyOrder;
 use crate::{OptimizerConfig, OptimizerRule};
 
+use std::collections::HashSet;
+
 use datafusion_common::Result;
 use datafusion_common::tree_node::Transformed;
 use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder, 
Volatility};
@@ -47,25 +50,30 @@ impl OptimizerRule for EliminateGroupByConstant {
     ) -> Result<Transformed<LogicalPlan>> {
         match plan {
             LogicalPlan::Aggregate(aggregate) => {
-                let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>) 
= aggregate
+                // Collect bare column references in GROUP BY
+                let group_by_columns: HashSet<&datafusion_common::Column> = 
aggregate
                     .group_expr
                     .iter()
-                    .partition(|expr| is_constant_expression(expr));
-
-                // If no constant expressions found (nothing to optimize) or
-                // constant expression is the only expression in aggregate,
-                // optimization is skipped
-                if const_group_expr.is_empty()
-                    || (!const_group_expr.is_empty()
-                        && nonconst_group_expr.is_empty()
-                        && aggregate.aggr_expr.is_empty())
+                    .filter_map(|expr| match expr {
+                        Expr::Column(c) => Some(c),
+                        _ => None,
+                    })
+                    .collect();
+
+                let (redundant, required): (Vec<_>, Vec<_>) = aggregate
+                    .group_expr
+                    .iter()
+                    .partition(|expr| is_redundant_group_expr(expr, 
&group_by_columns));
+
+                if redundant.is_empty()
+                    || (required.is_empty() && aggregate.aggr_expr.is_empty())
                 {
                     return 
Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
                 }
 
                 let simplified_aggregate = 
LogicalPlan::Aggregate(Aggregate::try_new(
                     aggregate.input,
-                    nonconst_group_expr.into_iter().cloned().collect(),
+                    required.into_iter().cloned().collect(),
                     aggregate.aggr_expr.clone(),
                 )?);
 
@@ -91,23 +99,47 @@ impl OptimizerRule for EliminateGroupByConstant {
     }
 }
 
-/// Checks if expression is constant, and can be eliminated from group by.
-///
-/// Intended to be used only within this rule, helper function, which heavily
-/// relies on `SimplifyExpressions` result.
-fn is_constant_expression(expr: &Expr) -> bool {
+/// Checks if a GROUP BY expression is redundant (can be removed without
+/// changing grouping semantics). An expression is redundant if it is a
+/// deterministic function of constants and columns already present as bare
+/// column references in the GROUP BY.
+fn is_redundant_group_expr(
+    expr: &Expr,
+    group_by_columns: &HashSet<&datafusion_common::Column>,
+) -> bool {
+    // Bare column references are never redundant - they define the grouping
+    if matches!(expr, Expr::Column(_)) {
+        return false;
+    }
+    is_deterministic_of(expr, group_by_columns)
+}
+
+/// Returns true if `expr` is a deterministic expression whose only column
+/// references are contained in `known_columns`.
+fn is_deterministic_of(
+    expr: &Expr,
+    known_columns: &HashSet<&datafusion_common::Column>,
+) -> bool {
     match expr {
-        Expr::Alias(e) => is_constant_expression(&e.expr),
+        Expr::Alias(e) => is_deterministic_of(&e.expr, known_columns),
+        Expr::Column(c) => known_columns.contains(c),
+        Expr::Literal(_, _) => true,
         Expr::BinaryExpr(e) => {
-            is_constant_expression(&e.left) && is_constant_expression(&e.right)
+            is_deterministic_of(&e.left, known_columns)
+                && is_deterministic_of(&e.right, known_columns)
         }
-        Expr::Literal(_, _) => true,
         Expr::ScalarFunction(e) => {
             matches!(
                 e.func.signature().volatility,
                 Volatility::Immutable | Volatility::Stable
-            ) && e.args.iter().all(is_constant_expression)
+            ) && e
+                .args
+                .iter()
+                .all(|arg| is_deterministic_of(arg, known_columns))
         }
+        Expr::Cast(e) => is_deterministic_of(&e.expr, known_columns),
+        Expr::TryCast(e) => is_deterministic_of(&e.expr, known_columns),
+        Expr::Negative(e) => is_deterministic_of(e, known_columns),
         _ => false,
     }
 }
@@ -268,6 +300,43 @@ mod tests {
         ")
     }
 
+    #[test]
+    fn test_eliminate_deterministic_expr_of_group_by_column() -> Result<()> {
+        let scan = test_table_scan()?;
+        // GROUP BY a, a - 1, a - 2, a - 3  ->  GROUP BY a
+        let plan = LogicalPlanBuilder::from(scan)
+            .aggregate(
+                vec![
+                    col("a"),
+                    col("a") - lit(1u32),
+                    col("a") - lit(2u32),
+                    col("a") - lit(3u32),
+                ],
+                vec![count(col("c"))],
+            )?
+            .build()?;
+
+        assert_optimized_plan_equal!(plan, @r"
+        Projection: test.a, test.a - UInt32(1), test.a - UInt32(2), test.a - 
UInt32(3), count(test.c)
+          Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
+            TableScan: test
+        ")
+    }
+
+    #[test]
+    fn test_no_eliminate_independent_columns() -> Result<()> {
+        // GROUP BY a, b - 1 should NOT eliminate b - 1 (b is not a group by 
column)
+        let scan = test_table_scan()?;
+        let plan = LogicalPlanBuilder::from(scan)
+            .aggregate(vec![col("a"), col("b") - lit(1u32)], 
vec![count(col("c"))])?
+            .build()?;
+
+        assert_optimized_plan_equal!(plan, @r"
+        Aggregate: groupBy=[[test.a, test.b - UInt32(1)]], 
aggr=[[count(test.c)]]
+          TableScan: test
+        ")
+    }
+
     #[test]
     fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> {
         let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
diff --git a/datafusion/sqllogictest/test_files/clickbench.slt 
b/datafusion/sqllogictest/test_files/clickbench.slt
index 10059664ad..dd558a4f36 100644
--- a/datafusion/sqllogictest/test_files/clickbench.slt
+++ b/datafusion/sqllogictest/test_files/clickbench.slt
@@ -959,19 +959,20 @@ EXPLAIN SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 
2, "ClientIP" - 3, COUNT
 ----
 logical_plan
 01)Sort: c DESC NULLS FIRST, fetch=10
-02)--Projection: hits.ClientIP, hits.ClientIP - Int64(1), hits.ClientIP - 
Int64(2), hits.ClientIP - Int64(3), count(Int64(1)) AS count(*) AS c
-03)----Aggregate: groupBy=[[hits.ClientIP, __common_expr_1 AS hits.ClientIP - 
Int64(1), __common_expr_1 AS hits.ClientIP - Int64(2), __common_expr_1 AS 
hits.ClientIP - Int64(3)]], aggr=[[count(Int64(1))]]
-04)------Projection: CAST(hits.ClientIP AS Int64) AS __common_expr_1, 
hits.ClientIP
+02)--Projection: hits.ClientIP, __common_expr_1 - Int64(1) AS hits.ClientIP - 
Int64(1), __common_expr_1 - Int64(2) AS hits.ClientIP - Int64(2), 
__common_expr_1 - Int64(3) AS hits.ClientIP - Int64(3), count(Int64(1)) AS c
+03)----Projection: CAST(hits.ClientIP AS Int64) AS __common_expr_1, 
hits.ClientIP, count(Int64(1))
+04)------Aggregate: groupBy=[[hits.ClientIP]], aggr=[[count(Int64(1))]]
 05)--------SubqueryAlias: hits
 06)----------TableScan: hits_raw projection=[ClientIP]
 physical_plan
 01)SortPreservingMergeExec: [c@4 DESC], fetch=10
 02)--SortExec: TopK(fetch=10), expr=[c@4 DESC], preserve_partitioning=[true]
-03)----ProjectionExec: expr=[ClientIP@0 as ClientIP, hits.ClientIP - 
Int64(1)@1 as hits.ClientIP - Int64(1), hits.ClientIP - Int64(2)@2 as 
hits.ClientIP - Int64(2), hits.ClientIP - Int64(3)@3 as hits.ClientIP - 
Int64(3), count(Int64(1))@4 as c]
-04)------AggregateExec: mode=FinalPartitioned, gby=[ClientIP@0 as ClientIP, 
hits.ClientIP - Int64(1)@1 as hits.ClientIP - Int64(1), hits.ClientIP - 
Int64(2)@2 as hits.ClientIP - Int64(2), hits.ClientIP - Int64(3)@3 as 
hits.ClientIP - Int64(3)], aggr=[count(Int64(1))]
-05)--------RepartitionExec: partitioning=Hash([ClientIP@0, hits.ClientIP - 
Int64(1)@1, hits.ClientIP - Int64(2)@2, hits.ClientIP - Int64(3)@3], 4), 
input_partitions=1
-06)----------AggregateExec: mode=Partial, gby=[ClientIP@1 as ClientIP, 
__common_expr_1@0 - 1 as hits.ClientIP - Int64(1), __common_expr_1@0 - 2 as 
hits.ClientIP - Int64(2), __common_expr_1@0 - 3 as hits.ClientIP - Int64(3)], 
aggr=[count(Int64(1))]
-07)------------DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, 
projection=[CAST(ClientIP@7 AS Int64) as __common_expr_1, ClientIP], 
file_type=parquet
+03)----ProjectionExec: expr=[ClientIP@1 as ClientIP, __common_expr_1@0 - 1 as 
hits.ClientIP - Int64(1), __common_expr_1@0 - 2 as hits.ClientIP - Int64(2), 
__common_expr_1@0 - 3 as hits.ClientIP - Int64(3), count(Int64(1))@2 as c]
+04)------ProjectionExec: expr=[CAST(ClientIP@0 AS Int64) as __common_expr_1, 
ClientIP@0 as ClientIP, count(Int64(1))@1 as count(Int64(1))]
+05)--------AggregateExec: mode=FinalPartitioned, gby=[ClientIP@0 as ClientIP], 
aggr=[count(Int64(1))]
+06)----------RepartitionExec: partitioning=Hash([ClientIP@0], 4), 
input_partitions=1
+07)------------AggregateExec: mode=Partial, gby=[ClientIP@0 as ClientIP], 
aggr=[count(Int64(1))]
+08)--------------DataSourceExec: file_groups={1 group: 
[[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]}, 
projection=[ClientIP], file_type=parquet
 
 query IIIII rowsort
 SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS 
c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3 
ORDER BY c DESC LIMIT 10;


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to