This is an automated email from the ASF dual-hosted git repository.
englefly pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new 6cf39e82c04 [enhance](nereids) expand support for eliminate agg by
unique (#48317)
6cf39e82c04 is described below
commit 6cf39e82c04aaa359bc98a4b3f86ee770399a1f3
Author: feiniaofeiafei <[email protected]>
AuthorDate: Thu May 29 15:24:22 2025 +0800
[enhance](nereids) expand support for eliminate agg by unique (#48317)
### What problem does this PR solve?
Issue Number: close #xxx
Related PR: #xxx
Problem Summary:
This PR enhances the EliminateGroupBy rule to support eliminating
aggregations when the group-by key is unique, extending its
functionality to handle scenarios where the aggregate function's child
is not a slot reference. Additionally, it adds support for the avg,
sum0, percentile, stddev, stddev_sample, variance, variance_sample,
min_by, max_by, avg_weighted function.
e.g.
select a,max(a+1) from t group by a;
->
select a,a+1 from t;
select a,count(1) from t group by a;
->
select a,1 from t;
select a,avg(b) from t group by a;
->
select a,cast(b as double) from t;
---
.../nereids/rules/rewrite/EliminateGroupBy.java | 145 +++++++++++++++------
.../nereids/trees/expressions/literal/Literal.java | 22 ++++
.../doris/nereids/properties/UniqueTest.java | 2 +-
.../rules/rewrite/EliminateGroupByTest.java | 47 +++++++
.../eliminate_gby_key/eliminate_group_by.out | Bin 0 -> 3494 bytes
.../eliminate_gby_key/eliminate_group_by.groovy | 54 ++++++++
6 files changed, 228 insertions(+), 42 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
index 9325607dd70..4bf1c04791e 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupBy.java
@@ -25,13 +25,31 @@ import org.apache.doris.nereids.trees.expressions.IsNull;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
import org.apache.doris.nereids.trees.expressions.Slot;
import
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AnyValue;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Avg;
+import org.apache.doris.nereids.trees.expressions.functions.agg.AvgWeighted;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Max;
+import org.apache.doris.nereids.trees.expressions.functions.agg.MaxBy;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Median;
import org.apache.doris.nereids.trees.expressions.functions.agg.Min;
+import org.apache.doris.nereids.trees.expressions.functions.agg.MinBy;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Percentile;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Stddev;
+import org.apache.doris.nereids.trees.expressions.functions.agg.StddevSamp;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Sum0;
+import org.apache.doris.nereids.trees.expressions.functions.agg.Variance;
+import org.apache.doris.nereids.trees.expressions.functions.agg.VarianceSamp;
+import org.apache.doris.nereids.trees.expressions.functions.scalar.Coalesce;
import org.apache.doris.nereids.trees.expressions.functions.scalar.If;
import org.apache.doris.nereids.trees.expressions.literal.BigIntLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.DoubleLiteral;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
+import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
import org.apache.doris.nereids.trees.plans.Plan;
+import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.types.DoubleType;
import org.apache.doris.nereids.util.ExpressionUtils;
import org.apache.doris.nereids.util.PlanUtils;
import org.apache.doris.nereids.util.TypeCoercionUtils;
@@ -46,54 +64,99 @@ import java.util.List;
* Eliminate GroupBy.
*/
public class EliminateGroupBy extends OneRewriteRuleFactory {
+ private static final ImmutableSet<Class<? extends Expression>>
supportedBasicFunctions
+ = ImmutableSet.of(Sum.class, Avg.class, Min.class, Max.class,
Median.class, AnyValue.class);
+ private static final ImmutableSet<Class<? extends Expression>>
supportedTwoArgsFunctions
+ = ImmutableSet.of(MinBy.class, MaxBy.class, AvgWeighted.class,
Percentile.class);
+ private static final ImmutableSet<Class<? extends Expression>>
supportedDevLikeFunctions
+ = ImmutableSet.of(Stddev.class, StddevSamp.class, Variance.class,
VarianceSamp.class);
+ private static final ImmutableSet<Class<? extends Expression>>
supportedFunctionSum0
+ = ImmutableSet.of(Sum0.class);
+ private static final ImmutableSet<Class<? extends Expression>>
allFunctionsExceptCount
+ = ImmutableSet.<Class<? extends Expression>>builder()
+ .addAll(supportedBasicFunctions)
+ .addAll(supportedTwoArgsFunctions)
+ .addAll(supportedDevLikeFunctions)
+ .addAll(supportedFunctionSum0)
+ .build();
@Override
public Rule build() {
return logicalAggregate()
.when(agg ->
ExpressionUtils.allMatch(agg.getGroupByExpressions(), Slot.class::isInstance))
- .then(agg -> {
- List<Expression> groupByExpressions =
agg.getGroupByExpressions();
- Builder<Slot> groupBySlots
- =
ImmutableSet.builderWithExpectedSize(groupByExpressions.size());
- for (Expression groupByExpression : groupByExpressions) {
- groupBySlots.add((Slot) groupByExpression);
- }
- Plan child = agg.child();
- boolean unique = child.getLogicalProperties()
- .getTrait()
- .isUniqueAndNotNull(groupBySlots.build());
- if (!unique) {
- return null;
- }
- for (AggregateFunction f : agg.getAggregateFunctions()) {
- if (!((f instanceof Sum || f instanceof Count || f
instanceof Min || f instanceof Max)
- && (f.arity() == 1 && f.child(0) instanceof
Slot))) {
- return null;
- }
- }
- List<NamedExpression> outputExpressions =
agg.getOutputExpressions();
+ .then(this::rewrite).toRule(RuleType.ELIMINATE_GROUP_BY);
+ }
- ImmutableList.Builder<NamedExpression> newOutput
- =
ImmutableList.builderWithExpectedSize(outputExpressions.size());
+ private Plan rewrite(LogicalAggregate<Plan> agg) {
+ List<Expression> groupByExpressions = agg.getGroupByExpressions();
+ Builder<Slot> groupBySlots
+ =
ImmutableSet.builderWithExpectedSize(groupByExpressions.size());
+ for (Expression groupByExpression : groupByExpressions) {
+ groupBySlots.add((Slot) groupByExpression);
+ }
+ Plan child = agg.child();
+ boolean unique = child.getLogicalProperties()
+ .getTrait()
+ .isUniqueAndNotNull(groupBySlots.build());
+ if (!unique) {
+ return null;
+ }
+ for (AggregateFunction f : agg.getAggregateFunctions()) {
+ if (!canRewrite(f)) {
+ return null;
+ }
+ }
+ List<NamedExpression> outputExpressions = agg.getOutputExpressions();
- for (NamedExpression ne : outputExpressions) {
- if (ne instanceof Alias && ne.child(0) instanceof
AggregateFunction) {
- AggregateFunction f = (AggregateFunction)
ne.child(0);
- if (f instanceof Sum || f instanceof Min || f
instanceof Max) {
- newOutput.add(new Alias(ne.getExprId(),
TypeCoercionUtils
- .castIfNotSameType(f.child(0),
f.getDataType()), ne.getName()));
- } else if (f instanceof Count) {
- newOutput.add((NamedExpression)
ne.withChildren(
- new If(new IsNull(f.child(0)), new
BigIntLiteral(0),
- new BigIntLiteral(1))));
- } else {
- throw new IllegalStateException("Unexpected
aggregate function: " + f);
- }
- } else {
- newOutput.add(ne);
- }
+ ImmutableList.Builder<NamedExpression> newOutput
+ =
ImmutableList.builderWithExpectedSize(outputExpressions.size());
+
+ for (NamedExpression ne : outputExpressions) {
+ if (ne instanceof Alias && ne.child(0) instanceof
AggregateFunction) {
+ AggregateFunction f = (AggregateFunction) ne.child(0);
+ if (supportedBasicFunctions.contains(f.getClass())) {
+ newOutput.add(new Alias(ne.getExprId(), TypeCoercionUtils
+ .castIfNotSameType(f.child(0), f.getDataType()),
ne.getName()));
+ } else if (f instanceof Count) {
+ if (((Count) f).isStar()) {
+ newOutput.add((NamedExpression)
ne.withChildren(TypeCoercionUtils
+ .castIfNotSameType(new BigIntLiteral(1),
f.getDataType())));
+ } else {
+ newOutput.add((NamedExpression) ne.withChildren(
+ new If(new IsNull(f.child(0)), new
BigIntLiteral(0),
+ new BigIntLiteral(1))));
}
- return PlanUtils.projectOrSelf(newOutput.build(), child);
- }).toRule(RuleType.ELIMINATE_GROUP_BY);
+ } else if (f instanceof Sum0) {
+ Coalesce coalesce = new Coalesce(f.child(0),
+ Literal.convertToTypedLiteral(0,
f.child(0).getDataType()));
+ newOutput.add((NamedExpression) ne.withChildren(
+ TypeCoercionUtils.castIfNotSameType(coalesce,
f.getDataType())));
+ } else if (supportedTwoArgsFunctions.contains(f.getClass())) {
+ If ifFunc = new If(new IsNull(f.child(1)), new
NullLiteral(f.child(0).getDataType()),
+ f.child(0));
+ newOutput.add((NamedExpression) ne.withChildren(
+ TypeCoercionUtils.castIfNotSameType(ifFunc,
f.getDataType())));
+ } else if (supportedDevLikeFunctions.contains(f.getClass())) {
+ If ifFunc = new If(new IsNull(f.child(0)), new
NullLiteral(DoubleType.INSTANCE),
+ new DoubleLiteral(0));
+ newOutput.add((NamedExpression) ne.withChildren(ifFunc));
+ } else {
+ return null;
+ }
+ } else {
+ newOutput.add(ne);
+ }
+ }
+ return PlanUtils.projectOrSelf(newOutput.build(), child);
+ }
+
+ private boolean canRewrite(AggregateFunction f) {
+ if (allFunctionsExceptCount.contains(f.getClass())) {
+ return true;
+ }
+ if (f instanceof Count) {
+ return ((Count) f).isStar() || 1 == f.arity();
+ }
+ return false;
}
}
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
index 22e2f2ce60b..784f3fba0f3 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/literal/Literal.java
@@ -30,6 +30,7 @@ import org.apache.doris.nereids.exceptions.UnboundException;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.shape.LeafExpression;
import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
+import org.apache.doris.nereids.types.BigIntType;
import org.apache.doris.nereids.types.CharType;
import org.apache.doris.nereids.types.DataType;
import org.apache.doris.nereids.types.DateTimeType;
@@ -37,8 +38,12 @@ import org.apache.doris.nereids.types.DateTimeV2Type;
import org.apache.doris.nereids.types.DateType;
import org.apache.doris.nereids.types.DecimalV2Type;
import org.apache.doris.nereids.types.DecimalV3Type;
+import org.apache.doris.nereids.types.DoubleType;
+import org.apache.doris.nereids.types.IntegerType;
import org.apache.doris.nereids.types.LargeIntType;
+import org.apache.doris.nereids.types.SmallIntType;
import org.apache.doris.nereids.types.StringType;
+import org.apache.doris.nereids.types.TinyIntType;
import org.apache.doris.nereids.types.VarcharType;
import org.apache.doris.nereids.types.coercion.IntegralType;
@@ -672,4 +677,21 @@ public abstract class Literal extends Expression
implements LeafExpression {
// different environment
return new VarcharLiteral(new String(bytes, StandardCharsets.UTF_8));
}
+
+ /**convertToTypedLiteral*/
+ public static Literal convertToTypedLiteral(Object value, DataType
dataType) {
+ Number number = (Number) value;
+ if (dataType.equals(TinyIntType.INSTANCE)) {
+ return new TinyIntLiteral(number.byteValue());
+ } else if (dataType.equals(SmallIntType.INSTANCE)) {
+ return new SmallIntLiteral(number.shortValue());
+ } else if (dataType.equals(IntegerType.INSTANCE)) {
+ return new IntegerLiteral(number.intValue());
+ } else if (dataType.equals(BigIntType.INSTANCE)) {
+ return new BigIntLiteral(number.longValue());
+ } else if (dataType.equals(DoubleType.INSTANCE)) {
+ return new DoubleLiteral(number.doubleValue());
+ }
+ return null;
+ }
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniqueTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniqueTest.java
index 391bc82021f..fbcf0c8028f 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniqueTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/properties/UniqueTest.java
@@ -69,7 +69,7 @@ class UniqueTest extends TestWithFeService {
Assertions.assertTrue(plan.getLogicalProperties().getTrait()
.isUnique(plan.getOutput().get(0)));
plan = PlanChecker.from(connectContext)
- .analyze("select id, sum(id), avg(id), max(id), min(id) from
agg group by id")
+ .analyze("select id, sum(id), avg(id), max(id), min(id),
topn(id,2) from agg group by id")
.rewrite()
.getPlan();
Assertions.assertTrue(plan.getLogicalProperties().getTrait()
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java
index f2a9e480f32..6d27469c469 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByTest.java
@@ -98,4 +98,51 @@ class EliminateGroupByTest extends TestWithFeService
implements MemoPatternMatch
)
);
}
+
+ @Test
+ void eliminateAvg() {
+ String sql = "select id, avg(age) from t group by id";
+
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(
+ logicalEmptyRelation().when(p ->
p.getProjects().get(0).toSql().equals("id")
+ &&
p.getProjects().get(1).toSql().equals("cast(age as DOUBLE) AS `avg(age)`")
+ &&
p.getProjects().get(1).getDataType().isDoubleType()
+ )
+ );
+ }
+
+ @Test
+ void eliminateCountStar() {
+ String sql = "select id, count(*) from t group by id";
+
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(
+ logicalEmptyRelation().when(p ->
p.getProjects().get(0).toSql().equals("id")
+ && p.getProjects().get(1).toSql().equals("1 AS
`count(*)`")
+ &&
p.getProjects().get(1).getDataType().isBigIntType()
+ )
+ );
+ }
+
+ @Test
+ void eliminateExpr() {
+ String sql = "select id, avg(age+1), min(abs(age)) from t group by id";
+
+ PlanChecker.from(connectContext)
+ .analyze(sql)
+ .rewrite()
+ .matches(
+ logicalEmptyRelation().when(p ->
p.getProjects().get(0).toSql().equals("id")
+ &&
p.getProjects().get(1).toSql().equals("cast((age + 1) as DOUBLE) AS
`avg(age+1)`")
+ &&
p.getProjects().get(1).getDataType().isDoubleType()
+ &&
p.getProjects().get(2).toSql().equals("abs(age) AS `min(abs(age))`")
+ &&
p.getProjects().get(2).getDataType().isBigIntType()
+ )
+ );
+ }
}
diff --git
a/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.out
b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.out
new file mode 100644
index 00000000000..ab48af56fae
Binary files /dev/null and
b/regression-test/data/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.out
differ
diff --git
a/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.groovy
b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.groovy
new file mode 100644
index 00000000000..97a858a5c2e
--- /dev/null
+++
b/regression-test/suites/nereids_rules_p0/eliminate_gby_key/eliminate_group_by.groovy
@@ -0,0 +1,54 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+suite("eliminate_group_by") {
+// sql "set disable_nereids_rules='ELIMINATE_GROUP_BY'"
+ sql "SET ignore_shape_nodes='PhysicalDistribute,PhysicalQuickSort'"
+ sql "drop table if exists test_unique2;"
+ sql """create table test_unique2(a int not null, b int) unique key(a)
distributed by hash(a) properties("replication_num"="1");"""
+ sql "insert into test_unique2 values(1,2),(2,2),(3,4),(4,4),(5,null);"
+ qt_count_star "select a,count(*) from test_unique2 group by a order by
1,2;"
+ qt_count_1 "select a,count(1) from test_unique2 group by a order by 1,2;"
+ qt_avg "select a,avg(b) from test_unique2 group by a order by 1,2;"
+ qt_expr "select a,max(a+1),avg(abs(a+100)),sum(a+b) from test_unique2
group by a order by 1,2,3,4;"
+ qt_window "select a,avg(sum(b) over(partition by b order by a)) from
test_unique2 group by a order by 1,2"
+ qt_two_args_func_min_by "select min_by(3,b),min_by(b,2),min_by(b,a),
min_by(null,a), min_by(b, null), min_by(null, null) from test_unique2 group by
a order by 1,2,3,4,5,6 "
+ qt_two_args_func_max_by "select max_by(3,a),max_by(a,2),max_by(b,a),
max_by(null,a), max_by(b, null), max_by(null, null) from test_unique2 group by
a order by 1,2,3,4,5,6 "
+ qt_two_args_func_avg_weighted "select avg_weighted(b,2),
avg_weighted(b,a), avg_weighted(null,a), avg_weighted(b, null),
avg_weighted(null, null) from test_unique2 group by a order by 1,2,3,4,5 "
+ qt_two_args_func_percentile "select percentile(b, null), percentile(null,
null),percentile(b, 0.3) from test_unique2 group by a order by 1,2,3"
+ qt_stddev "select a,stddev(b),stddev(null) from test_unique2 group by a
order by 1,2,3;"
+ qt_stddev_samp "select a,stddev_samp(b),stddev_samp(null) from
test_unique2 group by a order by 1,2,3;"
+ qt_variance "select a,variance(b),variance(null) from test_unique2 group
by a order by 1,2,3;"
+ qt_variance_samp "select a,variance_samp(b),variance_samp(null) from
test_unique2 group by a order by 1,2,3;"
+ qt_sum0 "select a,sum0(b),sum0(null) from test_unique2 group by a order by
1,2,3;"
+ qt_median "select
a,median(b),any_value(b),percentile(a,0.1),percentile(b,0.9),percentile(b,0.4)
from test_unique2 group by a order by 1,2,3,4,5,6;"
+
+ qt_count_star_shape "explain shape plan select a,count(*) from
test_unique2 group by a order by 1,2;"
+ qt_count_1_shape "explain shape plan select a,count(1) from test_unique2
group by a order by 1,2;"
+ qt_avg_shape "explain shape plan select a,avg(b) from test_unique2 group
by a order by 1,2;"
+ qt_expr_shape "explain shape plan select
a,max(a+1),avg(abs(a+100)),sum(a+b) from test_unique2 group by a order by
1,2,3,4;"
+ qt_window_shape "explain shape plan select a,avg(sum(b) over(partition by
b order by a)) from test_unique2 group by a order by 1,2"
+ qt_two_args_func_min_by_shape "explain shape plan select
min_by(3,b),min_by(b,2),min_by(b,a), min_by(null,a), min_by(b, null),
min_by(null, null) from test_unique2 group by a order by 1,2,3,4,5,6 "
+ qt_two_args_func_max_by_shape "explain shape plan select
max_by(3,a),max_by(a,2),max_by(b,a), max_by(null,a), max_by(b, null),
max_by(null, null) from test_unique2 group by a order by 1,2,3,4,5,6 "
+ qt_two_args_func_avg_weighted_shape "explain shape plan select
avg_weighted(b,2), avg_weighted(b,a), avg_weighted(null,a), avg_weighted(b,
null), avg_weighted(null, null) from test_unique2 group by a order by 1,2,3,4,5
"
+ qt_two_args_func_percentile_shape "explain shape plan select percentile(b,
null), percentile(null, null),percentile(b, 0.3) from test_unique2 group by a
order by 1,2,3"
+ qt_stddev_shape "explain shape plan select a,stddev(b),stddev(null) from
test_unique2 group by a order by 1,2,3;"
+ qt_stddev_samp_shape "explain shape plan select
a,stddev_samp(b),stddev_samp(null) from test_unique2 group by a order by 1,2,3;"
+ qt_variance_shape "explain shape plan select a,variance(b),variance(null)
from test_unique2 group by a order by 1,2,3;"
+ qt_variance_samp_shape "explain shape plan select
a,variance_samp(b),variance_samp(null) from test_unique2 group by a order by
1,2,3;"
+ qt_sum0_shape "explain shape plan select a,sum0(b),sum0(null) from
test_unique2 group by a order by 1,2,3;"
+ qt_median_shape "explain shape plan select
a,median(b),any_value(b),percentile(a,0.1),percentile(b,0.9),percentile(b,0.4)
from test_unique2 group by a order by 1,2,3,4,5,6;"
+}
\ No newline at end of file
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]