This is an automated email from the ASF dual-hosted git repository.
github-bot pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/main by this push:
new 848cd6387a Eliminate deterministic group by keys with deterministic
transformations (#20706)
848cd6387a is described below
commit 848cd6387af5c8d138d5cf4ab5299b3659d0605a
Author: Daniƫl Heres <[email protected]>
AuthorDate: Thu Mar 5 11:49:46 2026 +0100
Eliminate deterministic group by keys with deterministic transformations
(#20706)
## Which issue does this PR close?
- Helps with #18489
## Rationale for this change
Make queries go faster like this randomly selected one:
```
SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*)
AS c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP"
- 3
```
## What changes are included in this PR?
## Are these changes tested?
## Are there any user-facing changes?
---
.../optimizer/src/eliminate_group_by_constant.rs | 111 +++++++++++++++++----
datafusion/sqllogictest/test_files/clickbench.slt | 17 ++--
2 files changed, 99 insertions(+), 29 deletions(-)
diff --git a/datafusion/optimizer/src/eliminate_group_by_constant.rs
b/datafusion/optimizer/src/eliminate_group_by_constant.rs
index e93edc6240..6f5ca59e31 100644
--- a/datafusion/optimizer/src/eliminate_group_by_constant.rs
+++ b/datafusion/optimizer/src/eliminate_group_by_constant.rs
@@ -15,10 +15,13 @@
// specific language governing permissions and limitations
// under the License.
-//! [`EliminateGroupByConstant`] removes constant expressions from `GROUP BY`
clause
+//! [`EliminateGroupByConstant`] removes constant and functionally redundant
+//! expressions from `GROUP BY` clause
use crate::optimizer::ApplyOrder;
use crate::{OptimizerConfig, OptimizerRule};
+use std::collections::HashSet;
+
use datafusion_common::Result;
use datafusion_common::tree_node::Transformed;
use datafusion_expr::{Aggregate, Expr, LogicalPlan, LogicalPlanBuilder,
Volatility};
@@ -47,25 +50,30 @@ impl OptimizerRule for EliminateGroupByConstant {
) -> Result<Transformed<LogicalPlan>> {
match plan {
LogicalPlan::Aggregate(aggregate) => {
- let (const_group_expr, nonconst_group_expr): (Vec<_>, Vec<_>)
= aggregate
+ // Collect bare column references in GROUP BY
+ let group_by_columns: HashSet<&datafusion_common::Column> =
aggregate
.group_expr
.iter()
- .partition(|expr| is_constant_expression(expr));
-
- // If no constant expressions found (nothing to optimize) or
- // constant expression is the only expression in aggregate,
- // optimization is skipped
- if const_group_expr.is_empty()
- || (!const_group_expr.is_empty()
- && nonconst_group_expr.is_empty()
- && aggregate.aggr_expr.is_empty())
+ .filter_map(|expr| match expr {
+ Expr::Column(c) => Some(c),
+ _ => None,
+ })
+ .collect();
+
+ let (redundant, required): (Vec<_>, Vec<_>) = aggregate
+ .group_expr
+ .iter()
+ .partition(|expr| is_redundant_group_expr(expr,
&group_by_columns));
+
+ if redundant.is_empty()
+ || (required.is_empty() && aggregate.aggr_expr.is_empty())
{
return
Ok(Transformed::no(LogicalPlan::Aggregate(aggregate)));
}
let simplified_aggregate =
LogicalPlan::Aggregate(Aggregate::try_new(
aggregate.input,
- nonconst_group_expr.into_iter().cloned().collect(),
+ required.into_iter().cloned().collect(),
aggregate.aggr_expr.clone(),
)?);
@@ -91,23 +99,47 @@ impl OptimizerRule for EliminateGroupByConstant {
}
}
-/// Checks if expression is constant, and can be eliminated from group by.
-///
-/// Intended to be used only within this rule, helper function, which heavily
-/// relies on `SimplifyExpressions` result.
-fn is_constant_expression(expr: &Expr) -> bool {
+/// Checks if a GROUP BY expression is redundant (can be removed without
+/// changing grouping semantics). An expression is redundant if it is a
+/// deterministic function of constants and columns already present as bare
+/// column references in the GROUP BY.
+fn is_redundant_group_expr(
+ expr: &Expr,
+ group_by_columns: &HashSet<&datafusion_common::Column>,
+) -> bool {
+ // Bare column references are never redundant - they define the grouping
+ if matches!(expr, Expr::Column(_)) {
+ return false;
+ }
+ is_deterministic_of(expr, group_by_columns)
+}
+
+/// Returns true if `expr` is a deterministic expression whose only column
+/// references are contained in `known_columns`.
+fn is_deterministic_of(
+ expr: &Expr,
+ known_columns: &HashSet<&datafusion_common::Column>,
+) -> bool {
match expr {
- Expr::Alias(e) => is_constant_expression(&e.expr),
+ Expr::Alias(e) => is_deterministic_of(&e.expr, known_columns),
+ Expr::Column(c) => known_columns.contains(c),
+ Expr::Literal(_, _) => true,
Expr::BinaryExpr(e) => {
- is_constant_expression(&e.left) && is_constant_expression(&e.right)
+ is_deterministic_of(&e.left, known_columns)
+ && is_deterministic_of(&e.right, known_columns)
}
- Expr::Literal(_, _) => true,
Expr::ScalarFunction(e) => {
matches!(
e.func.signature().volatility,
Volatility::Immutable | Volatility::Stable
- ) && e.args.iter().all(is_constant_expression)
+ ) && e
+ .args
+ .iter()
+ .all(|arg| is_deterministic_of(arg, known_columns))
}
+ Expr::Cast(e) => is_deterministic_of(&e.expr, known_columns),
+ Expr::TryCast(e) => is_deterministic_of(&e.expr, known_columns),
+ Expr::Negative(e) => is_deterministic_of(e, known_columns),
_ => false,
}
}
@@ -268,6 +300,43 @@ mod tests {
")
}
+ #[test]
+ fn test_eliminate_deterministic_expr_of_group_by_column() -> Result<()> {
+ let scan = test_table_scan()?;
+ // GROUP BY a, a - 1, a - 2, a - 3 -> GROUP BY a
+ let plan = LogicalPlanBuilder::from(scan)
+ .aggregate(
+ vec![
+ col("a"),
+ col("a") - lit(1u32),
+ col("a") - lit(2u32),
+ col("a") - lit(3u32),
+ ],
+ vec![count(col("c"))],
+ )?
+ .build()?;
+
+ assert_optimized_plan_equal!(plan, @r"
+ Projection: test.a, test.a - UInt32(1), test.a - UInt32(2), test.a -
UInt32(3), count(test.c)
+ Aggregate: groupBy=[[test.a]], aggr=[[count(test.c)]]
+ TableScan: test
+ ")
+ }
+
+ #[test]
+ fn test_no_eliminate_independent_columns() -> Result<()> {
+ // GROUP BY a, b - 1 should NOT eliminate b - 1 (b is not a group by
column)
+ let scan = test_table_scan()?;
+ let plan = LogicalPlanBuilder::from(scan)
+ .aggregate(vec![col("a"), col("b") - lit(1u32)],
vec![count(col("c"))])?
+ .build()?;
+
+ assert_optimized_plan_equal!(plan, @r"
+ Aggregate: groupBy=[[test.a, test.b - UInt32(1)]],
aggr=[[count(test.c)]]
+ TableScan: test
+ ")
+ }
+
#[test]
fn test_no_op_volatile_scalar_fn_with_constant_arg() -> Result<()> {
let udf = ScalarUDF::new_from_impl(ScalarUDFMock::new_with_volatility(
diff --git a/datafusion/sqllogictest/test_files/clickbench.slt
b/datafusion/sqllogictest/test_files/clickbench.slt
index 10059664ad..dd558a4f36 100644
--- a/datafusion/sqllogictest/test_files/clickbench.slt
+++ b/datafusion/sqllogictest/test_files/clickbench.slt
@@ -959,19 +959,20 @@ EXPLAIN SELECT "ClientIP", "ClientIP" - 1, "ClientIP" -
2, "ClientIP" - 3, COUNT
----
logical_plan
01)Sort: c DESC NULLS FIRST, fetch=10
-02)--Projection: hits.ClientIP, hits.ClientIP - Int64(1), hits.ClientIP -
Int64(2), hits.ClientIP - Int64(3), count(Int64(1)) AS count(*) AS c
-03)----Aggregate: groupBy=[[hits.ClientIP, __common_expr_1 AS hits.ClientIP -
Int64(1), __common_expr_1 AS hits.ClientIP - Int64(2), __common_expr_1 AS
hits.ClientIP - Int64(3)]], aggr=[[count(Int64(1))]]
-04)------Projection: CAST(hits.ClientIP AS Int64) AS __common_expr_1,
hits.ClientIP
+02)--Projection: hits.ClientIP, __common_expr_1 - Int64(1) AS hits.ClientIP -
Int64(1), __common_expr_1 - Int64(2) AS hits.ClientIP - Int64(2),
__common_expr_1 - Int64(3) AS hits.ClientIP - Int64(3), count(Int64(1)) AS c
+03)----Projection: CAST(hits.ClientIP AS Int64) AS __common_expr_1,
hits.ClientIP, count(Int64(1))
+04)------Aggregate: groupBy=[[hits.ClientIP]], aggr=[[count(Int64(1))]]
05)--------SubqueryAlias: hits
06)----------TableScan: hits_raw projection=[ClientIP]
physical_plan
01)SortPreservingMergeExec: [c@4 DESC], fetch=10
02)--SortExec: TopK(fetch=10), expr=[c@4 DESC], preserve_partitioning=[true]
-03)----ProjectionExec: expr=[ClientIP@0 as ClientIP, hits.ClientIP -
Int64(1)@1 as hits.ClientIP - Int64(1), hits.ClientIP - Int64(2)@2 as
hits.ClientIP - Int64(2), hits.ClientIP - Int64(3)@3 as hits.ClientIP -
Int64(3), count(Int64(1))@4 as c]
-04)------AggregateExec: mode=FinalPartitioned, gby=[ClientIP@0 as ClientIP,
hits.ClientIP - Int64(1)@1 as hits.ClientIP - Int64(1), hits.ClientIP -
Int64(2)@2 as hits.ClientIP - Int64(2), hits.ClientIP - Int64(3)@3 as
hits.ClientIP - Int64(3)], aggr=[count(Int64(1))]
-05)--------RepartitionExec: partitioning=Hash([ClientIP@0, hits.ClientIP -
Int64(1)@1, hits.ClientIP - Int64(2)@2, hits.ClientIP - Int64(3)@3], 4),
input_partitions=1
-06)----------AggregateExec: mode=Partial, gby=[ClientIP@1 as ClientIP,
__common_expr_1@0 - 1 as hits.ClientIP - Int64(1), __common_expr_1@0 - 2 as
hits.ClientIP - Int64(2), __common_expr_1@0 - 3 as hits.ClientIP - Int64(3)],
aggr=[count(Int64(1))]
-07)------------DataSourceExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]},
projection=[CAST(ClientIP@7 AS Int64) as __common_expr_1, ClientIP],
file_type=parquet
+03)----ProjectionExec: expr=[ClientIP@1 as ClientIP, __common_expr_1@0 - 1 as
hits.ClientIP - Int64(1), __common_expr_1@0 - 2 as hits.ClientIP - Int64(2),
__common_expr_1@0 - 3 as hits.ClientIP - Int64(3), count(Int64(1))@2 as c]
+04)------ProjectionExec: expr=[CAST(ClientIP@0 AS Int64) as __common_expr_1,
ClientIP@0 as ClientIP, count(Int64(1))@1 as count(Int64(1))]
+05)--------AggregateExec: mode=FinalPartitioned, gby=[ClientIP@0 as ClientIP],
aggr=[count(Int64(1))]
+06)----------RepartitionExec: partitioning=Hash([ClientIP@0], 4),
input_partitions=1
+07)------------AggregateExec: mode=Partial, gby=[ClientIP@0 as ClientIP],
aggr=[count(Int64(1))]
+08)--------------DataSourceExec: file_groups={1 group:
[[WORKSPACE_ROOT/datafusion/core/tests/data/clickbench_hits_10.parquet]]},
projection=[ClientIP], file_type=parquet
query IIIII rowsort
SELECT "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3, COUNT(*) AS
c FROM hits GROUP BY "ClientIP", "ClientIP" - 1, "ClientIP" - 2, "ClientIP" - 3
ORDER BY c DESC LIMIT 10;
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]