This is an automated email from the ASF dual-hosted git repository.

englefly pushed a commit to branch branch-3.0
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/branch-3.0 by this push:
     new e6678e682ba [fix](Nereids) set correct sort key for aggregate #45369 
branch-3.0 (#45706)
e6678e682ba is described below

commit e6678e682ba57e38969c250971a23231fb762923
Author: minghong <zhoumingh...@selectdb.com>
AuthorDate: Sun Jan 5 17:35:30 2025 +0800

    [fix](Nereids) set correct sort key for aggregate #45369 branch-3.0 (#45706)
    
    ### What problem does this PR solve?
    
    pick #45369
---
 .../nereids/processor/post/PushTopnToAgg.java      | 159 ++++-----------
 .../nereids/rules/rewrite/LimitAggToTopNAgg.java   | 222 ++++++++++++++-------
 .../trees/plans/logical/LogicalAggregate.java      |   5 +
 .../data/query_p0/limit/test_group_by_limit.out    |  26 +--
 .../nereids_tpch_p0/tpch/push_topn_to_agg.groovy   | 111 +++--------
 .../query_p0/limit/test_group_by_limit.groovy      | 101 +++++++---
 6 files changed, 307 insertions(+), 317 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushTopnToAgg.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushTopnToAgg.java
index aca3f21a7d1..d93433768c4 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushTopnToAgg.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushTopnToAgg.java
@@ -21,157 +21,84 @@
 package org.apache.doris.nereids.processor.post;
 
 import org.apache.doris.nereids.CascadesContext;
-import org.apache.doris.nereids.properties.DistributionSpecGather;
-import org.apache.doris.nereids.properties.OrderKey;
 import org.apache.doris.nereids.trees.expressions.Expression;
-import org.apache.doris.nereids.trees.plans.AggMode;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate;
 import 
org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate.TopnPushInfo;
-import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalProject;
 import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN;
 import org.apache.doris.qe.ConnectContext;
 
-import org.apache.hadoop.util.Lists;
-
-import java.util.List;
-import java.util.stream.Collectors;
-
 /**
- * Add SortInfo to Agg. This SortInfo is used as boundary, not used to sort 
elements.
+ * Add TopNInfo to Agg. This TopNInfo is used as boundary, not used to sort 
elements.
  * example
  * sql: select count(*) from orders group by o_clerk order by o_clerk limit 1;
  * plan: topn(1) -> aggGlobal -> shuffle -> aggLocal -> scan
  * optimization: aggLocal and aggGlobal only need to generate the smallest row 
with respect to o_clerk.
  *
- * TODO: the following case is not covered:
- * sql: select sum(o_shippriority) from orders group by o_clerk limit 1;
- * plan: limit -> aggGlobal -> shuffle -> aggLocal -> scan
- * aggGlobal may receive partial aggregate results, and hence is not supported 
now
- * instance1: input (key=2, v=1) => localAgg => (2, 1) => aggGlobal inst1 => 
(2, 1)
- * instance2: input (key=1, v=1), (key=2, v=2) => localAgg inst2 => (1, 1)
- * (2,1),(1,1) => limit => may output (2, 1), which is not complete, missing 
(2, 2) in instance2
- *
- *TOPN:
- *  Precondition: topn orderkeys are the prefix of group keys
- *  TODO: topnKeys could be subset of groupKeys. This will be implemented in 
future
- *  Pattern 2-phase agg:
- *     topn -> aggGlobal -> distribute -> aggLocal
- *     =>
- *     topn(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n)
- *  Pattern 1-phase agg:
- *     topn->agg->Any(not agg) -> topn -> agg(topn=n) -> any
- *
- *  LIMIT:
- *  Pattern 1: limit->agg(1phase)->any
- *  Pattern 2: limit->agg(global)->gather->agg(local)
+ * This rule only applies to the patterns
+ * 1. topn->project->agg, or
+ * 2. topn->agg
+ * that
+ * 1. orderKeys and groupkeys are one-one mapping
+ * 2. aggregate is not scalar agg
+ * Refer to LimitAggToTopNAgg rule.
  */
 public class PushTopnToAgg extends PlanPostProcessor {
     @Override
     public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, 
CascadesContext ctx) {
         topN.child().accept(this, ctx);
-        if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= 
topN.getLimit() + topN.getOffset()) {
+        if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= 
topN.getLimit() + topN.getOffset()
+                && !ConnectContext.get().getSessionVariable().pushTopnToAgg) {
             return topN;
         }
-        Plan topnChild = topN.child();
-        if (topnChild instanceof PhysicalProject) {
-            topnChild = topnChild.child(0);
+        Plan topNChild = topN.child();
+        if (topNChild instanceof PhysicalProject) {
+            topNChild = topNChild.child(0);
         }
-        if (topnChild instanceof PhysicalHashAggregate) {
-            PhysicalHashAggregate<? extends Plan> upperAgg = 
(PhysicalHashAggregate<? extends Plan>) topnChild;
-            List<OrderKey> orderKeys = 
tryGenerateOrderKeyByGroupKeyAndTopnKey(topN, upperAgg);
-            if (!orderKeys.isEmpty()) {
-
-                if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() 
== AggMode.BUFFER_TO_RESULT) {
-                    upperAgg.setTopnPushInfo(new TopnPushInfo(
-                            orderKeys,
-                            topN.getLimit() + topN.getOffset()));
-                    if (upperAgg.child() instanceof PhysicalDistribute
-                            && upperAgg.child().child(0) instanceof 
PhysicalHashAggregate) {
-                        PhysicalHashAggregate<? extends Plan> bottomAgg =
-                                (PhysicalHashAggregate<? extends Plan>) 
upperAgg.child().child(0);
+        if (topNChild instanceof PhysicalHashAggregate) {
+            PhysicalHashAggregate<? extends Plan> upperAgg = 
(PhysicalHashAggregate<? extends Plan>) topNChild;
+            if (isGroupKeyIdenticalToOrderKey(topN, upperAgg)) {
+                upperAgg.setTopnPushInfo(new TopnPushInfo(
+                        topN.getOrderKeys(),
+                        topN.getLimit() + topN.getOffset()));
+                if (upperAgg.child() instanceof PhysicalDistribute
+                        && upperAgg.child().child(0) instanceof 
PhysicalHashAggregate) {
+                    PhysicalHashAggregate<? extends Plan> bottomAgg =
+                            (PhysicalHashAggregate<? extends Plan>) 
upperAgg.child().child(0);
+                    if (isGroupKeyIdenticalToOrderKey(topN, bottomAgg)) {
+                        bottomAgg.setTopnPushInfo(new TopnPushInfo(
+                                topN.getOrderKeys(),
+                                topN.getLimit() + topN.getOffset()));
+                    }
+                } else if (upperAgg.child() instanceof PhysicalHashAggregate) {
+                    // multi-distinct plan
+                    PhysicalHashAggregate<? extends Plan> bottomAgg =
+                            (PhysicalHashAggregate<? extends Plan>) 
upperAgg.child();
+                    if (isGroupKeyIdenticalToOrderKey(topN, bottomAgg)) {
                         bottomAgg.setTopnPushInfo(new TopnPushInfo(
-                                orderKeys,
+                                topN.getOrderKeys(),
                                 topN.getLimit() + topN.getOffset()));
                     }
-                } else if (upperAgg.getAggPhase().isLocal() && 
upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
-                    // one phase agg
-                    upperAgg.setTopnPushInfo(new TopnPushInfo(
-                            orderKeys,
-                            topN.getLimit() + topN.getOffset()));
                 }
             }
         }
         return topN;
     }
 
-    /**
-     return true, if topn order-key is prefix of agg group-key, ignore 
asc/desc and null_first
-     TODO order-key can be subset of group-key. BE does not support now.
-     */
-    private List<OrderKey> 
tryGenerateOrderKeyByGroupKeyAndTopnKey(PhysicalTopN<? extends Plan> topN,
+    private boolean isGroupKeyIdenticalToOrderKey(PhysicalTopN<? extends Plan> 
topN,
             PhysicalHashAggregate<? extends Plan> agg) {
-        List<OrderKey> orderKeys = 
Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
-        if (topN.getOrderKeys().size() > agg.getGroupByExpressions().size()) {
-            return orderKeys;
-        }
-        List<Expression> topnKeys = topN.getOrderKeys().stream()
-                .map(OrderKey::getExpr).collect(Collectors.toList());
-        for (int i = 0; i < topN.getOrderKeys().size(); i++) {
-            // prefix check
-            if (!topnKeys.get(i).equals(agg.getGroupByExpressions().get(i))) {
-                return Lists.newArrayList();
-            }
-            orderKeys.add(topN.getOrderKeys().get(i));
-        }
-        for (int i = topN.getOrderKeys().size(); i < 
agg.getGroupByExpressions().size(); i++) {
-            orderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), 
true, false));
-        }
-        return orderKeys;
-    }
-
-    @Override
-    public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, 
CascadesContext ctx) {
-        limit.child().accept(this, ctx);
-        if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= 
limit.getLimit() + limit.getOffset()) {
-            return limit;
+        if (topN.getOrderKeys().size() != agg.getGroupByExpressions().size()) {
+            return false;
         }
-        Plan limitChild = limit.child();
-        if (limitChild instanceof PhysicalProject) {
-            limitChild = limitChild.child(0);
-        }
-        if (limitChild instanceof PhysicalHashAggregate) {
-            PhysicalHashAggregate<? extends Plan> upperAgg = 
(PhysicalHashAggregate<? extends Plan>) limitChild;
-            if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == 
AggMode.BUFFER_TO_RESULT) {
-                Plan child = upperAgg.child();
-                Plan grandChild = child.child(0);
-                if (child instanceof PhysicalDistribute
-                        && ((PhysicalDistribute<?>) 
child).getDistributionSpec() instanceof DistributionSpecGather
-                        && grandChild instanceof PhysicalHashAggregate) {
-                    upperAgg.setTopnPushInfo(new TopnPushInfo(
-                            generateOrderKeyByGroupKey(upperAgg),
-                            limit.getLimit() + limit.getOffset()));
-                    PhysicalHashAggregate<? extends Plan> bottomAgg =
-                            (PhysicalHashAggregate<? extends Plan>) grandChild;
-                    bottomAgg.setTopnPushInfo(new TopnPushInfo(
-                            generateOrderKeyByGroupKey(bottomAgg),
-                            limit.getLimit() + limit.getOffset()));
-                }
-            } else if (upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) {
-                // 1-phase agg
-                upperAgg.setTopnPushInfo(new TopnPushInfo(
-                        generateOrderKeyByGroupKey(upperAgg),
-                        limit.getLimit() + limit.getOffset()));
+        for (int i = 0; i < agg.getGroupByExpressions().size(); i++) {
+            Expression groupByKey = agg.getGroupByExpressions().get(i);
+            Expression orderKey = topN.getOrderKeys().get(i).getExpr();
+            if (!groupByKey.equals(orderKey)) {
+                return false;
             }
         }
-        return limit;
-    }
-
-    private List<OrderKey> generateOrderKeyByGroupKey(PhysicalHashAggregate<? 
extends Plan> agg) {
-        return agg.getGroupByExpressions().stream()
-                .map(key -> new OrderKey(key, true, false))
-                .collect(Collectors.toList());
+        return true;
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java
index dfa1230a8f8..049709dd23a 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java
@@ -17,24 +17,29 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.properties.OrderKey;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
+import org.apache.doris.nereids.trees.expressions.Alias;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
+import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.plans.Plan;
 import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
-import org.apache.doris.nereids.trees.plans.logical.LogicalLimit;
 import org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalTopN;
 import org.apache.doris.qe.ConnectContext;
 
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
+import com.google.common.collect.Sets;
 
+import java.util.HashMap;
 import java.util.List;
-import java.util.Optional;
+import java.util.Map;
+import java.util.Set;
 import java.util.stream.Collectors;
 
 /**
@@ -45,6 +50,7 @@ import java.util.stream.Collectors;
  * 2. push limit to local agg
  */
 public class LimitAggToTopNAgg implements RewriteRuleFactory {
+
     @Override
     public List<Rule> buildRules() {
         return ImmutableList.of(
@@ -54,109 +60,171 @@ public class LimitAggToTopNAgg implements 
RewriteRuleFactory {
                                 && 
ConnectContext.get().getSessionVariable().pushTopnToAgg
                                 && 
ConnectContext.get().getSessionVariable().topnOptLimitThreshold
                                 >= limit.getLimit() + limit.getOffset())
+                        .when(limit -> {
+                            LogicalAggregate<? extends Plan> agg = 
limit.child();
+                            return !agg.getGroupByExpressions().isEmpty() && 
!agg.getSourceRepeat().isPresent();
+                        })
                         .then(limit -> {
                             LogicalAggregate<? extends Plan> agg = 
limit.child();
-                            Optional<OrderKey> orderKeysOpt = 
tryGenerateOrderKeyByTheFirstGroupKey(agg);
-                            if (!orderKeysOpt.isPresent()) {
-                                return null;
-                            }
-                            List<OrderKey> orderKeys = 
Lists.newArrayList(orderKeysOpt.get());
+                            List<OrderKey> orderKeys = 
generateOrderKeyByGroupKey(agg);
                             return new LogicalTopN<>(orderKeys, 
limit.getLimit(), limit.getOffset(), agg);
                         }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
-                //limit->project->agg to topn->project->agg
+                //limit->project->agg to project->topn->agg
                 logicalLimit(logicalProject(logicalAggregate()))
                         .when(limit -> ConnectContext.get() != null
                                 && 
ConnectContext.get().getSessionVariable().pushTopnToAgg
                                 && 
ConnectContext.get().getSessionVariable().topnOptLimitThreshold
                                 >= limit.getLimit() + limit.getOffset())
+                        .when(limit -> {
+                            LogicalAggregate<? extends Plan> agg = 
limit.child().child();
+                            return !agg.getGroupByExpressions().isEmpty() && 
!agg.getSourceRepeat().isPresent();
+                        })
                         .then(limit -> {
                             LogicalProject<? extends Plan> project = 
limit.child();
-                            LogicalAggregate<? extends Plan> agg
-                                    = (LogicalAggregate<? extends Plan>) 
project.child();
-                            Optional<OrderKey> orderKeysOpt = 
tryGenerateOrderKeyByTheFirstGroupKey(agg);
-                            if (!orderKeysOpt.isPresent()) {
-                                return null;
-                            }
-                            List<OrderKey> orderKeys = 
Lists.newArrayList(orderKeysOpt.get());
-                            Plan result;
-
-                            if (outputAllGroupKeys(limit, agg)) {
-                                result = new LogicalTopN<>(orderKeys, 
limit.getLimit(),
-                                        limit.getOffset(), project);
-                            } else {
-                                // add the first group by key to topn, and 
prune this key by upper project
-                                // topn order keys are prefix of group by keys
-                                // refer to 
PushTopnToAgg.tryGenerateOrderKeyByGroupKeyAndTopnKey()
-                                Expression firstGroupByKey = 
agg.getGroupByExpressions().get(0);
-                                if (!(firstGroupByKey instanceof 
SlotReference)) {
-                                    return null;
-                                }
-                                boolean shouldPruneFirstGroupByKey = true;
-                                if 
(project.getOutputs().contains(firstGroupByKey)) {
-                                    shouldPruneFirstGroupByKey = false;
-                                } else {
-                                    List<NamedExpression> bottomProjections = 
Lists.newArrayList(project.getProjects());
-                                    bottomProjections.add((SlotReference) 
firstGroupByKey);
-                                    project = 
project.withProjects(bottomProjections);
-                                }
-                                LogicalTopN topn = new 
LogicalTopN<>(orderKeys, limit.getLimit(),
-                                        limit.getOffset(), project);
-                                if (shouldPruneFirstGroupByKey) {
-                                    List<NamedExpression> limitOutput = 
limit.getOutput().stream()
-                                            .map(e -> (NamedExpression) 
e).collect(Collectors.toList());
-                                    result = new LogicalProject<>(limitOutput, 
topn);
-                                } else {
-                                    result = topn;
-                                }
-                            }
-                            return result;
+                            LogicalAggregate<? extends Plan> agg = 
(LogicalAggregate<? extends Plan>) project.child();
+                            List<OrderKey> orderKeys = 
generateOrderKeyByGroupKey(agg);
+                            LogicalTopN topn = new LogicalTopN<>(orderKeys, 
limit.getLimit(),
+                                    limit.getOffset(), agg);
+                            project = (LogicalProject<? extends Plan>) 
project.withChildren(topn);
+                            return project;
                         }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
-                // topn -> agg: add all group key to sort key, if sort key is 
prefix of group key
+                // topn -> agg: append group key(if it is not sort key) to 
sort key
                 logicalTopN(logicalAggregate())
                         .when(topn -> ConnectContext.get() != null
                                 && 
ConnectContext.get().getSessionVariable().pushTopnToAgg
                                 && 
ConnectContext.get().getSessionVariable().topnOptLimitThreshold
                                 >= topn.getLimit() + topn.getOffset())
+                        .when(topn -> {
+                            LogicalAggregate<? extends Plan> agg = 
topn.child();
+                            return !agg.getGroupByExpressions().isEmpty() && 
!agg.getSourceRepeat().isPresent();
+                        })
+                        .then(topn -> {
+                            LogicalAggregate<? extends Plan> agg = 
topn.child();
+                            Pair<List<OrderKey>, List<Expression>> pair =
+                                    
supplementOrderKeyByGroupKeyIfCompatible(topn, agg);
+                            if (pair != null) {
+                                agg = agg.withGroupBy(pair.second);
+                                topn = (LogicalTopN) topn.withChildren(agg);
+                                topn = (LogicalTopN) 
topn.withOrderKeys(pair.first);
+                            }
+                            return topn;
+                        }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG),
+                //topn -> project ->agg: add all group key to sort key, and 
prune column
+                logicalTopN(logicalProject(logicalAggregate()))
+                        .when(topn -> ConnectContext.get() != null
+                                && 
ConnectContext.get().getSessionVariable().pushTopnToAgg
+                                && 
ConnectContext.get().getSessionVariable().topnOptLimitThreshold
+                                >= topn.getLimit() + topn.getOffset())
+                        .when(topn -> {
+                            LogicalAggregate<? extends Plan> agg = 
topn.child().child();
+                            return !agg.getGroupByExpressions().isEmpty() && 
!agg.getSourceRepeat().isPresent();
+                        })
                         .then(topn -> {
-                            LogicalAggregate<? extends Plan> agg = 
(LogicalAggregate<? extends Plan>) topn.child();
-                            List<OrderKey> newOrders = 
tryGenerateOrderKeyByGroupKeyAndTopnKey(topn, agg);
-                            if (newOrders.isEmpty()) {
-                                return topn;
+                            LogicalTopN originTopn = topn;
+                            LogicalProject<? extends Plan> project = 
topn.child();
+                            LogicalAggregate<? extends Plan> agg = 
(LogicalAggregate) project.child();
+                            if (!project.isAllSlots()) {
+                                /*
+                                    topn(orderKey=[a])
+                                        +-->project(b as a)
+                                           +--> agg(groupKey[b]
+                                     =>
+                                     topn(orderKey=[b])
+                                        +-->project(b as a)
+                                           +-->agg(groupKey[b])
+                                     and then exchange topn and project
+                                 */
+                                Map<SlotReference, SlotReference> keyAsKey = 
new HashMap<>();
+                                for (NamedExpression e : 
project.getProjects()) {
+                                    if (e instanceof Alias && e.child(0) 
instanceof SlotReference) {
+                                        keyAsKey.put((SlotReference) 
e.toSlot(), (SlotReference) e.child(0));
+                                    }
+                                }
+                                List<OrderKey> projectOrderKeys = 
Lists.newArrayList();
+                                boolean hasNew = false;
+                                for (OrderKey orderKey : topn.getOrderKeys()) {
+                                    if 
(keyAsKey.containsKey(orderKey.getExpr())) {
+                                        
projectOrderKeys.add(orderKey.withExpression(keyAsKey.get(orderKey.getExpr())));
+                                        hasNew = true;
+                                    } else {
+                                        projectOrderKeys.add(orderKey);
+                                    }
+                                }
+                                if (hasNew) {
+                                    topn = (LogicalTopN) 
topn.withOrderKeys(projectOrderKeys);
+                                }
+                            }
+                            Pair<List<OrderKey>, List<Expression>> pair =
+                                    
supplementOrderKeyByGroupKeyIfCompatible(topn, agg);
+                            Plan result;
+                            if (pair == null) {
+                                result = originTopn;
                             } else {
-                                return topn.withOrderKeys(newOrders);
+                                agg = agg.withGroupBy(pair.second);
+                                topn = (LogicalTopN) 
topn.withOrderKeys(pair.first);
+                                if (isOrderKeysInProject(topn, project)) {
+                                    project = (LogicalProject<? extends Plan>) 
project.withChildren(agg);
+                                    topn = 
(LogicalTopN<LogicalProject<LogicalAggregate<Plan>>>)
+                                            topn.withChildren(project);
+                                    result = topn;
+                                } else {
+                                    topn = (LogicalTopN) 
topn.withChildren(agg);
+                                    project = (LogicalProject<? extends Plan>) 
project.withChildren(topn);
+                                    result = project;
+                                }
                             }
-                        }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG));
+                            return result;
+                        }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG)
+        );
     }
 
-    private List<OrderKey> 
tryGenerateOrderKeyByGroupKeyAndTopnKey(LogicalTopN<? extends Plan> topN,
-            LogicalAggregate<? extends Plan> agg) {
-        List<OrderKey> orderKeys = 
Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
-        if (topN.getOrderKeys().size() > agg.getGroupByExpressions().size()) {
-            return orderKeys;
-        }
-        List<Expression> topnKeys = topN.getOrderKeys().stream()
-                .map(OrderKey::getExpr).collect(Collectors.toList());
-        for (int i = 0; i < topN.getOrderKeys().size(); i++) {
-            // prefix check
-            if (!topnKeys.get(i).equals(agg.getGroupByExpressions().get(i))) {
-                return Lists.newArrayList();
+    private boolean isOrderKeysInProject(LogicalTopN<? extends Plan> topn, 
LogicalProject project) {
+        Set<Slot> projectSlots = project.getOutputSet();
+        for (OrderKey orderKey : topn.getOrderKeys()) {
+            if (!projectSlots.contains(orderKey.getExpr())) {
+                return false;
             }
-            orderKeys.add(topN.getOrderKeys().get(i));
-        }
-        for (int i = topN.getOrderKeys().size(); i < 
agg.getGroupByExpressions().size(); i++) {
-            orderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), 
true, false));
         }
-        return orderKeys;
+        return true;
     }
 
-    private boolean outputAllGroupKeys(LogicalLimit limit, LogicalAggregate 
agg) {
-        return limit.getOutputSet().containsAll(agg.getGroupByExpressions());
+    private List<OrderKey> generateOrderKeyByGroupKey(LogicalAggregate<? 
extends Plan> agg) {
+        return agg.getGroupByExpressions().stream()
+                .map(key -> new OrderKey(key, true, false))
+                .collect(Collectors.toList());
     }
 
-    private Optional<OrderKey> 
tryGenerateOrderKeyByTheFirstGroupKey(LogicalAggregate<? extends Plan> agg) {
-        if (agg.getGroupByExpressions().isEmpty()) {
-            return Optional.empty();
+    /**
+     * compatible: if order key is subset of group by keys
+     * example:
+     * 1. orderKey[a, b], groupKeys[b, a, c]
+     *    compatible, return Pair(orderKey[a, b, c], groupKey[a, b, c])
+     * 2. orderKey[a, b+1], groupKeys[a, b]
+     *    not compatible, return null
+     */
+    private Pair<List<OrderKey>, List<Expression>> 
supplementOrderKeyByGroupKeyIfCompatible(
+            LogicalTopN<? extends Plan> topn, LogicalAggregate<? extends Plan> 
agg) {
+        Set<Expression> groupKeySet = 
Sets.newHashSet(agg.getGroupByExpressions());
+        List<Expression> orderKeyList = topn.getOrderKeys().stream()
+                .map(OrderKey::getExpr).collect(Collectors.toList());
+        Set<Expression> orderKeySet = Sets.newHashSet(orderKeyList);
+        boolean compatible = groupKeySet.containsAll(orderKeyList);
+        if (compatible) {
+            List<OrderKey> newOrderKeys = 
Lists.newArrayList(topn.getOrderKeys());
+            List<Expression> newGroupExpressions = 
Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size());
+            for (OrderKey orderKey : newOrderKeys) {
+                newGroupExpressions.add(orderKey.getExpr());
+            }
+
+            for (Expression groupKey : agg.getGroupByExpressions()) {
+                if (!orderKeySet.contains(groupKey)) {
+                    newOrderKeys.add(new OrderKey(groupKey, true, false));
+                    newGroupExpressions.add(groupKey);
+                }
+            }
+            return Pair.of(newOrderKeys, newGroupExpressions);
+        } else {
+            return null;
         }
-        return Optional.of(new OrderKey(agg.getGroupByExpressions().get(0), 
true, false));
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index 31cee19cc43..df8f886451f 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -263,6 +263,11 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
                 hasPushed, sourceRepeat, groupExpression, 
Optional.of(getLogicalProperties()), children.get(0));
     }
 
+    public LogicalAggregate<Plan> withGroupBy(List<Expression> 
groupByExprList) {
+        return new LogicalAggregate<>(groupByExprList, outputExpressions, 
normalized, ordinalIsResolved, generated,
+                hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), 
child());
+    }
+
     public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> 
groupByExprList,
             List<NamedExpression> outputExpressionList) {
         return new LogicalAggregate<>(groupByExprList, outputExpressionList, 
normalized, ordinalIsResolved, generated,
diff --git a/regression-test/data/query_p0/limit/test_group_by_limit.out 
b/regression-test/data/query_p0/limit/test_group_by_limit.out
index d9ac2a2481a..4a396cdaca8 100644
--- a/regression-test/data/query_p0/limit/test_group_by_limit.out
+++ b/regression-test/data/query_p0/limit/test_group_by_limit.out
@@ -1,65 +1,65 @@
 -- This file is automatically generated. You should know what you did if you 
want to edit this
--- !select --
+-- !select1 --
 253967024      8491    AIR
 259556658      8641    FOB
 260402265      8669    MAIL
 
--- !select --
+-- !select2 --
 449872500      15000   1
 386605746      12900   2
 320758616      10717   3
 
--- !select --
+-- !select3 --
 198674527      6588    0.0
 198679731      6563    0.01
 198501055      6622    0.02
 
--- !select --
+-- !select4 --
 27137  1       1992-02-02
 45697  1       1992-02-04
 114452 5       1992-02-05
 
--- !select --
+-- !select5 --
 27137  1       1992-02-02T00:00
 45697  1       1992-02-04T00:00
 114452 5       1992-02-05T00:00
 
--- !select --
+-- !select6 --
 139015016      4632    1
 130287219      4313    2
 162309750      5334    3
 
--- !select --
+-- !select7 --
 64774969       2166    AIR     1
 54166166       1804    AIR     2
 45538267       1532    AIR     3
 
--- !select --
+-- !select8 --
 6882631        228     AIR     1       0.0
 6756423        228     AIR     1       0.01
 7920028        254     AIR     1       0.02
 
--- !select --
+-- !select9 --
 7618   1       AIR     1       0.0     1992-02-06
 2210   1       AIR     1       0.0     1992-03-24
 16807  1       AIR     1       0.0     1992-03-29
 
--- !select --
+-- !select10 --
 6882631        228     AIR     1       0.0
 6756423        228     AIR     1       0.01
 7920028        254     AIR     1       0.02
 
--- !select --
+-- !select11 --
 6882631        228     AIR     1       0.0
 6756423        228     AIR     1       0.01
 7920028        254     AIR     1       0.02
 
--- !select --
+-- !select12 --
 7707018        238     TRUCK   1       0.0
 7467045        233     TRUCK   1       0.01
 6927206        245     TRUCK   1       0.02
 
--- !select --
+-- !select13 --
 7661562        249     TRUCK   1       0.08
 6673139        228     TRUCK   1       0.07
 8333862        265     TRUCK   1       0.06
diff --git 
a/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy 
b/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy
index 631656a6b19..5ae587910b6 100644
--- a/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy
+++ b/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy
@@ -50,22 +50,31 @@ suite("push_topn_to_agg") {
         notContains("STREAMING")
     }
 
-    // order key should be prefix of group key
+    // order keys are part of group keys, 
+    // 1. adjust group keys (o_custkey, o_clerk) -> o_clerk, o_custkey
+    // 2. append o_custkey to order key 
     explain{
-        sql "select o_custkey, sum(o_shippriority), o_clerk  from orders group 
by o_custkey, o_clerk order by o_clerk, o_custkey limit 11;"
-        multiContains("sortByGroupKey:false", 2)
+        sql "select sum(o_shippriority)  from orders group by o_custkey, 
o_clerk order by o_clerk limit 11;"
+        contains("sortByGroupKey:true")
+        contains("group by: o_clerk[#10], o_custkey[#9]")
+        contains("order by: o_clerk[#18] ASC, o_custkey[#19] ASC")
     }
 
-    // order key should be prefix of group key
-    explain{
-        sql "select o_custkey, o_clerk, sum(o_shippriority) as x from orders 
group by o_custkey, o_clerk order by o_custkey, x limit 12;"
-        multiContains("sortByGroupKey:false", 2)
+
+    // one distinct 
+    explain {
+        sql "select sum(distinct o_shippriority) from orders group by 
o_orderkey limit 13; "
+        contains("VTOP-N")
+        contains("order by: o_orderkey")
+        multiContains("sortByGroupKey:true", 1)
     }
 
-    // one phase agg is optimized 
+    // multi distinct 
     explain {
-        sql "select sum(o_shippriority) from orders group by o_orderkey limit 
13; "
-        contains("sortByGroupKey:true")
+        sql "select count(distinct o_clerk), sum(distinct o_shippriority) from 
orders group by o_orderkey limit 14; "
+        contains("VTOP-N")
+        contains("order by: o_orderkey")
+        multiContains("sortByGroupKey:true", 2)
     }
 
     // use group key as sort key to enable topn-push opt
@@ -74,22 +83,17 @@ suite("push_topn_to_agg") {
         contains("sortByGroupKey:true")
     }
 
-    // group key is part of output of limit, apply opt
+    // group key is expression
     explain {
-        sql "select sum(o_shippriority), o_clerk from orders group by o_clerk 
limit 15; "
+        sql "select sum(o_shippriority), o_clerk+1 from orders group by 
o_clerk+1 limit 15; "
         contains("sortByGroupKey:true")
     }
 
-    // order key is not prefix of group key
+    // order key is not part of group key
     explain {
         sql "select o_custkey, sum(o_shippriority) from orders group by 
o_custkey order by o_custkey+1 limit 16; "
         contains("sortByGroupKey:false")
-    }
-
-    // order key is not prefix of group key
-    explain {
-        sql "select o_custkey, sum(o_shippriority) from orders group by 
o_custkey order by o_custkey+1 limit 17; "
-        contains("sortByGroupKey:false")
+        notContains("sortByGroupKey:true")
     }
 
     // topn + one phase agg
@@ -97,73 +101,4 @@ suite("push_topn_to_agg") {
         sql "select sum(ps_availqty), ps_partkey, ps_suppkey from partsupp 
group by ps_partkey, ps_suppkey order by ps_partkey, ps_suppkey limit 18;"
         contains("sortByGroupKey:true")
     }
-
-    // sort key is prefix of group key, make all group key to sort 
key(ps_suppkey) and then apply push-topn-agg rule
-    explain {
-        sql "select sum(ps_availqty), ps_partkey, ps_suppkey from partsupp 
group by ps_partkey, ps_suppkey order by ps_partkey limit 19;"
-        contains("sortByGroupKey:true")
-    }
-
-    explain {
-        sql "select sum(ps_availqty), ps_suppkey, ps_availqty from partsupp 
group by ps_suppkey, ps_availqty order by ps_suppkey limit 19;"
-        contains("sortByGroupKey:true")
-    }
-
-    // sort key is not prefix of group key, deny
-    explain {
-        sql "select sum(ps_availqty), ps_partkey, ps_suppkey from partsupp 
group by ps_partkey, ps_suppkey order by ps_suppkey limit 20;"
-        contains("sortByGroupKey:false")
-    }
-
-    multi_sql """
-    drop table if exists t1;
-    CREATE TABLE IF NOT EXISTS t1
-        (
-        k1 TINYINT
-        )
-        ENGINE=olap
-        AGGREGATE KEY(k1)
-        DISTRIBUTED BY HASH(k1) BUCKETS 1
-        PROPERTIES (
-            "replication_num" = "1"
-        );
-
-    insert into t1 values (0),(1);
-
-    drop table if exists t2;
-    CREATE TABLE IF NOT EXISTS t2
-        (
-        k1 TINYINT
-        )
-        ENGINE=olap
-        AGGREGATE KEY(k1)
-        DISTRIBUTED BY HASH(k1) BUCKETS 1
-        PROPERTIES (
-            "replication_num" = "1"
-        );
-    insert into t2 values(5),(6);
-    """
-
-    // the result of following sql may be unstable, run 3 times
-    qt_stable_1 """
-    select * from (
-        select k1 from t1
-        UNION
-        select k1 from t2
-    ) as b order by k1  limit 2;
-    """
-    qt_stable_2 """
-    select * from (
-        select k1 from t1
-        UNION
-        select k1 from t2
-    ) as b order by k1  limit 2;
-    """
-    qt_stable_3 """
-    select * from (
-        select k1 from t1
-        UNION
-        select k1 from t2
-    ) as b order by k1  limit 2;
-    """
 }
\ No newline at end of file
diff --git a/regression-test/suites/query_p0/limit/test_group_by_limit.groovy 
b/regression-test/suites/query_p0/limit/test_group_by_limit.groovy
index 271619c4a93..801266f4e92 100644
--- a/regression-test/suites/query_p0/limit/test_group_by_limit.groovy
+++ b/regression-test/suites/query_p0/limit/test_group_by_limit.groovy
@@ -23,42 +23,97 @@ sql 'set enable_force_spill=false'
 
 sql 'set topn_opt_limit_threshold=10'
 
-
 // different types
-qt_select """ select  sum(orderkey), count(partkey), shipmode from 
tpch_tiny_lineitem group by shipmode limit 3; """
-
-qt_select """ select  sum(orderkey), count(partkey),  linenumber from 
tpch_tiny_lineitem group by linenumber limit 3; """
+qt_select1 """ select  sum(orderkey), count(partkey), shipmode from 
tpch_tiny_lineitem group by shipmode limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey), shipmode from 
tpch_tiny_lineitem group by shipmode limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
 
-qt_select """ select  sum(orderkey), count(partkey),  tax from 
tpch_tiny_lineitem group by tax limit 3; """
+qt_select2 """ select  sum(orderkey), count(partkey),  linenumber from 
tpch_tiny_lineitem group by linenumber limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey),  linenumber from 
tpch_tiny_lineitem group by linenumber limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
 
-qt_select """ select  sum(orderkey), count(partkey),  commitdate from 
tpch_tiny_lineitem group by commitdate limit 3; """
+qt_select3 """ select  sum(orderkey), count(partkey),  tax from 
tpch_tiny_lineitem group by tax limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey),  tax from tpch_tiny_lineitem 
group by tax limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
 
+qt_select4 """ select  sum(orderkey), count(partkey),  commitdate from 
tpch_tiny_lineitem group by commitdate limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey),  commitdate from 
tpch_tiny_lineitem group by commitdate limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
 
 // group by functions
-qt_select """ select  sum(orderkey), count(partkey),  cast(commitdate as 
datetime) from tpch_tiny_lineitem group by cast(commitdate as datetime) limit 
3; """
-
-qt_select """ select  sum(orderkey), count(partkey),  month(commitdate) from 
tpch_tiny_lineitem group by month(commitdate) limit 3; """
+qt_select5 """ select  sum(orderkey), count(partkey),  cast(commitdate as 
datetime) from tpch_tiny_lineitem group by cast(commitdate as datetime) limit 
3; """
+explain {
+    sql " select  sum(orderkey), count(partkey),  cast(commitdate as datetime) 
from tpch_tiny_lineitem group by cast(commitdate as datetime) limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
 
+qt_select6 """ select  sum(orderkey), count(partkey),  month(commitdate) from 
tpch_tiny_lineitem group by month(commitdate) limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey),  month(commitdate) from 
tpch_tiny_lineitem group by month(commitdate) limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
 
 // mutli column
-qt_select """ select  sum(orderkey), count(partkey), shipmode, linenumber from 
tpch_tiny_lineitem group by shipmode, linenumber limit 3; """
-
-qt_select """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax from tpch_tiny_lineitem group by shipmode, linenumber, tax limit 3; """
-
-qt_select """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax , commitdate from tpch_tiny_lineitem group by shipmode, linenumber, tax, 
commitdate  limit 3; """
-
+qt_select7 """ select  sum(orderkey), count(partkey), shipmode, linenumber 
from tpch_tiny_lineitem group by shipmode, linenumber limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey), shipmode, linenumber from 
tpch_tiny_lineitem group by shipmode, linenumber limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
+qt_select8 """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax from tpch_tiny_lineitem group by shipmode, linenumber, tax limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey), shipmode, linenumber , tax 
from tpch_tiny_lineitem group by shipmode, linenumber, tax limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
+qt_select9 """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax , commitdate from tpch_tiny_lineitem group by shipmode, linenumber, tax, 
commitdate  limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey), shipmode, linenumber , tax , 
commitdate from tpch_tiny_lineitem group by shipmode, linenumber, tax, 
commitdate  limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
 
 // group by + order by 
 
 // group by columns eq order by columns
-qt_select """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by 
shipmode, linenumber, tax limit 3; """
-
+qt_select10 """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by 
shipmode, linenumber, tax limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey), shipmode, linenumber , tax 
from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode, 
linenumber, tax limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
 // group by columns contains order by columns
-qt_select """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by 
shipmode limit 3; """
-
+qt_select11 """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by 
shipmode limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey), shipmode, linenumber , tax 
from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode 
limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
 // desc order by column
-qt_select """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by 
shipmode desc, linenumber, tax limit 3; """
-
-qt_select """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by 
shipmode desc, linenumber, tax desc limit 3; """
-
+qt_select12 """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by 
shipmode desc, linenumber, tax limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey), shipmode, linenumber , tax 
from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode 
desc, linenumber, tax limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
+qt_select13 """ select  sum(orderkey), count(partkey), shipmode, linenumber , 
tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by 
shipmode desc, linenumber, tax desc limit 3; """
+explain{
+    sql " select  sum(orderkey), count(partkey), shipmode, linenumber , tax 
from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode 
desc, linenumber, tax desc limit 3; "
+    contains("VTOP-N")
+    contains("sortByGroupKey:true")
+}
 }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to