This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-2.1 by this push:
     new 7123bbfecb5 branch-2.1: [fix](nereids) fix distinct window compute 
wrong result (#48987) (#49010)
7123bbfecb5 is described below

commit 7123bbfecb5c95836c564593ace544a9d5c22ed5
Author: 924060929 <lanhuaj...@selectdb.com>
AuthorDate: Fri Mar 14 10:38:16 2025 +0800

    branch-2.1: [fix](nereids) fix distinct window compute wrong result 
(#48987) (#49010)
    
    cherry pick from #48987
---
 .../doris/nereids/jobs/executor/Analyzer.java      |   9 --
 .../rules/analysis/EliminateDistinctConstant.java  |  48 --------
 .../rules/analysis/ProjectToGlobalAggregate.java   | 126 +++++++++++++++++++--
 .../analysis/ProjectWithDistinctToAggregate.java   |  57 ----------
 .../analysis/ReplaceExpressionByChildOutput.java   |  34 +++++-
 .../nereids_p0/aggregate/select_distinct.groovy    |  48 ++++++++
 6 files changed, 192 insertions(+), 130 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
index 3a111a7f4d7..855c28a7a55 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Analyzer.java
@@ -30,7 +30,6 @@ import org.apache.doris.nereids.rules.analysis.CheckAnalysis;
 import org.apache.doris.nereids.rules.analysis.CheckPolicy;
 import org.apache.doris.nereids.rules.analysis.CollectJoinConstraint;
 import org.apache.doris.nereids.rules.analysis.CollectSubQueryAlias;
-import org.apache.doris.nereids.rules.analysis.EliminateDistinctConstant;
 import org.apache.doris.nereids.rules.analysis.EliminateGroupByConstant;
 import org.apache.doris.nereids.rules.analysis.EliminateLogicalSelectHint;
 import org.apache.doris.nereids.rules.analysis.FillUpMissingSlots;
@@ -40,7 +39,6 @@ import 
org.apache.doris.nereids.rules.analysis.NormalizeAggregate;
 import org.apache.doris.nereids.rules.analysis.NormalizeRepeat;
 import org.apache.doris.nereids.rules.analysis.OneRowRelationExtractAggregate;
 import org.apache.doris.nereids.rules.analysis.ProjectToGlobalAggregate;
-import org.apache.doris.nereids.rules.analysis.ProjectWithDistinctToAggregate;
 import org.apache.doris.nereids.rules.analysis.ReplaceExpressionByChildOutput;
 import org.apache.doris.nereids.rules.analysis.SubqueryToApply;
 import org.apache.doris.nereids.rules.analysis.VariableToLiteral;
@@ -105,13 +103,6 @@ public class Analyzer extends AbstractBatchJobExecutor {
             bottomUp(new AddInitMaterializationHook()),
             bottomUp(
                     new ProjectToGlobalAggregate(),
-                    // this rule check's the logicalProject node's isDistinct 
property
-                    // and replace the logicalProject node with a 
LogicalAggregate node
-                    // so any rule before this, if create a new logicalProject 
node
-                    // should make sure isDistinct property is correctly 
passed around.
-                    // please see rule BindSlotReference or BindFunction for 
example
-                    new EliminateDistinctConstant(),
-                    new ProjectWithDistinctToAggregate(),
                     new ReplaceExpressionByChildOutput(),
                     new OneRowRelationExtractAggregate()
             ),
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateDistinctConstant.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateDistinctConstant.java
deleted file mode 100644
index 0d051ee8c87..00000000000
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/EliminateDistinctConstant.java
+++ /dev/null
@@ -1,48 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.analysis;
-
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.plans.LimitPhase;
-import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-
-/**
- * EliminateDistinctConstant.
- * <p>
- * example sql:
- * <pre>
- * select distinct 1,2,3 from tbl
- *          =>
- * select 1,2,3 from (select 1, 2, 3 from tbl limit 1) as tmp
- *  </pre>
- */
-public class EliminateDistinctConstant extends OneAnalysisRuleFactory {
-    @Override
-    public Rule build() {
-        return RuleType.ELIMINATE_DISTINCT_CONSTANT.build(
-                logicalProject()
-                        .when(LogicalProject::isDistinct)
-                        .when(project -> 
project.getProjects().stream().allMatch(Expression::isConstant))
-                        .then(project -> new 
LogicalProject(project.getProjects(), new LogicalLimit<>(1, 0,
-                                LimitPhase.ORIGIN, project.child())))
-        );
-    }
-}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java
index da642e76610..a6916756c10 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectToGlobalAggregate.java
@@ -17,13 +17,24 @@
 
 package org.apache.doris.nereids.rules.analysis;
 
+import org.apache.doris.nereids.exceptions.AnalysisException;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
+import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitors;
+import org.apache.doris.nereids.trees.plans.LimitPhase;
+import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
+import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
+import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 
 import com.google.common.collect.ImmutableList;
 
+import java.util.List;
+
 /**
  * ProjectToGlobalAggregate.
  * <p>
@@ -43,17 +54,110 @@ public class ProjectToGlobalAggregate extends 
OneAnalysisRuleFactory {
     @Override
     public Rule build() {
         return RuleType.PROJECT_TO_GLOBAL_AGGREGATE.build(
-           logicalProject().then(project -> {
-               boolean needGlobalAggregate = project.getProjects()
-                       .stream()
-                       .anyMatch(p -> 
p.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null));
-
-               if (needGlobalAggregate) {
-                   return new LogicalAggregate<>(ImmutableList.of(), 
project.getProjects(), project.child());
-               } else {
-                   return project;
-               }
-           })
+            logicalProject().then(project -> {
+                project = distinctConstantsToLimit1(project);
+                Plan result = projectToAggregate(project);
+                return distinctToAggregate(result, project);
+            })
         );
     }
+
+    // select distinct 1,2,3 from tbl
+    //               ↓
+    // select 1,2,3 from (select 1, 2, 3 from tbl limit 1) as tmp
+    private static LogicalProject<Plan> 
distinctConstantsToLimit1(LogicalProject<Plan> project) {
+        if (!project.isDistinct()) {
+            return project;
+        }
+
+        boolean allSelectItemAreConstants = true;
+        for (NamedExpression selectItem : project.getProjects()) {
+            if (!selectItem.isConstant()) {
+                allSelectItemAreConstants = false;
+                break;
+            }
+        }
+
+        if (allSelectItemAreConstants) {
+            return new LogicalProject<>(
+                    project.getProjects(),
+                    new LogicalLimit<>(1, 0, LimitPhase.ORIGIN, 
project.child())
+            );
+        }
+        return project;
+    }
+
+    // select avg(xxx) from tbl
+    //         ↓
+    // LogicalAggregate(groupBy=[], output=[avg(xxx)])
+    private static Plan projectToAggregate(LogicalProject<Plan> project) {
+        // contains aggregate functions, like sum, avg ?
+        for (NamedExpression selectItem : project.getProjects()) {
+            if 
(selectItem.accept(ExpressionVisitors.CONTAINS_AGGREGATE_CHECKER, null)) {
+                return new LogicalAggregate<>(ImmutableList.of(), 
project.getProjects(), project.child());
+            }
+        }
+        return project;
+    }
+
+    private static Plan distinctToAggregate(Plan result, LogicalProject<Plan> 
originProject) {
+        if (!originProject.isDistinct()) {
+            return result;
+        }
+        if (result instanceof LogicalProject) {
+            // remove distinct: select distinct fun(xxx) as c1 from tbl
+            //
+            // LogicalProject(distinct=true, output=[fun(xxx) as c1])
+            //                  ↓
+            // LogicalAggregate(groupBy=[c1], output=[c1])
+            //                  |
+            //   LogicalProject(output=[fun(xxx) as c1])
+            LogicalProject<?> project = (LogicalProject<?>) result;
+
+            ImmutableList.Builder<NamedExpression> bottomProjectOutput
+                    = 
ImmutableList.builderWithExpectedSize(project.getProjects().size());
+            ImmutableList.Builder<NamedExpression> topAggOutput
+                    = 
ImmutableList.builderWithExpectedSize(project.getProjects().size());
+
+            boolean hasComplexExpr = false;
+            for (NamedExpression selectItem : project.getProjects()) {
+                if (selectItem.isSlot()) {
+                    topAggOutput.add(selectItem);
+                    bottomProjectOutput.add(selectItem);
+                } else if (isAliasLiteral(selectItem)) {
+                    // stay in agg, and eliminate by 
`ELIMINATE_GROUP_BY_CONSTANT`
+                    topAggOutput.add(selectItem);
+                } else {
+                    // `FillUpMissingSlots` not support find complex expr in 
aggregate,
+                    // so we should push down into the bottom project
+                    hasComplexExpr = true;
+                    topAggOutput.add(selectItem.toSlot());
+                    bottomProjectOutput.add(selectItem);
+                }
+            }
+
+            if (!hasComplexExpr) {
+                List<Slot> projects = (List) project.getProjects();
+                return new LogicalAggregate(projects, projects, 
project.child());
+            }
+
+            LogicalProject<?> removeDistinct = new 
LogicalProject<>(bottomProjectOutput.build(), project.child());
+            ImmutableList<NamedExpression> aggOutput = topAggOutput.build();
+            return new LogicalAggregate(aggOutput, aggOutput, removeDistinct);
+        } else if (result instanceof LogicalAggregate) {
+            // remove distinct: select distinct avg(xxx) as c1 from tbl
+            //
+            // LogicalProject(distinct=true, output=[avg(xxx) as c1])
+            //                  ↓
+            //  LogicalAggregate(output=[avg(xxx) as c1])
+            return result;
+        } else {
+            // never reach
+            throw new AnalysisException("Unsupported");
+        }
+    }
+
+    private static boolean isAliasLiteral(NamedExpression selectItem) {
+        return selectItem instanceof Alias && selectItem.child(0) instanceof 
Literal;
+    }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectWithDistinctToAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectWithDistinctToAggregate.java
deleted file mode 100644
index f858820d612..00000000000
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ProjectWithDistinctToAggregate.java
+++ /dev/null
@@ -1,57 +0,0 @@
-// Licensed to the Apache Software Foundation (ASF) under one
-// or more contributor license agreements.  See the NOTICE file
-// distributed with this work for additional information
-// regarding copyright ownership.  The ASF licenses this file
-// to you under the Apache License, Version 2.0 (the
-// "License"); you may not use this file except in compliance
-// with the License.  You may obtain a copy of the License at
-//
-//   http://www.apache.org/licenses/LICENSE-2.0
-//
-// Unless required by applicable law or agreed to in writing,
-// software distributed under the License is distributed on an
-// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-// KIND, either express or implied.  See the License for the
-// specific language governing permissions and limitations
-// under the License.
-
-package org.apache.doris.nereids.rules.analysis;
-
-import org.apache.doris.nereids.rules.Rule;
-import org.apache.doris.nereids.rules.RuleType;
-import org.apache.doris.nereids.trees.expressions.Expression;
-import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
-import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
-import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
-
-/**
- * ProjectWithDistinctToAggregate.
- * <p>
- * example sql:
- * <pre>
- * select distinct value from tbl
- *
- * LogicalProject(projects=[distinct value])
- *            |
- * LogicalOlapScan(table=tbl)
- *          =>
- * LogicalAggregate(groupBy=[value], output=[value])
- *           |
- * LogicalOlapScan(table=tbl)
- *  </pre>
- */
-public class ProjectWithDistinctToAggregate extends OneAnalysisRuleFactory {
-    @Override
-    public Rule build() {
-        return RuleType.PROJECT_WITH_DISTINCT_TO_AGGREGATE.build(
-            logicalProject()
-                .when(LogicalProject::isDistinct)
-                .whenNot(project -> 
project.getProjects().stream().anyMatch(this::hasAggregateFunction))
-                .then(project -> new LogicalAggregate<>(project.getProjects(), 
false, project.child()))
-        );
-    }
-
-    private boolean hasAggregateFunction(Expression expression) {
-        return expression.anyMatch(AggregateFunction.class::isInstance);
-    }
-}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java
index cd53086f966..5dc85811d50 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/ReplaceExpressionByChildOutput.java
@@ -53,21 +53,27 @@ public class ReplaceExpressionByChildOutput implements 
AnalysisRuleFactory {
                 ))
                 .add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
                         logicalSort(logicalAggregate()).then(sort -> {
-                            LogicalAggregate<Plan> aggregate = sort.child();
-                            Map<Expression, Slot> sMap = 
buildOutputAliasMap(aggregate.getOutputExpressions());
+                            LogicalAggregate<Plan> agg = sort.child();
+                            Map<Expression, Slot> sMap = 
buildOutputAliasMap(agg.getOutputExpressions());
+                            if (sMap.isEmpty() && isSelectDistinct(agg)) {
+                                sMap = getSelectDistinctExpressions(agg);
+                            }
                             return replaceSortExpression(sort, sMap);
                         })
                 )).add(RuleType.REPLACE_SORT_EXPRESSION_BY_CHILD_OUTPUT.build(
                         
logicalSort(logicalHaving(logicalAggregate())).then(sort -> {
-                            LogicalAggregate<Plan> aggregate = 
sort.child().child();
-                            Map<Expression, Slot> sMap = 
buildOutputAliasMap(aggregate.getOutputExpressions());
+                            LogicalAggregate<Plan> agg = sort.child().child();
+                            Map<Expression, Slot> sMap = 
buildOutputAliasMap(agg.getOutputExpressions());
+                            if (sMap.isEmpty() && isSelectDistinct(agg)) {
+                                sMap = getSelectDistinctExpressions(agg);
+                            }
                             return replaceSortExpression(sort, sMap);
                         })
                 ))
                 .build();
     }
 
-    private Map<Expression, Slot> buildOutputAliasMap(List<NamedExpression> 
output) {
+    private static Map<Expression, Slot> 
buildOutputAliasMap(List<NamedExpression> output) {
         Map<Expression, Slot> sMap = 
Maps.newHashMapWithExpectedSize(output.size());
         for (NamedExpression expr : output) {
             if (expr instanceof Alias) {
@@ -93,4 +99,22 @@ public class ReplaceExpressionByChildOutput implements 
AnalysisRuleFactory {
 
         return changed ? new LogicalSort<>(newKeys.build(), sort.child()) : 
sort;
     }
+
+    private static boolean isSelectDistinct(LogicalAggregate<? extends Plan> 
agg) {
+        return agg.getGroupByExpressions().equals(agg.getOutputExpressions())
+                && agg.getGroupByExpressions().equals(agg.child().getOutput());
+    }
+
+    private static Map<Expression, Slot> 
getSelectDistinctExpressions(LogicalAggregate<? extends Plan> agg) {
+        Plan child = agg.child();
+        List<NamedExpression> selectItems;
+        if (child instanceof LogicalProject) {
+            selectItems = ((LogicalProject<?>) child).getProjects();
+        } else if (child instanceof LogicalAggregate) {
+            selectItems = ((LogicalAggregate<?>) child).getOutputExpressions();
+        } else {
+            selectItems = ImmutableList.of();
+        }
+        return buildOutputAliasMap(selectItems);
+    }
 }
diff --git a/regression-test/suites/nereids_p0/aggregate/select_distinct.groovy 
b/regression-test/suites/nereids_p0/aggregate/select_distinct.groovy
new file mode 100644
index 00000000000..ddcd2d47e82
--- /dev/null
+++ b/regression-test/suites/nereids_p0/aggregate/select_distinct.groovy
@@ -0,0 +1,48 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *  http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+suite("select_distinct") {
+    multi_sql """
+        SET enable_nereids_planner=true;
+        SET enable_fallback_to_original_planner=false;
+        drop table if exists test_distinct_window;
+        create table test_distinct_window(id int) distributed by hash(id) 
properties('replication_num'='1');
+        insert into test_distinct_window values(1), (2), (3);
+        """
+
+    test {
+        sql "select distinct sum(value) over(partition by id) from (select 100 
value, 1 id union all select 100, 2)a"
+        result([[100L]])
+    }
+
+    test {
+        sql "select distinct value+1 from (select 100 value, 1 id union all 
select 100, 2)a order by value+1"
+        result([[101]])
+    }
+
+    test {
+        sql "select distinct 1, 2, 3 from test_distinct_window"
+        result([[1, 2, 3]])
+    }
+
+    test {
+        sql "select distinct sum(id) from test_distinct_window"
+        result([[6L]])
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to