morrySnow commented on code in PR #45209: URL: https://github.com/apache/doris/pull/45209#discussion_r1879470123
########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java: ########## @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2) + * +--Plan + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer + * +--Plan + * +--LogicalProject(c1, c2) + * +--LogicalJoin + * +--LogicalAggregate(output:count(distinct a)) + * +--LogicalCTEConsumer + * +--LogicalAggregate(output:count(distinct b)) + * +--LogicalCTEConsumer + * */ +public class DistinctSplit extends OneRewriteRuleFactory { + private static final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions = + ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); + + @Override + public Rule build() { + return logicalAggregate() + .whenNot(agg -> agg.getSourceRepeat().isPresent()) + .thenApply(ctx -> doSplit(ctx.root, ctx.cascadesContext)) + .toRule(RuleType.DISTINCT_SPLIT); + } + + private static boolean isDistinctMultiColumns(AggregateFunction func) { + if (func.arity() <= 1) { + return false; + } + for (int i = 1; i < func.arity(); ++i) { + // think about group_concat(distinct col_1, ',') + if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) { + return true; + } + } + return false; + } + + private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) { + List<Alias> aliases = new ArrayList<>(); Review Comment: distinctFuncWithAlias ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/CheckAnalysis.java: ########## @@ -165,10 +165,10 @@ private void checkAggregate(LogicalAggregate<? extends Plan> aggregate) { distinctFunctionNum += aggregateFunction.isDistinct() ? 1 : 0; } - if (distinctMultiColumns && distinctFunctionNum > 1) { - throw new AnalysisException( - "The query contains multi count distinct or sum distinct, each can't have multi columns"); - } + // if (distinctMultiColumns && distinctFunctionNum > 1) { + // throw new AnalysisException( + // "The query contains multi count distinct or sum distinct, each can't have multi columns"); + // } Review Comment: remove useless code ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java: ########## @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2) + * +--Plan + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer + * +--Plan + * +--LogicalProject(c1, c2) + * +--LogicalJoin + * +--LogicalAggregate(output:count(distinct a)) + * +--LogicalCTEConsumer + * +--LogicalAggregate(output:count(distinct b)) + * +--LogicalCTEConsumer + * */ +public class DistinctSplit extends OneRewriteRuleFactory { + private static final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions = + ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); Review Comment: add a new trait `SupportMultiDistinct` and add `convertToMultiDistinct` to it. Then, we could use `instanceof SupportMultiDistinct` to determine support or not ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java: ########## @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2) + * +--Plan + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer + * +--Plan + * +--LogicalProject(c1, c2) + * +--LogicalJoin + * +--LogicalAggregate(output:count(distinct a)) + * +--LogicalCTEConsumer + * +--LogicalAggregate(output:count(distinct b)) + * +--LogicalCTEConsumer + * */ +public class DistinctSplit extends OneRewriteRuleFactory { + private static final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions = + ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); + + @Override + public Rule build() { + return logicalAggregate() + .whenNot(agg -> agg.getSourceRepeat().isPresent()) + .thenApply(ctx -> doSplit(ctx.root, ctx.cascadesContext)) + .toRule(RuleType.DISTINCT_SPLIT); + } + + private static boolean isDistinctMultiColumns(AggregateFunction func) { + if (func.arity() <= 1) { + return false; + } + for (int i = 1; i < func.arity(); ++i) { + // think about group_concat(distinct col_1, ',') + if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) { + return true; + } + } + return false; + } + + private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) { + List<Alias> aliases = new ArrayList<>(); + Set<Expression> distinctFunc = new HashSet<>(); + List<Alias> otherAggFuncs = new ArrayList<>(); + if (!needTransform(agg, aliases, distinctFunc, otherAggFuncs)) { + return null; + } + + LogicalCTEProducer<Plan> producer = new LogicalCTEProducer<>(ctx.getStatementContext().getNextCTEId(), + agg.child()); + // construct cte consumer and aggregate + List<Expression> groupBy = agg.getGroupByExpressions(); + List<LogicalAggregate<Plan>> newAggs = new ArrayList<>(); + // All aggFunc except count distinct are placed in the first one Review Comment: all 'count distinct' in comment should replaced by 'multi distinct agg'? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/executor/Rewriter.java: ########## @@ -548,6 +549,9 @@ private static List<RewriteJob> getWholeTreeRewriteJobs( rewriteJobs.addAll(jobs(topic("or expansion", custom(RuleType.OR_EXPANSION, () -> OrExpansion.INSTANCE)))); } + rewriteJobs.addAll(jobs(topic("count distinct split", topDown(new DistinctSplit()) Review Comment: ```suggestion rewriteJobs.addAll(jobs(topic("distinct split", topDown(new DistinctSplit()) ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java: ########## @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2) + * +--Plan + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer + * +--Plan + * +--LogicalProject(c1, c2) + * +--LogicalJoin + * +--LogicalAggregate(output:count(distinct a)) + * +--LogicalCTEConsumer + * +--LogicalAggregate(output:count(distinct b)) + * +--LogicalCTEConsumer + * */ +public class DistinctSplit extends OneRewriteRuleFactory { + private static final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions = + ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); + + @Override + public Rule build() { + return logicalAggregate() + .whenNot(agg -> agg.getSourceRepeat().isPresent()) Review Comment: add a todo to support it? maybe could join by internal grouping id? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java: ########## @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2) + * +--Plan + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer + * +--Plan + * +--LogicalProject(c1, c2) + * +--LogicalJoin + * +--LogicalAggregate(output:count(distinct a)) + * +--LogicalCTEConsumer + * +--LogicalAggregate(output:count(distinct b)) + * +--LogicalCTEConsumer + * */ +public class DistinctSplit extends OneRewriteRuleFactory { + private static final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions = + ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); + + @Override + public Rule build() { + return logicalAggregate() + .whenNot(agg -> agg.getSourceRepeat().isPresent()) + .thenApply(ctx -> doSplit(ctx.root, ctx.cascadesContext)) + .toRule(RuleType.DISTINCT_SPLIT); + } + + private static boolean isDistinctMultiColumns(AggregateFunction func) { + if (func.arity() <= 1) { + return false; + } + for (int i = 1; i < func.arity(); ++i) { + // think about group_concat(distinct col_1, ',') + if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) { + return true; + } + } + return false; + } + + private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) { + List<Alias> aliases = new ArrayList<>(); + Set<Expression> distinctFunc = new HashSet<>(); + List<Alias> otherAggFuncs = new ArrayList<>(); + if (!needTransform(agg, aliases, distinctFunc, otherAggFuncs)) { + return null; + } + + LogicalCTEProducer<Plan> producer = new LogicalCTEProducer<>(ctx.getStatementContext().getNextCTEId(), + agg.child()); + // construct cte consumer and aggregate + List<Expression> groupBy = agg.getGroupByExpressions(); + List<LogicalAggregate<Plan>> newAggs = new ArrayList<>(); + // All aggFunc except count distinct are placed in the first one + Map<Alias, Alias> joinOutput = new HashMap<>(); + List<Expression> outputJoinGroupBys = new ArrayList<>(); + for (int i = 0; i < aliases.size(); ++i) { + Expression expr = aliases.get(i).child(0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(), + producer.getCteId(), "", producer); + ctx.putCTEIdToConsumer(consumer); + Map<Slot, Slot> replaceMap = new HashMap<>(); + for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) { + replaceMap.put(entry.getValue(), entry.getKey()); + } + List<Expression> replacedGroupBy = ExpressionUtils.replace(groupBy, replaceMap); + Expression count = ExpressionUtils.replace(expr, replaceMap); Review Comment: ```suggestion Expression newDistinctAggFunc = ExpressionUtils.replace(distinctAggFunc, replaceMap); ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java: ########## @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2) + * +--Plan + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer + * +--Plan + * +--LogicalProject(c1, c2) + * +--LogicalJoin + * +--LogicalAggregate(output:count(distinct a)) + * +--LogicalCTEConsumer + * +--LogicalAggregate(output:count(distinct b)) + * +--LogicalCTEConsumer + * */ +public class DistinctSplit extends OneRewriteRuleFactory { + private static final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions = + ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); + + @Override + public Rule build() { + return logicalAggregate() + .whenNot(agg -> agg.getSourceRepeat().isPresent()) + .thenApply(ctx -> doSplit(ctx.root, ctx.cascadesContext)) + .toRule(RuleType.DISTINCT_SPLIT); + } + + private static boolean isDistinctMultiColumns(AggregateFunction func) { + if (func.arity() <= 1) { + return false; + } + for (int i = 1; i < func.arity(); ++i) { + // think about group_concat(distinct col_1, ',') + if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) { + return true; + } + } + return false; + } + + private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) { + List<Alias> aliases = new ArrayList<>(); + Set<Expression> distinctFunc = new HashSet<>(); + List<Alias> otherAggFuncs = new ArrayList<>(); + if (!needTransform(agg, aliases, distinctFunc, otherAggFuncs)) { + return null; + } + + LogicalCTEProducer<Plan> producer = new LogicalCTEProducer<>(ctx.getStatementContext().getNextCTEId(), + agg.child()); + // construct cte consumer and aggregate + List<Expression> groupBy = agg.getGroupByExpressions(); + List<LogicalAggregate<Plan>> newAggs = new ArrayList<>(); + // All aggFunc except count distinct are placed in the first one + Map<Alias, Alias> joinOutput = new HashMap<>(); + List<Expression> outputJoinGroupBys = new ArrayList<>(); + for (int i = 0; i < aliases.size(); ++i) { + Expression expr = aliases.get(i).child(0); Review Comment: ```suggestion Expression distinctAggFunc = aliases.get(i).child(0); ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java: ########## @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2) + * +--Plan + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer + * +--Plan + * +--LogicalProject(c1, c2) + * +--LogicalJoin + * +--LogicalAggregate(output:count(distinct a)) + * +--LogicalCTEConsumer + * +--LogicalAggregate(output:count(distinct b)) + * +--LogicalCTEConsumer + * */ +public class DistinctSplit extends OneRewriteRuleFactory { + private static final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions = + ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); + + @Override + public Rule build() { + return logicalAggregate() + .whenNot(agg -> agg.getSourceRepeat().isPresent()) + .thenApply(ctx -> doSplit(ctx.root, ctx.cascadesContext)) + .toRule(RuleType.DISTINCT_SPLIT); + } + + private static boolean isDistinctMultiColumns(AggregateFunction func) { + if (func.arity() <= 1) { + return false; + } + for (int i = 1; i < func.arity(); ++i) { + // think about group_concat(distinct col_1, ',') + if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) { + return true; + } + } + return false; + } + + private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) { + List<Alias> aliases = new ArrayList<>(); + Set<Expression> distinctFunc = new HashSet<>(); + List<Alias> otherAggFuncs = new ArrayList<>(); + if (!needTransform(agg, aliases, distinctFunc, otherAggFuncs)) { + return null; + } + + LogicalCTEProducer<Plan> producer = new LogicalCTEProducer<>(ctx.getStatementContext().getNextCTEId(), + agg.child()); + // construct cte consumer and aggregate + List<Expression> groupBy = agg.getGroupByExpressions(); + List<LogicalAggregate<Plan>> newAggs = new ArrayList<>(); + // All aggFunc except count distinct are placed in the first one + Map<Alias, Alias> joinOutput = new HashMap<>(); + List<Expression> outputJoinGroupBys = new ArrayList<>(); + for (int i = 0; i < aliases.size(); ++i) { + Expression expr = aliases.get(i).child(0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(), + producer.getCteId(), "", producer); + ctx.putCTEIdToConsumer(consumer); + Map<Slot, Slot> replaceMap = new HashMap<>(); + for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) { + replaceMap.put(entry.getValue(), entry.getKey()); + } + List<Expression> replacedGroupBy = ExpressionUtils.replace(groupBy, replaceMap); + Expression count = ExpressionUtils.replace(expr, replaceMap); + List<NamedExpression> outputExpressions = replacedGroupBy.stream() + .map(e -> (Slot) e).collect(Collectors.toList()); Review Comment: ```suggestion .map(Slot.class::cast).collect(Collectors.toList()); ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java: ########## @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2) + * +--Plan + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer + * +--Plan + * +--LogicalProject(c1, c2) + * +--LogicalJoin + * +--LogicalAggregate(output:count(distinct a)) + * +--LogicalCTEConsumer + * +--LogicalAggregate(output:count(distinct b)) + * +--LogicalCTEConsumer + * */ +public class DistinctSplit extends OneRewriteRuleFactory { + private static final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions = + ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); + + @Override + public Rule build() { + return logicalAggregate() + .whenNot(agg -> agg.getSourceRepeat().isPresent()) + .thenApply(ctx -> doSplit(ctx.root, ctx.cascadesContext)) + .toRule(RuleType.DISTINCT_SPLIT); + } + + private static boolean isDistinctMultiColumns(AggregateFunction func) { + if (func.arity() <= 1) { + return false; + } + for (int i = 1; i < func.arity(); ++i) { + // think about group_concat(distinct col_1, ',') + if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) { + return true; + } + } + return false; + } + + private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) { + List<Alias> aliases = new ArrayList<>(); + Set<Expression> distinctFunc = new HashSet<>(); + List<Alias> otherAggFuncs = new ArrayList<>(); + if (!needTransform(agg, aliases, distinctFunc, otherAggFuncs)) { + return null; + } + + LogicalCTEProducer<Plan> producer = new LogicalCTEProducer<>(ctx.getStatementContext().getNextCTEId(), + agg.child()); + // construct cte consumer and aggregate + List<Expression> groupBy = agg.getGroupByExpressions(); + List<LogicalAggregate<Plan>> newAggs = new ArrayList<>(); + // All aggFunc except count distinct are placed in the first one + Map<Alias, Alias> joinOutput = new HashMap<>(); + List<Expression> outputJoinGroupBys = new ArrayList<>(); + for (int i = 0; i < aliases.size(); ++i) { + Expression expr = aliases.get(i).child(0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(), + producer.getCteId(), "", producer); + ctx.putCTEIdToConsumer(consumer); + Map<Slot, Slot> replaceMap = new HashMap<>(); + for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) { + replaceMap.put(entry.getValue(), entry.getKey()); + } + List<Expression> replacedGroupBy = ExpressionUtils.replace(groupBy, replaceMap); + Expression count = ExpressionUtils.replace(expr, replaceMap); + List<NamedExpression> outputExpressions = replacedGroupBy.stream() + .map(e -> (Slot) e).collect(Collectors.toList()); + Alias alias = new Alias(count); + outputExpressions.add(alias); + if (i == 0) { + List<Expression> otherAggFuncAliases = otherAggFuncs.stream() + .map(e -> ExpressionUtils.replace(e, replaceMap)).collect(Collectors.toList()); + for (Expression otherAggFuncAlias : otherAggFuncAliases) { + // otherAggFunc is instance of Alias + Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0)); + outputExpressions.add(outputOtherFunc); + joinOutput.put(outputOtherFunc, (Alias) otherAggFuncAlias); + } + // save replacedGroupBy + outputJoinGroupBys.addAll(replacedGroupBy); + } + LogicalAggregate<Plan> newAgg = new LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer); + newAggs.add(newAgg); + joinOutput.put(alias, aliases.get(i)); + } + LogicalJoin<Plan, Plan> join = constructJoin(newAggs, groupBy); + LogicalProject<Plan> project = constructProject(groupBy, joinOutput, outputJoinGroupBys, join); + return new LogicalCTEAnchor<Plan, Plan>(producer.getCteId(), producer, project); + } + + private static boolean needTransform(LogicalAggregate<Plan> agg, List<Alias> aliases, + Set<Expression> distinctFunc, List<Alias> otherAggFuncs) { + boolean distinctMultiColumns = false; + for (NamedExpression namedExpression : agg.getOutputExpressions()) { + if (!(namedExpression instanceof Alias) || !(namedExpression.child(0) instanceof AggregateFunction)) { + continue; + } + AggregateFunction aggFunc = (AggregateFunction) namedExpression.child(0); + if (supportedFunctions.contains(aggFunc.getClass()) && aggFunc.isDistinct()) { + aliases.add((Alias) namedExpression); + distinctFunc.add(aggFunc); + distinctMultiColumns |= isDistinctMultiColumns(aggFunc); Review Comment: use logical or for better performance ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/DistinctSplit.java: ########## @@ -0,0 +1,213 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite; + +import org.apache.doris.nereids.CascadesContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.OrderExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.agg.Avg; +import org.apache.doris.nereids.trees.expressions.functions.agg.Count; +import org.apache.doris.nereids.trees.expressions.functions.agg.GroupConcat; +import org.apache.doris.nereids.trees.expressions.functions.agg.Sum; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEAnchor; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEConsumer; +import org.apache.doris.nereids.trees.plans.logical.LogicalCTEProducer; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableSet; + +import java.util.ArrayList; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * LogicalAggregate(output:count(distinct a) as c1, count(distinct b) as c2) + * +--Plan + * -> + * LogicalCTEAnchor + * +--LogicalCTEProducer + * +--Plan + * +--LogicalProject(c1, c2) + * +--LogicalJoin + * +--LogicalAggregate(output:count(distinct a)) + * +--LogicalCTEConsumer + * +--LogicalAggregate(output:count(distinct b)) + * +--LogicalCTEConsumer + * */ +public class DistinctSplit extends OneRewriteRuleFactory { + private static final ImmutableSet<Class<? extends AggregateFunction>> supportedFunctions = + ImmutableSet.of(Count.class, Sum.class, Avg.class, GroupConcat.class); + + @Override + public Rule build() { + return logicalAggregate() + .whenNot(agg -> agg.getSourceRepeat().isPresent()) + .thenApply(ctx -> doSplit(ctx.root, ctx.cascadesContext)) + .toRule(RuleType.DISTINCT_SPLIT); + } + + private static boolean isDistinctMultiColumns(AggregateFunction func) { + if (func.arity() <= 1) { + return false; + } + for (int i = 1; i < func.arity(); ++i) { + // think about group_concat(distinct col_1, ',') + if (!(func.child(i) instanceof OrderExpression) && !func.child(i).getInputSlots().isEmpty()) { + return true; + } + } + return false; + } + + private static Plan doSplit(LogicalAggregate<Plan> agg, CascadesContext ctx) { + List<Alias> aliases = new ArrayList<>(); + Set<Expression> distinctFunc = new HashSet<>(); + List<Alias> otherAggFuncs = new ArrayList<>(); + if (!needTransform(agg, aliases, distinctFunc, otherAggFuncs)) { + return null; + } + + LogicalCTEProducer<Plan> producer = new LogicalCTEProducer<>(ctx.getStatementContext().getNextCTEId(), + agg.child()); + // construct cte consumer and aggregate + List<Expression> groupBy = agg.getGroupByExpressions(); + List<LogicalAggregate<Plan>> newAggs = new ArrayList<>(); + // All aggFunc except count distinct are placed in the first one + Map<Alias, Alias> joinOutput = new HashMap<>(); + List<Expression> outputJoinGroupBys = new ArrayList<>(); + for (int i = 0; i < aliases.size(); ++i) { + Expression expr = aliases.get(i).child(0); + LogicalCTEConsumer consumer = new LogicalCTEConsumer(ctx.getStatementContext().getNextRelationId(), + producer.getCteId(), "", producer); + ctx.putCTEIdToConsumer(consumer); + Map<Slot, Slot> replaceMap = new HashMap<>(); + for (Map.Entry<Slot, Slot> entry : consumer.getConsumerToProducerOutputMap().entrySet()) { + replaceMap.put(entry.getValue(), entry.getKey()); + } + List<Expression> replacedGroupBy = ExpressionUtils.replace(groupBy, replaceMap); + Expression count = ExpressionUtils.replace(expr, replaceMap); + List<NamedExpression> outputExpressions = replacedGroupBy.stream() + .map(e -> (Slot) e).collect(Collectors.toList()); + Alias alias = new Alias(count); + outputExpressions.add(alias); + if (i == 0) { + List<Expression> otherAggFuncAliases = otherAggFuncs.stream() + .map(e -> ExpressionUtils.replace(e, replaceMap)).collect(Collectors.toList()); + for (Expression otherAggFuncAlias : otherAggFuncAliases) { + // otherAggFunc is instance of Alias + Alias outputOtherFunc = new Alias(otherAggFuncAlias.child(0)); + outputExpressions.add(outputOtherFunc); + joinOutput.put(outputOtherFunc, (Alias) otherAggFuncAlias); + } + // save replacedGroupBy + outputJoinGroupBys.addAll(replacedGroupBy); + } + LogicalAggregate<Plan> newAgg = new LogicalAggregate<>(replacedGroupBy, outputExpressions, consumer); + newAggs.add(newAgg); + joinOutput.put(alias, aliases.get(i)); + } + LogicalJoin<Plan, Plan> join = constructJoin(newAggs, groupBy); + LogicalProject<Plan> project = constructProject(groupBy, joinOutput, outputJoinGroupBys, join); Review Comment: group by key id are same between producer and after project? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org