This is an automated email from the ASF dual-hosted git repository. englefly pushed a commit to branch branch-3.0 in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/branch-3.0 by this push: new e6678e682ba [fix](Nereids) set correct sort key for aggregate #45369 branch-3.0 (#45706) e6678e682ba is described below commit e6678e682ba57e38969c250971a23231fb762923 Author: minghong <zhoumingh...@selectdb.com> AuthorDate: Sun Jan 5 17:35:30 2025 +0800 [fix](Nereids) set correct sort key for aggregate #45369 branch-3.0 (#45706) ### What problem does this PR solve? pick #45369 --- .../nereids/processor/post/PushTopnToAgg.java | 159 ++++----------- .../nereids/rules/rewrite/LimitAggToTopNAgg.java | 222 ++++++++++++++------- .../trees/plans/logical/LogicalAggregate.java | 5 + .../data/query_p0/limit/test_group_by_limit.out | 26 +-- .../nereids_tpch_p0/tpch/push_topn_to_agg.groovy | 111 +++-------- .../query_p0/limit/test_group_by_limit.groovy | 101 +++++++--- 6 files changed, 307 insertions(+), 317 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushTopnToAgg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushTopnToAgg.java index aca3f21a7d1..d93433768c4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushTopnToAgg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/processor/post/PushTopnToAgg.java @@ -21,157 +21,84 @@ package org.apache.doris.nereids.processor.post; import org.apache.doris.nereids.CascadesContext; -import org.apache.doris.nereids.properties.DistributionSpecGather; -import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.plans.AggMode; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.physical.PhysicalDistribute; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate; import org.apache.doris.nereids.trees.plans.physical.PhysicalHashAggregate.TopnPushInfo; -import org.apache.doris.nereids.trees.plans.physical.PhysicalLimit; import org.apache.doris.nereids.trees.plans.physical.PhysicalProject; import org.apache.doris.nereids.trees.plans.physical.PhysicalTopN; import org.apache.doris.qe.ConnectContext; -import org.apache.hadoop.util.Lists; - -import java.util.List; -import java.util.stream.Collectors; - /** - * Add SortInfo to Agg. This SortInfo is used as boundary, not used to sort elements. + * Add TopNInfo to Agg. This TopNInfo is used as boundary, not used to sort elements. * example * sql: select count(*) from orders group by o_clerk order by o_clerk limit 1; * plan: topn(1) -> aggGlobal -> shuffle -> aggLocal -> scan * optimization: aggLocal and aggGlobal only need to generate the smallest row with respect to o_clerk. * - * TODO: the following case is not covered: - * sql: select sum(o_shippriority) from orders group by o_clerk limit 1; - * plan: limit -> aggGlobal -> shuffle -> aggLocal -> scan - * aggGlobal may receive partial aggregate results, and hence is not supported now - * instance1: input (key=2, v=1) => localAgg => (2, 1) => aggGlobal inst1 => (2, 1) - * instance2: input (key=1, v=1), (key=2, v=2) => localAgg inst2 => (1, 1) - * (2,1),(1,1) => limit => may output (2, 1), which is not complete, missing (2, 2) in instance2 - * - *TOPN: - * Precondition: topn orderkeys are the prefix of group keys - * TODO: topnKeys could be subset of groupKeys. This will be implemented in future - * Pattern 2-phase agg: - * topn -> aggGlobal -> distribute -> aggLocal - * => - * topn(n) -> aggGlobal(topn=n) -> distribute -> aggLocal(topn=n) - * Pattern 1-phase agg: - * topn->agg->Any(not agg) -> topn -> agg(topn=n) -> any - * - * LIMIT: - * Pattern 1: limit->agg(1phase)->any - * Pattern 2: limit->agg(global)->gather->agg(local) + * This rule only applies to the patterns + * 1. topn->project->agg, or + * 2. topn->agg + * that + * 1. orderKeys and groupkeys are one-one mapping + * 2. aggregate is not scalar agg + * Refer to LimitAggToTopNAgg rule. */ public class PushTopnToAgg extends PlanPostProcessor { @Override public Plan visitPhysicalTopN(PhysicalTopN<? extends Plan> topN, CascadesContext ctx) { topN.child().accept(this, ctx); - if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= topN.getLimit() + topN.getOffset()) { + if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= topN.getLimit() + topN.getOffset() + && !ConnectContext.get().getSessionVariable().pushTopnToAgg) { return topN; } - Plan topnChild = topN.child(); - if (topnChild instanceof PhysicalProject) { - topnChild = topnChild.child(0); + Plan topNChild = topN.child(); + if (topNChild instanceof PhysicalProject) { + topNChild = topNChild.child(0); } - if (topnChild instanceof PhysicalHashAggregate) { - PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topnChild; - List<OrderKey> orderKeys = tryGenerateOrderKeyByGroupKeyAndTopnKey(topN, upperAgg); - if (!orderKeys.isEmpty()) { - - if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) { - upperAgg.setTopnPushInfo(new TopnPushInfo( - orderKeys, - topN.getLimit() + topN.getOffset())); - if (upperAgg.child() instanceof PhysicalDistribute - && upperAgg.child().child(0) instanceof PhysicalHashAggregate) { - PhysicalHashAggregate<? extends Plan> bottomAgg = - (PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0); + if (topNChild instanceof PhysicalHashAggregate) { + PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) topNChild; + if (isGroupKeyIdenticalToOrderKey(topN, upperAgg)) { + upperAgg.setTopnPushInfo(new TopnPushInfo( + topN.getOrderKeys(), + topN.getLimit() + topN.getOffset())); + if (upperAgg.child() instanceof PhysicalDistribute + && upperAgg.child().child(0) instanceof PhysicalHashAggregate) { + PhysicalHashAggregate<? extends Plan> bottomAgg = + (PhysicalHashAggregate<? extends Plan>) upperAgg.child().child(0); + if (isGroupKeyIdenticalToOrderKey(topN, bottomAgg)) { + bottomAgg.setTopnPushInfo(new TopnPushInfo( + topN.getOrderKeys(), + topN.getLimit() + topN.getOffset())); + } + } else if (upperAgg.child() instanceof PhysicalHashAggregate) { + // multi-distinct plan + PhysicalHashAggregate<? extends Plan> bottomAgg = + (PhysicalHashAggregate<? extends Plan>) upperAgg.child(); + if (isGroupKeyIdenticalToOrderKey(topN, bottomAgg)) { bottomAgg.setTopnPushInfo(new TopnPushInfo( - orderKeys, + topN.getOrderKeys(), topN.getLimit() + topN.getOffset())); } - } else if (upperAgg.getAggPhase().isLocal() && upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) { - // one phase agg - upperAgg.setTopnPushInfo(new TopnPushInfo( - orderKeys, - topN.getLimit() + topN.getOffset())); } } } return topN; } - /** - return true, if topn order-key is prefix of agg group-key, ignore asc/desc and null_first - TODO order-key can be subset of group-key. BE does not support now. - */ - private List<OrderKey> tryGenerateOrderKeyByGroupKeyAndTopnKey(PhysicalTopN<? extends Plan> topN, + private boolean isGroupKeyIdenticalToOrderKey(PhysicalTopN<? extends Plan> topN, PhysicalHashAggregate<? extends Plan> agg) { - List<OrderKey> orderKeys = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size()); - if (topN.getOrderKeys().size() > agg.getGroupByExpressions().size()) { - return orderKeys; - } - List<Expression> topnKeys = topN.getOrderKeys().stream() - .map(OrderKey::getExpr).collect(Collectors.toList()); - for (int i = 0; i < topN.getOrderKeys().size(); i++) { - // prefix check - if (!topnKeys.get(i).equals(agg.getGroupByExpressions().get(i))) { - return Lists.newArrayList(); - } - orderKeys.add(topN.getOrderKeys().get(i)); - } - for (int i = topN.getOrderKeys().size(); i < agg.getGroupByExpressions().size(); i++) { - orderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), true, false)); - } - return orderKeys; - } - - @Override - public Plan visitPhysicalLimit(PhysicalLimit<? extends Plan> limit, CascadesContext ctx) { - limit.child().accept(this, ctx); - if (ConnectContext.get().getSessionVariable().topnOptLimitThreshold <= limit.getLimit() + limit.getOffset()) { - return limit; + if (topN.getOrderKeys().size() != agg.getGroupByExpressions().size()) { + return false; } - Plan limitChild = limit.child(); - if (limitChild instanceof PhysicalProject) { - limitChild = limitChild.child(0); - } - if (limitChild instanceof PhysicalHashAggregate) { - PhysicalHashAggregate<? extends Plan> upperAgg = (PhysicalHashAggregate<? extends Plan>) limitChild; - if (upperAgg.getAggPhase().isGlobal() && upperAgg.getAggMode() == AggMode.BUFFER_TO_RESULT) { - Plan child = upperAgg.child(); - Plan grandChild = child.child(0); - if (child instanceof PhysicalDistribute - && ((PhysicalDistribute<?>) child).getDistributionSpec() instanceof DistributionSpecGather - && grandChild instanceof PhysicalHashAggregate) { - upperAgg.setTopnPushInfo(new TopnPushInfo( - generateOrderKeyByGroupKey(upperAgg), - limit.getLimit() + limit.getOffset())); - PhysicalHashAggregate<? extends Plan> bottomAgg = - (PhysicalHashAggregate<? extends Plan>) grandChild; - bottomAgg.setTopnPushInfo(new TopnPushInfo( - generateOrderKeyByGroupKey(bottomAgg), - limit.getLimit() + limit.getOffset())); - } - } else if (upperAgg.getAggMode() == AggMode.INPUT_TO_RESULT) { - // 1-phase agg - upperAgg.setTopnPushInfo(new TopnPushInfo( - generateOrderKeyByGroupKey(upperAgg), - limit.getLimit() + limit.getOffset())); + for (int i = 0; i < agg.getGroupByExpressions().size(); i++) { + Expression groupByKey = agg.getGroupByExpressions().get(i); + Expression orderKey = topN.getOrderKeys().get(i).getExpr(); + if (!groupByKey.equals(orderKey)) { + return false; } } - return limit; - } - - private List<OrderKey> generateOrderKeyByGroupKey(PhysicalHashAggregate<? extends Plan> agg) { - return agg.getGroupByExpressions().stream() - .map(key -> new OrderKey(key, true, false)) - .collect(Collectors.toList()); + return true; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java index dfa1230a8f8..049709dd23a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/LimitAggToTopNAgg.java @@ -17,24 +17,29 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.properties.OrderKey; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; -import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.trees.plans.logical.LogicalTopN; import org.apache.doris.qe.ConnectContext; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; +import com.google.common.collect.Sets; +import java.util.HashMap; import java.util.List; -import java.util.Optional; +import java.util.Map; +import java.util.Set; import java.util.stream.Collectors; /** @@ -45,6 +50,7 @@ import java.util.stream.Collectors; * 2. push limit to local agg */ public class LimitAggToTopNAgg implements RewriteRuleFactory { + @Override public List<Rule> buildRules() { return ImmutableList.of( @@ -54,109 +60,171 @@ public class LimitAggToTopNAgg implements RewriteRuleFactory { && ConnectContext.get().getSessionVariable().pushTopnToAgg && ConnectContext.get().getSessionVariable().topnOptLimitThreshold >= limit.getLimit() + limit.getOffset()) + .when(limit -> { + LogicalAggregate<? extends Plan> agg = limit.child(); + return !agg.getGroupByExpressions().isEmpty() && !agg.getSourceRepeat().isPresent(); + }) .then(limit -> { LogicalAggregate<? extends Plan> agg = limit.child(); - Optional<OrderKey> orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg); - if (!orderKeysOpt.isPresent()) { - return null; - } - List<OrderKey> orderKeys = Lists.newArrayList(orderKeysOpt.get()); + List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg); return new LogicalTopN<>(orderKeys, limit.getLimit(), limit.getOffset(), agg); }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG), - //limit->project->agg to topn->project->agg + //limit->project->agg to project->topn->agg logicalLimit(logicalProject(logicalAggregate())) .when(limit -> ConnectContext.get() != null && ConnectContext.get().getSessionVariable().pushTopnToAgg && ConnectContext.get().getSessionVariable().topnOptLimitThreshold >= limit.getLimit() + limit.getOffset()) + .when(limit -> { + LogicalAggregate<? extends Plan> agg = limit.child().child(); + return !agg.getGroupByExpressions().isEmpty() && !agg.getSourceRepeat().isPresent(); + }) .then(limit -> { LogicalProject<? extends Plan> project = limit.child(); - LogicalAggregate<? extends Plan> agg - = (LogicalAggregate<? extends Plan>) project.child(); - Optional<OrderKey> orderKeysOpt = tryGenerateOrderKeyByTheFirstGroupKey(agg); - if (!orderKeysOpt.isPresent()) { - return null; - } - List<OrderKey> orderKeys = Lists.newArrayList(orderKeysOpt.get()); - Plan result; - - if (outputAllGroupKeys(limit, agg)) { - result = new LogicalTopN<>(orderKeys, limit.getLimit(), - limit.getOffset(), project); - } else { - // add the first group by key to topn, and prune this key by upper project - // topn order keys are prefix of group by keys - // refer to PushTopnToAgg.tryGenerateOrderKeyByGroupKeyAndTopnKey() - Expression firstGroupByKey = agg.getGroupByExpressions().get(0); - if (!(firstGroupByKey instanceof SlotReference)) { - return null; - } - boolean shouldPruneFirstGroupByKey = true; - if (project.getOutputs().contains(firstGroupByKey)) { - shouldPruneFirstGroupByKey = false; - } else { - List<NamedExpression> bottomProjections = Lists.newArrayList(project.getProjects()); - bottomProjections.add((SlotReference) firstGroupByKey); - project = project.withProjects(bottomProjections); - } - LogicalTopN topn = new LogicalTopN<>(orderKeys, limit.getLimit(), - limit.getOffset(), project); - if (shouldPruneFirstGroupByKey) { - List<NamedExpression> limitOutput = limit.getOutput().stream() - .map(e -> (NamedExpression) e).collect(Collectors.toList()); - result = new LogicalProject<>(limitOutput, topn); - } else { - result = topn; - } - } - return result; + LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) project.child(); + List<OrderKey> orderKeys = generateOrderKeyByGroupKey(agg); + LogicalTopN topn = new LogicalTopN<>(orderKeys, limit.getLimit(), + limit.getOffset(), agg); + project = (LogicalProject<? extends Plan>) project.withChildren(topn); + return project; }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG), - // topn -> agg: add all group key to sort key, if sort key is prefix of group key + // topn -> agg: append group key(if it is not sort key) to sort key logicalTopN(logicalAggregate()) .when(topn -> ConnectContext.get() != null && ConnectContext.get().getSessionVariable().pushTopnToAgg && ConnectContext.get().getSessionVariable().topnOptLimitThreshold >= topn.getLimit() + topn.getOffset()) + .when(topn -> { + LogicalAggregate<? extends Plan> agg = topn.child(); + return !agg.getGroupByExpressions().isEmpty() && !agg.getSourceRepeat().isPresent(); + }) + .then(topn -> { + LogicalAggregate<? extends Plan> agg = topn.child(); + Pair<List<OrderKey>, List<Expression>> pair = + supplementOrderKeyByGroupKeyIfCompatible(topn, agg); + if (pair != null) { + agg = agg.withGroupBy(pair.second); + topn = (LogicalTopN) topn.withChildren(agg); + topn = (LogicalTopN) topn.withOrderKeys(pair.first); + } + return topn; + }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG), + //topn -> project ->agg: add all group key to sort key, and prune column + logicalTopN(logicalProject(logicalAggregate())) + .when(topn -> ConnectContext.get() != null + && ConnectContext.get().getSessionVariable().pushTopnToAgg + && ConnectContext.get().getSessionVariable().topnOptLimitThreshold + >= topn.getLimit() + topn.getOffset()) + .when(topn -> { + LogicalAggregate<? extends Plan> agg = topn.child().child(); + return !agg.getGroupByExpressions().isEmpty() && !agg.getSourceRepeat().isPresent(); + }) .then(topn -> { - LogicalAggregate<? extends Plan> agg = (LogicalAggregate<? extends Plan>) topn.child(); - List<OrderKey> newOrders = tryGenerateOrderKeyByGroupKeyAndTopnKey(topn, agg); - if (newOrders.isEmpty()) { - return topn; + LogicalTopN originTopn = topn; + LogicalProject<? extends Plan> project = topn.child(); + LogicalAggregate<? extends Plan> agg = (LogicalAggregate) project.child(); + if (!project.isAllSlots()) { + /* + topn(orderKey=[a]) + +-->project(b as a) + +--> agg(groupKey[b] + => + topn(orderKey=[b]) + +-->project(b as a) + +-->agg(groupKey[b]) + and then exchange topn and project + */ + Map<SlotReference, SlotReference> keyAsKey = new HashMap<>(); + for (NamedExpression e : project.getProjects()) { + if (e instanceof Alias && e.child(0) instanceof SlotReference) { + keyAsKey.put((SlotReference) e.toSlot(), (SlotReference) e.child(0)); + } + } + List<OrderKey> projectOrderKeys = Lists.newArrayList(); + boolean hasNew = false; + for (OrderKey orderKey : topn.getOrderKeys()) { + if (keyAsKey.containsKey(orderKey.getExpr())) { + projectOrderKeys.add(orderKey.withExpression(keyAsKey.get(orderKey.getExpr()))); + hasNew = true; + } else { + projectOrderKeys.add(orderKey); + } + } + if (hasNew) { + topn = (LogicalTopN) topn.withOrderKeys(projectOrderKeys); + } + } + Pair<List<OrderKey>, List<Expression>> pair = + supplementOrderKeyByGroupKeyIfCompatible(topn, agg); + Plan result; + if (pair == null) { + result = originTopn; } else { - return topn.withOrderKeys(newOrders); + agg = agg.withGroupBy(pair.second); + topn = (LogicalTopN) topn.withOrderKeys(pair.first); + if (isOrderKeysInProject(topn, project)) { + project = (LogicalProject<? extends Plan>) project.withChildren(agg); + topn = (LogicalTopN<LogicalProject<LogicalAggregate<Plan>>>) + topn.withChildren(project); + result = topn; + } else { + topn = (LogicalTopN) topn.withChildren(agg); + project = (LogicalProject<? extends Plan>) project.withChildren(topn); + result = project; + } } - }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG)); + return result; + }).toRule(RuleType.LIMIT_AGG_TO_TOPN_AGG) + ); } - private List<OrderKey> tryGenerateOrderKeyByGroupKeyAndTopnKey(LogicalTopN<? extends Plan> topN, - LogicalAggregate<? extends Plan> agg) { - List<OrderKey> orderKeys = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size()); - if (topN.getOrderKeys().size() > agg.getGroupByExpressions().size()) { - return orderKeys; - } - List<Expression> topnKeys = topN.getOrderKeys().stream() - .map(OrderKey::getExpr).collect(Collectors.toList()); - for (int i = 0; i < topN.getOrderKeys().size(); i++) { - // prefix check - if (!topnKeys.get(i).equals(agg.getGroupByExpressions().get(i))) { - return Lists.newArrayList(); + private boolean isOrderKeysInProject(LogicalTopN<? extends Plan> topn, LogicalProject project) { + Set<Slot> projectSlots = project.getOutputSet(); + for (OrderKey orderKey : topn.getOrderKeys()) { + if (!projectSlots.contains(orderKey.getExpr())) { + return false; } - orderKeys.add(topN.getOrderKeys().get(i)); - } - for (int i = topN.getOrderKeys().size(); i < agg.getGroupByExpressions().size(); i++) { - orderKeys.add(new OrderKey(agg.getGroupByExpressions().get(i), true, false)); } - return orderKeys; + return true; } - private boolean outputAllGroupKeys(LogicalLimit limit, LogicalAggregate agg) { - return limit.getOutputSet().containsAll(agg.getGroupByExpressions()); + private List<OrderKey> generateOrderKeyByGroupKey(LogicalAggregate<? extends Plan> agg) { + return agg.getGroupByExpressions().stream() + .map(key -> new OrderKey(key, true, false)) + .collect(Collectors.toList()); } - private Optional<OrderKey> tryGenerateOrderKeyByTheFirstGroupKey(LogicalAggregate<? extends Plan> agg) { - if (agg.getGroupByExpressions().isEmpty()) { - return Optional.empty(); + /** + * compatible: if order key is subset of group by keys + * example: + * 1. orderKey[a, b], groupKeys[b, a, c] + * compatible, return Pair(orderKey[a, b, c], groupKey[a, b, c]) + * 2. orderKey[a, b+1], groupKeys[a, b] + * not compatible, return null + */ + private Pair<List<OrderKey>, List<Expression>> supplementOrderKeyByGroupKeyIfCompatible( + LogicalTopN<? extends Plan> topn, LogicalAggregate<? extends Plan> agg) { + Set<Expression> groupKeySet = Sets.newHashSet(agg.getGroupByExpressions()); + List<Expression> orderKeyList = topn.getOrderKeys().stream() + .map(OrderKey::getExpr).collect(Collectors.toList()); + Set<Expression> orderKeySet = Sets.newHashSet(orderKeyList); + boolean compatible = groupKeySet.containsAll(orderKeyList); + if (compatible) { + List<OrderKey> newOrderKeys = Lists.newArrayList(topn.getOrderKeys()); + List<Expression> newGroupExpressions = Lists.newArrayListWithCapacity(agg.getGroupByExpressions().size()); + for (OrderKey orderKey : newOrderKeys) { + newGroupExpressions.add(orderKey.getExpr()); + } + + for (Expression groupKey : agg.getGroupByExpressions()) { + if (!orderKeySet.contains(groupKey)) { + newOrderKeys.add(new OrderKey(groupKey, true, false)); + newGroupExpressions.add(groupKey); + } + } + return Pair.of(newOrderKeys, newGroupExpressions); + } else { + return null; } - return Optional.of(new OrderKey(agg.getGroupByExpressions().get(0), true, false)); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java index 31cee19cc43..df8f886451f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java @@ -263,6 +263,11 @@ public class LogicalAggregate<CHILD_TYPE extends Plan> hasPushed, sourceRepeat, groupExpression, Optional.of(getLogicalProperties()), children.get(0)); } + public LogicalAggregate<Plan> withGroupBy(List<Expression> groupByExprList) { + return new LogicalAggregate<>(groupByExprList, outputExpressions, normalized, ordinalIsResolved, generated, + hasPushed, sourceRepeat, Optional.empty(), Optional.empty(), child()); + } + public LogicalAggregate<Plan> withGroupByAndOutput(List<Expression> groupByExprList, List<NamedExpression> outputExpressionList) { return new LogicalAggregate<>(groupByExprList, outputExpressionList, normalized, ordinalIsResolved, generated, diff --git a/regression-test/data/query_p0/limit/test_group_by_limit.out b/regression-test/data/query_p0/limit/test_group_by_limit.out index d9ac2a2481a..4a396cdaca8 100644 --- a/regression-test/data/query_p0/limit/test_group_by_limit.out +++ b/regression-test/data/query_p0/limit/test_group_by_limit.out @@ -1,65 +1,65 @@ -- This file is automatically generated. You should know what you did if you want to edit this --- !select -- +-- !select1 -- 253967024 8491 AIR 259556658 8641 FOB 260402265 8669 MAIL --- !select -- +-- !select2 -- 449872500 15000 1 386605746 12900 2 320758616 10717 3 --- !select -- +-- !select3 -- 198674527 6588 0.0 198679731 6563 0.01 198501055 6622 0.02 --- !select -- +-- !select4 -- 27137 1 1992-02-02 45697 1 1992-02-04 114452 5 1992-02-05 --- !select -- +-- !select5 -- 27137 1 1992-02-02T00:00 45697 1 1992-02-04T00:00 114452 5 1992-02-05T00:00 --- !select -- +-- !select6 -- 139015016 4632 1 130287219 4313 2 162309750 5334 3 --- !select -- +-- !select7 -- 64774969 2166 AIR 1 54166166 1804 AIR 2 45538267 1532 AIR 3 --- !select -- +-- !select8 -- 6882631 228 AIR 1 0.0 6756423 228 AIR 1 0.01 7920028 254 AIR 1 0.02 --- !select -- +-- !select9 -- 7618 1 AIR 1 0.0 1992-02-06 2210 1 AIR 1 0.0 1992-03-24 16807 1 AIR 1 0.0 1992-03-29 --- !select -- +-- !select10 -- 6882631 228 AIR 1 0.0 6756423 228 AIR 1 0.01 7920028 254 AIR 1 0.02 --- !select -- +-- !select11 -- 6882631 228 AIR 1 0.0 6756423 228 AIR 1 0.01 7920028 254 AIR 1 0.02 --- !select -- +-- !select12 -- 7707018 238 TRUCK 1 0.0 7467045 233 TRUCK 1 0.01 6927206 245 TRUCK 1 0.02 --- !select -- +-- !select13 -- 7661562 249 TRUCK 1 0.08 6673139 228 TRUCK 1 0.07 8333862 265 TRUCK 1 0.06 diff --git a/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy b/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy index 631656a6b19..5ae587910b6 100644 --- a/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy +++ b/regression-test/suites/nereids_tpch_p0/tpch/push_topn_to_agg.groovy @@ -50,22 +50,31 @@ suite("push_topn_to_agg") { notContains("STREAMING") } - // order key should be prefix of group key + // order keys are part of group keys, + // 1. adjust group keys (o_custkey, o_clerk) -> o_clerk, o_custkey + // 2. append o_custkey to order key explain{ - sql "select o_custkey, sum(o_shippriority), o_clerk from orders group by o_custkey, o_clerk order by o_clerk, o_custkey limit 11;" - multiContains("sortByGroupKey:false", 2) + sql "select sum(o_shippriority) from orders group by o_custkey, o_clerk order by o_clerk limit 11;" + contains("sortByGroupKey:true") + contains("group by: o_clerk[#10], o_custkey[#9]") + contains("order by: o_clerk[#18] ASC, o_custkey[#19] ASC") } - // order key should be prefix of group key - explain{ - sql "select o_custkey, o_clerk, sum(o_shippriority) as x from orders group by o_custkey, o_clerk order by o_custkey, x limit 12;" - multiContains("sortByGroupKey:false", 2) + + // one distinct + explain { + sql "select sum(distinct o_shippriority) from orders group by o_orderkey limit 13; " + contains("VTOP-N") + contains("order by: o_orderkey") + multiContains("sortByGroupKey:true", 1) } - // one phase agg is optimized + // multi distinct explain { - sql "select sum(o_shippriority) from orders group by o_orderkey limit 13; " - contains("sortByGroupKey:true") + sql "select count(distinct o_clerk), sum(distinct o_shippriority) from orders group by o_orderkey limit 14; " + contains("VTOP-N") + contains("order by: o_orderkey") + multiContains("sortByGroupKey:true", 2) } // use group key as sort key to enable topn-push opt @@ -74,22 +83,17 @@ suite("push_topn_to_agg") { contains("sortByGroupKey:true") } - // group key is part of output of limit, apply opt + // group key is expression explain { - sql "select sum(o_shippriority), o_clerk from orders group by o_clerk limit 15; " + sql "select sum(o_shippriority), o_clerk+1 from orders group by o_clerk+1 limit 15; " contains("sortByGroupKey:true") } - // order key is not prefix of group key + // order key is not part of group key explain { sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey order by o_custkey+1 limit 16; " contains("sortByGroupKey:false") - } - - // order key is not prefix of group key - explain { - sql "select o_custkey, sum(o_shippriority) from orders group by o_custkey order by o_custkey+1 limit 17; " - contains("sortByGroupKey:false") + notContains("sortByGroupKey:true") } // topn + one phase agg @@ -97,73 +101,4 @@ suite("push_topn_to_agg") { sql "select sum(ps_availqty), ps_partkey, ps_suppkey from partsupp group by ps_partkey, ps_suppkey order by ps_partkey, ps_suppkey limit 18;" contains("sortByGroupKey:true") } - - // sort key is prefix of group key, make all group key to sort key(ps_suppkey) and then apply push-topn-agg rule - explain { - sql "select sum(ps_availqty), ps_partkey, ps_suppkey from partsupp group by ps_partkey, ps_suppkey order by ps_partkey limit 19;" - contains("sortByGroupKey:true") - } - - explain { - sql "select sum(ps_availqty), ps_suppkey, ps_availqty from partsupp group by ps_suppkey, ps_availqty order by ps_suppkey limit 19;" - contains("sortByGroupKey:true") - } - - // sort key is not prefix of group key, deny - explain { - sql "select sum(ps_availqty), ps_partkey, ps_suppkey from partsupp group by ps_partkey, ps_suppkey order by ps_suppkey limit 20;" - contains("sortByGroupKey:false") - } - - multi_sql """ - drop table if exists t1; - CREATE TABLE IF NOT EXISTS t1 - ( - k1 TINYINT - ) - ENGINE=olap - AGGREGATE KEY(k1) - DISTRIBUTED BY HASH(k1) BUCKETS 1 - PROPERTIES ( - "replication_num" = "1" - ); - - insert into t1 values (0),(1); - - drop table if exists t2; - CREATE TABLE IF NOT EXISTS t2 - ( - k1 TINYINT - ) - ENGINE=olap - AGGREGATE KEY(k1) - DISTRIBUTED BY HASH(k1) BUCKETS 1 - PROPERTIES ( - "replication_num" = "1" - ); - insert into t2 values(5),(6); - """ - - // the result of following sql may be unstable, run 3 times - qt_stable_1 """ - select * from ( - select k1 from t1 - UNION - select k1 from t2 - ) as b order by k1 limit 2; - """ - qt_stable_2 """ - select * from ( - select k1 from t1 - UNION - select k1 from t2 - ) as b order by k1 limit 2; - """ - qt_stable_3 """ - select * from ( - select k1 from t1 - UNION - select k1 from t2 - ) as b order by k1 limit 2; - """ } \ No newline at end of file diff --git a/regression-test/suites/query_p0/limit/test_group_by_limit.groovy b/regression-test/suites/query_p0/limit/test_group_by_limit.groovy index 271619c4a93..801266f4e92 100644 --- a/regression-test/suites/query_p0/limit/test_group_by_limit.groovy +++ b/regression-test/suites/query_p0/limit/test_group_by_limit.groovy @@ -23,42 +23,97 @@ sql 'set enable_force_spill=false' sql 'set topn_opt_limit_threshold=10' - // different types -qt_select """ select sum(orderkey), count(partkey), shipmode from tpch_tiny_lineitem group by shipmode limit 3; """ - -qt_select """ select sum(orderkey), count(partkey), linenumber from tpch_tiny_lineitem group by linenumber limit 3; """ +qt_select1 """ select sum(orderkey), count(partkey), shipmode from tpch_tiny_lineitem group by shipmode limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), shipmode from tpch_tiny_lineitem group by shipmode limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} -qt_select """ select sum(orderkey), count(partkey), tax from tpch_tiny_lineitem group by tax limit 3; """ +qt_select2 """ select sum(orderkey), count(partkey), linenumber from tpch_tiny_lineitem group by linenumber limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), linenumber from tpch_tiny_lineitem group by linenumber limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} -qt_select """ select sum(orderkey), count(partkey), commitdate from tpch_tiny_lineitem group by commitdate limit 3; """ +qt_select3 """ select sum(orderkey), count(partkey), tax from tpch_tiny_lineitem group by tax limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), tax from tpch_tiny_lineitem group by tax limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} +qt_select4 """ select sum(orderkey), count(partkey), commitdate from tpch_tiny_lineitem group by commitdate limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), commitdate from tpch_tiny_lineitem group by commitdate limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} // group by functions -qt_select """ select sum(orderkey), count(partkey), cast(commitdate as datetime) from tpch_tiny_lineitem group by cast(commitdate as datetime) limit 3; """ - -qt_select """ select sum(orderkey), count(partkey), month(commitdate) from tpch_tiny_lineitem group by month(commitdate) limit 3; """ +qt_select5 """ select sum(orderkey), count(partkey), cast(commitdate as datetime) from tpch_tiny_lineitem group by cast(commitdate as datetime) limit 3; """ +explain { + sql " select sum(orderkey), count(partkey), cast(commitdate as datetime) from tpch_tiny_lineitem group by cast(commitdate as datetime) limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} +qt_select6 """ select sum(orderkey), count(partkey), month(commitdate) from tpch_tiny_lineitem group by month(commitdate) limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), month(commitdate) from tpch_tiny_lineitem group by month(commitdate) limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} // mutli column -qt_select """ select sum(orderkey), count(partkey), shipmode, linenumber from tpch_tiny_lineitem group by shipmode, linenumber limit 3; """ - -qt_select """ select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax limit 3; """ - -qt_select """ select sum(orderkey), count(partkey), shipmode, linenumber , tax , commitdate from tpch_tiny_lineitem group by shipmode, linenumber, tax, commitdate limit 3; """ - +qt_select7 """ select sum(orderkey), count(partkey), shipmode, linenumber from tpch_tiny_lineitem group by shipmode, linenumber limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), shipmode, linenumber from tpch_tiny_lineitem group by shipmode, linenumber limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} +qt_select8 """ select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} +qt_select9 """ select sum(orderkey), count(partkey), shipmode, linenumber , tax , commitdate from tpch_tiny_lineitem group by shipmode, linenumber, tax, commitdate limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), shipmode, linenumber , tax , commitdate from tpch_tiny_lineitem group by shipmode, linenumber, tax, commitdate limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} // group by + order by // group by columns eq order by columns -qt_select """ select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode, linenumber, tax limit 3; """ - +qt_select10 """ select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode, linenumber, tax limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode, linenumber, tax limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} // group by columns contains order by columns -qt_select """ select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode limit 3; """ - +qt_select11 """ select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} // desc order by column -qt_select """ select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode desc, linenumber, tax limit 3; """ - -qt_select """ select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode desc, linenumber, tax desc limit 3; """ - +qt_select12 """ select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode desc, linenumber, tax limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode desc, linenumber, tax limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} +qt_select13 """ select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode desc, linenumber, tax desc limit 3; """ +explain{ + sql " select sum(orderkey), count(partkey), shipmode, linenumber , tax from tpch_tiny_lineitem group by shipmode, linenumber, tax order by shipmode desc, linenumber, tax desc limit 3; " + contains("VTOP-N") + contains("sortByGroupKey:true") +} } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org