englefly commented on code in PR #64849:
URL: https://github.com/apache/doris/pull/64849#discussion_r3497105394
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/EliminateGroupByKey.java:
##########
@@ -38,69 +42,162 @@
import java.util.Map.Entry;
import java.util.Set;
-
/**
* Eliminate group by key based on fd item information.
* such as:
* for a -> b, we can get:
* group by a, b, c => group by a, c
+ *
+ * When a group-by key is FD-redundant but still needed in the output,
+ * it is wrapped with any_value() and assigned a fresh ExprId.
+ * Upper plan references are rewritten via ExprIdRewriter so that
+ * all ancestor nodes see the new ExprIds.
*/
-@DependsRules({EliminateGroupBy.class, ColumnPruning.class})
-public class EliminateGroupByKey implements RewriteRuleFactory {
+public class EliminateGroupByKey extends DefaultPlanRewriter<Map<ExprId,
ExprId>> implements CustomRewriter {
+ private ExprIdRewriter exprIdReplacer;
@Override
- public List<Rule> buildRules() {
- return ImmutableList.of(
- RuleType.ELIMINATE_GROUP_BY_KEY.build(
- logicalProject(logicalAggregate().when(agg ->
!agg.getSourceRepeat().isPresent()))
- .then(proj -> {
- LogicalAggregate<? extends Plan> agg =
proj.child();
- LogicalAggregate<Plan> newAgg =
eliminateGroupByKey(agg, proj.getInputSlots());
- if (newAgg == null) {
- return null;
- }
- return proj.withChildren(newAgg);
- })),
- RuleType.ELIMINATE_FILTER_GROUP_BY_KEY.build(
- logicalProject(logicalFilter(logicalAggregate()
- .when(agg ->
!agg.getSourceRepeat().isPresent())))
- .then(proj -> {
- LogicalAggregate<? extends Plan> agg =
proj.child().child();
- Set<Slot> requireSlots = new
HashSet<>(proj.getInputSlots());
-
requireSlots.addAll(proj.child(0).getInputSlots());
- LogicalAggregate<Plan> newAgg =
eliminateGroupByKey(agg, requireSlots);
- if (newAgg == null) {
- return null;
- }
- return
proj.withChildren(proj.child().withChildren(newAgg));
- })
- )
- );
+ public Plan rewriteRoot(Plan plan, JobContext jobContext) {
+ if (!plan.containsType(Aggregate.class)) {
+ return plan;
+ }
+ Map<ExprId, ExprId> replaceMap = new HashMap<>();
+ ExprIdRewriter.ReplaceRule replaceRule = new
ExprIdRewriter.ReplaceRule(replaceMap, false);
+ exprIdReplacer = new ExprIdRewriter(replaceRule, jobContext);
+ return plan.accept(this, replaceMap);
}
- LogicalAggregate<Plan> eliminateGroupByKey(LogicalAggregate<? extends
Plan> agg, Set<Slot> requireOutput) {
- Set<Expression> removeExpression = findCanBeRemovedExpressions(agg,
requireOutput,
+ @Override
+ public Plan visit(Plan plan, Map<ExprId, ExprId> replaceMap) {
+ plan = visitChildren(this, plan, replaceMap);
+ plan = exprIdReplacer.rewriteExpr(plan, replaceMap);
+ return plan;
+ }
+
+ @Override
+ public Plan visitLogicalProject(LogicalProject<? extends Plan> proj,
Map<ExprId, ExprId> replaceMap) {
+ proj = visitChildren(this, proj, replaceMap);
+
+ // Find the Aggregate child, possibly through a Filter
+ Plan child = proj.child(0);
+ LogicalAggregate<? extends Plan> agg;
+ boolean hasFilter = child instanceof LogicalFilter;
+ if (hasFilter && child.child(0) instanceof LogicalAggregate) {
+ agg = (LogicalAggregate<? extends Plan>) child.child(0);
+ } else if (child instanceof LogicalAggregate) {
+ agg = (LogicalAggregate<? extends Plan>) child;
+ } else {
+ return exprIdReplacer.rewriteExpr(proj, replaceMap);
+ }
+
+ // Don't transform if source repeat is present
+ if (agg.getSourceRepeat().isPresent()) {
+ return exprIdReplacer.rewriteExpr(proj, replaceMap);
+ }
+
+ // Compute requireOutput: slots needed by the Project (and Filter, if
present)
+ Set<Slot> requireOutput = new HashSet<>(proj.getInputSlots());
+ if (hasFilter) {
+ requireOutput.addAll(child.getInputSlots());
+ }
+
+ // Transform the aggregate
+ EliminateResult result = eliminateGroupByKeyWithMap(agg,
requireOutput);
+ if (!result.changed) {
+ return exprIdReplacer.rewriteExpr(proj, replaceMap);
+ }
+
+ // Merge into the global replaceMap so that all ancestor nodes get
rewritten
+ replaceMap.putAll(result.replaceMap);
+
+ // Rebuild the child chain with the new aggregate,
+ // and rewrite the Filter (if present) and Project expressions
+ Plan newChild;
+ if (hasFilter) {
+ Plan updatedFilter = child.withChildren(result.newAgg);
+ newChild = exprIdReplacer.rewriteExpr(updatedFilter, replaceMap);
+ } else {
+ newChild = result.newAgg;
+ }
+ Plan newProj = exprIdReplacer.rewriteExpr(proj.withChildren(newChild),
replaceMap);
+ return newProj;
+ }
+
+ /** Result of eliminateGroupByKey: the new aggregate and a map of old->new
ExprIds. */
+ private static class EliminateResult {
+ final LogicalAggregate<Plan> newAgg;
+ final Map<ExprId, ExprId> replaceMap;
+ final boolean changed;
+
+ EliminateResult(LogicalAggregate<Plan> newAgg, Map<ExprId, ExprId>
replaceMap, boolean changed) {
+ this.newAgg = newAgg;
+ this.replaceMap = replaceMap;
+ this.changed = changed;
+ }
+ }
+
+ EliminateResult eliminateGroupByKeyWithMap(LogicalAggregate<? extends
Plan> agg, Set<Slot> requireOutput) {
+ FindResult result = findCanBeRemovedExpressionsInternal(agg,
requireOutput,
agg.child().getLogicalProperties().getTrait());
+ Set<Expression> removeExpression = result.removeExpression;
+ Set<Expression> wrapWithAnyValue = result.wrapWithAnyValue;
+
List<Expression> newGroupExpression = new ArrayList<>();
for (Expression expression : agg.getGroupByExpressions()) {
- if (!removeExpression.contains(expression)) {
+ if (!removeExpression.contains(expression)
+ && !wrapWithAnyValue.contains(expression)) {
newGroupExpression.add(expression);
Review Comment:
这个rule 通过fd 消除group key, 所以至少会保留一个key
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]