924060929 commented on code in PR #10659: URL: https://github.com/apache/doris/pull/10659#discussion_r915456668
########## fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java: ########## @@ -114,60 +134,96 @@ public PlanFragment visit(Plan plan, PlanTranslatorContext context) { * Translate Agg. */ @Override - public PlanFragment visitPhysicalAggregation( - PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) { - + public PlanFragment visitPhysicalAggregate( + PhysicalUnaryPlan<PhysicalAggregate, Plan> agg, PlanTranslatorContext context) { PlanFragment inputPlanFragment = visit(agg.child(0), context); - - AggregationNode aggregationNode; - List<Slot> slotList = new ArrayList<>(); - PhysicalAggregation physicalAggregation = agg.getOperator(); - AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec(); - - List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList(); + PhysicalAggregate physicalAggregate = agg.getOperator(); + + // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts: + // 1. group by expressions: removing duplicate expressions add to tuple + // 2. agg functions: only removing duplicate agg functions in output expression should appear in tuple. + // e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple + // We need: + // 1. add a project after agg, if output expressions include agg function as a expression tree leaf. Review Comment: ```suggestion // 1. add a project after agg, if agg function is not the top output expression. ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java: ########## @@ -100,8 +103,25 @@ private static Expression swapEqualToForChildrenOrder(EqualTo<?, ?> equalTo, Lis } } - public void translatePlan(PhysicalPlan physicalPlan, PlanTranslatorContext context) { - visit(physicalPlan, context); + /** + * Translate Nereids Physical Plan tree to Stale Planner PlanFragment tree. + * + * @param physicalPlan Nereids Physical Plan tree + * @param context context to help translate + * @return Stale Planner PlanFragment tree + */ + public PlanFragment translatePlan(PhysicalPlan physicalPlan, PlanTranslatorContext context) { + PlanFragment rootFragment = visit(physicalPlan, context); + if (rootFragment.isPartitioned() && rootFragment.getPlanRoot().getNumInstances() > 1) { + rootFragment = createMergeFragment(rootFragment, context); + context.addPlanFragment(rootFragment); Review Comment: rename `createMergeFragment()` to `exchangeToMergeFragment()` and move `context.addPlanFragment(rootFragment)` to `exchangeToMergeFragment(rootFragment, context)` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java: ########## @@ -114,60 +134,96 @@ public PlanFragment visit(Plan plan, PlanTranslatorContext context) { * Translate Agg. */ @Override - public PlanFragment visitPhysicalAggregation( - PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) { - + public PlanFragment visitPhysicalAggregate( + PhysicalUnaryPlan<PhysicalAggregate, Plan> agg, PlanTranslatorContext context) { PlanFragment inputPlanFragment = visit(agg.child(0), context); - - AggregationNode aggregationNode; - List<Slot> slotList = new ArrayList<>(); - PhysicalAggregation physicalAggregation = agg.getOperator(); - AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec(); - - List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList(); + PhysicalAggregate physicalAggregate = agg.getOperator(); + + // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts: + // 1. group by expressions: removing duplicate expressions add to tuple + // 2. agg functions: only removing duplicate agg functions in output expression should appear in tuple. + // e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple + // We need: + // 1. add a project after agg, if output expressions include agg function as a expression tree leaf. + // 2. introduce canonicalized, semanticEquals and deterministic in Expression + // for removing duplicate. + List<Expression> groupByExpressionList = physicalAggregate.getGroupByExprList(); + List<NamedExpression> outputExpressionList = physicalAggregate.getOutputExpressionList(); + + // 1. generate slot reference for each group expression + List<SlotReference> groupSlotList = Lists.newArrayList(); + for (Expression e : groupByExpressionList) { + if (e instanceof SlotReference && outputExpressionList.stream().anyMatch(o -> o.contains(e::equals))) { + groupSlotList.add((SlotReference) e); + } else { + groupSlotList.add(new SlotReference(e.sql(), e.getDataType(), e.nullable(), Collections.emptyList())); + } + } ArrayList<Expr> execGroupingExpressions = groupByExpressionList.stream() - // Since output of plan doesn't contain the slots of groupBy, which is actually needed by - // the BE execution, so we have to collect them and add to the slotList to generate corresponding - // TupleDesc. - .peek(x -> slotList.addAll(x.collect(SlotReference.class::isInstance))) .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toCollection(ArrayList::new)); - slotList.addAll(agg.getOutput()); - TupleDescriptor outputTupleDesc = generateTupleDesc(slotList, context, null); - - List<NamedExpression> outputExpressionList = physicalAggregation.getOutputExpressionList(); - ArrayList<FunctionCallExpr> execAggExpressions = outputExpressionList.stream() - .map(e -> e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance)) + // 2. collect agg functions and generate agg function to slot reference map + List<Slot> aggFunctionOutput = Lists.newArrayList(); + List<AggregateFunction> aggregateFunctionList = outputExpressionList.stream() + .filter(o -> o.contains(AggregateFunction.class::isInstance)) + .peek(o -> aggFunctionOutput.add(o.toSlot())) + .map(o -> (List<AggregateFunction>) o.collect(AggregateFunction.class::isInstance)) .flatMap(List::stream) + .collect(Collectors.toList()); + ArrayList<FunctionCallExpr> execAggExpressions = aggregateFunctionList.stream() .map(x -> (FunctionCallExpr) ExpressionTranslator.translate(x, context)) .collect(Collectors.toCollection(ArrayList::new)); - List<Expression> partitionExpressionList = physicalAggregation.getPartitionExprList(); + // 3. generate output tuple + // TODO: currently, we only support sum(a), if we want to support sum(a) + 1, we need to + // split merge agg to project(agg) and generate tuple like what first phase agg do. + List<Slot> slotList = Lists.newArrayList(); + TupleDescriptor outputTupleDesc; + if (agg.getOperator().getAggPhase() == AggPhase.FIRST_MERGE) { + slotList.addAll(groupSlotList); + slotList.addAll(aggFunctionOutput); + outputTupleDesc = generateTupleDesc(slotList, null, context); + } else { + outputTupleDesc = generateTupleDesc(agg.getOutput(), null, context); + } + + // process partition list + List<Expression> partitionExpressionList = physicalAggregate.getPartitionExprList(); List<Expr> execPartitionExpressions = partitionExpressionList.stream() - .map(e -> (FunctionCallExpr) ExpressionTranslator.translate(e, context)).collect(Collectors.toList()); + .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList()); + DataPartition mergePartition = DataPartition.UNPARTITIONED; + if (CollectionUtils.isNotEmpty(execPartitionExpressions)) { Review Comment: Store execPartitionExpressions in the merge LogicalAggregate doesn't seem reasonable, because sender(input fragment) execute the partition expression, not receiver(merge fragment). ########## fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java: ########## @@ -364,25 +416,43 @@ private TupleDescriptor generateTupleDesc(List<Slot> slotList, PlanTranslatorCon } private PlanFragment createParentFragment(PlanFragment childFragment, DataPartition parentPartition, Review Comment: ```suggestion private PlanFragment exchangeToMergeFragment(PlanFragment childFragment, DataPartition parentPartition, ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java: ########## @@ -364,25 +416,43 @@ private TupleDescriptor generateTupleDesc(List<Slot> slotList, PlanTranslatorCon } private PlanFragment createParentFragment(PlanFragment childFragment, DataPartition parentPartition, - PlanTranslatorContext ctx) { - ExchangeNode exchangeNode = new ExchangeNode(ctx.nextNodeId(), childFragment.getPlanRoot(), false); + PlanTranslatorContext context) { + ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), childFragment.getPlanRoot(), false); exchangeNode.setNumInstances(childFragment.getPlanRoot().getNumInstances()); - PlanFragment parentFragment = new PlanFragment(ctx.nextFragmentId(), exchangeNode, parentPartition); + PlanFragment parentFragment = new PlanFragment(context.nextFragmentId(), exchangeNode, parentPartition); childFragment.setDestination(exchangeNode); childFragment.setOutputPartition(parentPartition); + context.addPlanFragment(parentFragment); return parentFragment; } private void connectChildFragment(PlanNode node, int childIdx, PlanFragment parentFragment, PlanFragment childFragment, PlanTranslatorContext context) { - ExchangeNode exchangeNode = new ExchangeNode(context.nextNodeId(), childFragment.getPlanRoot(), false); + ExchangeNode exchangeNode = new ExchangeNode(context.nextPlanNodeId(), childFragment.getPlanRoot(), false); exchangeNode.setNumInstances(childFragment.getPlanRoot().getNumInstances()); exchangeNode.setFragment(parentFragment); node.setChild(childIdx, exchangeNode); childFragment.setDestination(exchangeNode); } + /** + * Return unpartitioned fragment that merges the input fragment's output via + * an ExchangeNode. + * Requires that input fragment be partitioned. + */ + private PlanFragment createMergeFragment(PlanFragment inputFragment, PlanTranslatorContext context) { Review Comment: ```suggestion private PlanFragment exchangeToMergeFragment(PlanFragment inputFragment, PlanTranslatorContext context) { ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java: ########## @@ -114,60 +134,96 @@ public PlanFragment visit(Plan plan, PlanTranslatorContext context) { * Translate Agg. */ @Override - public PlanFragment visitPhysicalAggregation( - PhysicalUnaryPlan<PhysicalAggregation, Plan> agg, PlanTranslatorContext context) { - + public PlanFragment visitPhysicalAggregate( + PhysicalUnaryPlan<PhysicalAggregate, Plan> agg, PlanTranslatorContext context) { PlanFragment inputPlanFragment = visit(agg.child(0), context); - - AggregationNode aggregationNode; - List<Slot> slotList = new ArrayList<>(); - PhysicalAggregation physicalAggregation = agg.getOperator(); - AggregateInfo.AggPhase phase = physicalAggregation.getAggPhase().toExec(); - - List<Expression> groupByExpressionList = physicalAggregation.getGroupByExprList(); + PhysicalAggregate physicalAggregate = agg.getOperator(); + + // TODO: stale planner generate aggregate tuple in a special way. tuple include 2 parts: + // 1. group by expressions: removing duplicate expressions add to tuple + // 2. agg functions: only removing duplicate agg functions in output expression should appear in tuple. + // e.g. select sum(v1) + 1, sum(v1) + 2 from t1 should only generate one sum(v1) in tuple + // We need: + // 1. add a project after agg, if output expressions include agg function as a expression tree leaf. + // 2. introduce canonicalized, semanticEquals and deterministic in Expression + // for removing duplicate. + List<Expression> groupByExpressionList = physicalAggregate.getGroupByExprList(); + List<NamedExpression> outputExpressionList = physicalAggregate.getOutputExpressionList(); + + // 1. generate slot reference for each group expression + List<SlotReference> groupSlotList = Lists.newArrayList(); + for (Expression e : groupByExpressionList) { + if (e instanceof SlotReference && outputExpressionList.stream().anyMatch(o -> o.contains(e::equals))) { + groupSlotList.add((SlotReference) e); + } else { + groupSlotList.add(new SlotReference(e.sql(), e.getDataType(), e.nullable(), Collections.emptyList())); + } + } ArrayList<Expr> execGroupingExpressions = groupByExpressionList.stream() - // Since output of plan doesn't contain the slots of groupBy, which is actually needed by - // the BE execution, so we have to collect them and add to the slotList to generate corresponding - // TupleDesc. - .peek(x -> slotList.addAll(x.collect(SlotReference.class::isInstance))) .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toCollection(ArrayList::new)); - slotList.addAll(agg.getOutput()); - TupleDescriptor outputTupleDesc = generateTupleDesc(slotList, context, null); - - List<NamedExpression> outputExpressionList = physicalAggregation.getOutputExpressionList(); - ArrayList<FunctionCallExpr> execAggExpressions = outputExpressionList.stream() - .map(e -> e.<List<AggregateFunction>>collect(AggregateFunction.class::isInstance)) + // 2. collect agg functions and generate agg function to slot reference map + List<Slot> aggFunctionOutput = Lists.newArrayList(); + List<AggregateFunction> aggregateFunctionList = outputExpressionList.stream() + .filter(o -> o.contains(AggregateFunction.class::isInstance)) + .peek(o -> aggFunctionOutput.add(o.toSlot())) + .map(o -> (List<AggregateFunction>) o.collect(AggregateFunction.class::isInstance)) .flatMap(List::stream) + .collect(Collectors.toList()); + ArrayList<FunctionCallExpr> execAggExpressions = aggregateFunctionList.stream() .map(x -> (FunctionCallExpr) ExpressionTranslator.translate(x, context)) .collect(Collectors.toCollection(ArrayList::new)); - List<Expression> partitionExpressionList = physicalAggregation.getPartitionExprList(); + // 3. generate output tuple + // TODO: currently, we only support sum(a), if we want to support sum(a) + 1, we need to + // split merge agg to project(agg) and generate tuple like what first phase agg do. + List<Slot> slotList = Lists.newArrayList(); + TupleDescriptor outputTupleDesc; + if (agg.getOperator().getAggPhase() == AggPhase.FIRST_MERGE) { + slotList.addAll(groupSlotList); + slotList.addAll(aggFunctionOutput); + outputTupleDesc = generateTupleDesc(slotList, null, context); + } else { + outputTupleDesc = generateTupleDesc(agg.getOutput(), null, context); + } + + // process partition list + List<Expression> partitionExpressionList = physicalAggregate.getPartitionExprList(); List<Expr> execPartitionExpressions = partitionExpressionList.stream() - .map(e -> (FunctionCallExpr) ExpressionTranslator.translate(e, context)).collect(Collectors.toList()); + .map(e -> ExpressionTranslator.translate(e, context)).collect(Collectors.toList()); + DataPartition mergePartition = DataPartition.UNPARTITIONED; + if (CollectionUtils.isNotEmpty(execPartitionExpressions)) { + mergePartition = DataPartition.hashPartitioned(execGroupingExpressions); + } + // todo: support DISTINCT + AggregationNode aggregationNode; AggregateInfo aggInfo; - switch (phase) { + switch (physicalAggregate.getAggPhase()) { case FIRST: aggInfo = AggregateInfo.create(execGroupingExpressions, execAggExpressions, outputTupleDesc, outputTupleDesc, AggregateInfo.AggPhase.FIRST); - aggregationNode = new AggregationNode(context.nextNodeId(), inputPlanFragment.getPlanRoot(), aggInfo); + aggregationNode = new AggregationNode(context.nextPlanNodeId(), + inputPlanFragment.getPlanRoot(), aggInfo); aggregationNode.unsetNeedsFinalize(); - aggregationNode.setUseStreamingPreagg(physicalAggregation.isUsingStream()); + aggregationNode.setUseStreamingPreagg(physicalAggregate.isUsingStream()); aggregationNode.setIntermediateTuple(); - if (!partitionExpressionList.isEmpty()) { - inputPlanFragment.setOutputPartition(DataPartition.hashPartitioned(execPartitionExpressions)); - } - break; + inputPlanFragment.setPlanRoot(aggregationNode); + PlanFragment mergeFragment = createParentFragment(inputPlanFragment, mergePartition, context); Review Comment: ```suggestion PlanFragment mergeFragment = exchangeToMergeFragment(inputPlanFragment, mergePartition, context); ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java: ########## @@ -115,46 +136,16 @@ public Rule<Plan> build() { }).toRule(RuleType.AGGREGATE_DISASSEMBLE); } - private org.apache.doris.catalog.AggregateFunction findAggFunc(AggregateFunction functionCall) { - FunctionName functionName = new FunctionName(functionCall.getName()); - List<Expression> expressionList = functionCall.getArguments(); - List<Type> staleTypeList = expressionList.stream().map(Expression::getDataType) - .map(DataType::toCatalogDataType).collect(Collectors.toList()); - Function staleFuncDesc = new Function(functionName, staleTypeList, - functionCall.getDataType().toCatalogDataType(), - // I think an aggregate function will never have a variable length parameters - false); - Function staleFunc = Catalog.getCurrentCatalog() - .getFunction(staleFuncDesc, CompareMode.IS_IDENTICAL); - Preconditions.checkArgument(staleFunc instanceof org.apache.doris.catalog.AggregateFunction); - return (org.apache.doris.catalog.AggregateFunction) staleFunc; - } - - @SuppressWarnings("unchecked") - private <T extends Expression> void replaceSlot(Map<Slot, Slot> staleToNew, - List<T> expressionList, Expression root, int index) { - if (index != -1) { - if (root instanceof Slot) { - Slot v = staleToNew.get(root); - if (v == null) { - return; - } - expressionList.set(index, (T) v); - return; - } - } - List<Expression> children = root.children(); - for (int i = 0; i < children.size(); i++) { - Expression cur = children.get(i); - if (!(cur instanceof Slot)) { - replaceSlot(staleToNew, expressionList, cur, -1); - continue; - } - Expression v = staleToNew.get(cur); - if (v == null) { - continue; + private static class AggregateFunctionParamsRewriter + extends DefaultExpressionRewriter<Map<AggregateFunction, NamedExpression>> { + @Override + public Expression visitBoundFunction(BoundFunction boundFunction, Review Comment: you can use `visitAggregateFunction` function ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java: ########## @@ -64,49 +65,69 @@ public Rule<Plan> build() { Operator operator = plan.getOperator(); LogicalAggregate agg = (LogicalAggregate) operator; List<NamedExpression> outputExpressionList = agg.getOutputExpressionList(); - List<NamedExpression> intermediateAggExpressionList = Lists.newArrayList(); - // TODO: shouldn't extract agg function from this field. - for (NamedExpression namedExpression : outputExpressionList) { - namedExpression = (NamedExpression) namedExpression.clone(); - List<AggregateFunction> functionCallList = - namedExpression.collect(org.apache.doris.catalog.AggregateFunction.class::isInstance); - // TODO: we will have another mechanism to get corresponding stale agg func. - for (AggregateFunction functionCall : functionCallList) { - org.apache.doris.catalog.AggregateFunction staleAggFunc = findAggFunc(functionCall); - Type staleIntermediateType = staleAggFunc.getIntermediateType(); - Type staleRetType = staleAggFunc.getReturnType(); - if (staleIntermediateType != null && !staleIntermediateType.equals(staleRetType)) { - functionCall.setIntermediate(DataType.convertFromCatalogDataType(staleIntermediateType)); + List<Expression> groupByExpressionList = agg.getGroupByExpressionList(); + + Map<AggregateFunction, NamedExpression> aggregateFunctionAliasMap = Maps.newHashMap(); + for (NamedExpression outputExpression : outputExpressionList) { + outputExpression.foreach(e -> { + if (e instanceof AggregateFunction) { + AggregateFunction a = (AggregateFunction) e; + aggregateFunctionAliasMap.put(a, new Alias<>(a, a.sql())); + } + }); + } + + List<Expression> updateGroupByExpressionList = groupByExpressionList; + List<NamedExpression> updateGroupByAliasList = updateGroupByExpressionList.stream() + .map(g -> new Alias<>(g, g.sql())) + .collect(Collectors.toList()); + + List<NamedExpression> updateOutputExpressionList = Lists.newArrayList(); + updateOutputExpressionList.addAll(updateGroupByAliasList); + updateOutputExpressionList.addAll(aggregateFunctionAliasMap.values()); + + List<Expression> mergeGroupByExpressionList = updateGroupByAliasList.stream() + .map(NamedExpression::toSlot).collect(Collectors.toList()); + + List<NamedExpression> mergeOutputExpressionList = Lists.newArrayList(); + for (NamedExpression o : outputExpressionList) { + if (o.contains(AggregateFunction.class::isInstance)) { + mergeOutputExpressionList.add((NamedExpression) new AggregateFunctionParamsRewriter() + .visit(o, aggregateFunctionAliasMap)); + } else { + for (int i = 0; i < updateGroupByAliasList.size(); i++) { + // TODO: we need to do sub tree match and replace. but we do not have semanticEquals now. + // e.g. a + 1 + 2 in output expression should be replaced by + // (slot reference to update phase out (a + 1)) + 2, if we do group by a + 1 + // currently, we could only handle output expression same with group by expression + if (o instanceof SlotReference) { + // a in output expression will be SLotReference + if (o.equals(updateGroupByExpressionList.get(i))) { + mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot()); + break; + } + } else if (o instanceof Alias) { + // a + 1 in output expression will be Alias + if (o.child(0).equals(updateGroupByExpressionList.get(i))) { + mergeOutputExpressionList.add(updateGroupByAliasList.get(i).toSlot()); + break; + } + } } } - intermediateAggExpressionList.add(namedExpression); } + LogicalAggregate localAgg = new LogicalAggregate( - agg.getGroupByExprList().stream().map(Expression::clone).collect(Collectors.toList()), - intermediateAggExpressionList, + updateGroupByExpressionList, + updateOutputExpressionList, Review Comment: I think this variable names and compute logic is confuse, how about this: 1. localGroupByExprs = originGloupByExprs 2. localOutputExprs = originOutput.withAlias 3. globalGroupByWithAlias = originGloupByExprs.withAlias 4. globalOutputAlias = originOutput.replaceAggregateFunctionArguments this advantage is 1. variableName contains position and member information 2. assign statement contains the most simple compute logical -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org