This is an automated email from the ASF dual-hosted git repository.

yiguolei pushed a commit to branch branch-2.1
in repository https://gitbox.apache.org/repos/asf/doris.git

commit 8c535c51b5a87b949bfb3fec75b575f716687476
Author: Pxl <pxl...@qq.com>
AuthorDate: Thu Apr 18 10:33:24 2024 +0800

    [Improvement](materialized-view) support multiple agg function have same 
base table slot (#33774)
    
    support multiple agg function have same base table slot
---
 .../glue/translator/PhysicalPlanTranslator.java    |   8 +-
 .../nereids/rules/analysis/BindExpression.java     |  26 +---
 .../mv/AbstractSelectMaterializedIndexRule.java    |  19 +--
 .../mv/SelectMaterializedIndexWithAggregate.java   | 173 +++++++++++----------
 .../org/apache/doris/nereids/util/PlanUtils.java   |  25 +++
 .../multi_agg_with_same_slot.out                   |  48 ++++++
 .../data/mv_p0/mv_percentile/mv_percentile.out     |  36 +++++
 .../multi_agg_with_same_slot.groovy                |  73 +++++++++
 .../mv_p0/mv_percentile/mv_percentile.groovy       |  66 ++++++++
 9 files changed, 359 insertions(+), 115 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
index 8117d9122d1..fc9a88cf2e0 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java
@@ -2224,9 +2224,11 @@ public class PhysicalPlanTranslator extends 
DefaultPlanVisitor<PlanFragment, Pla
                 .map(expr -> ExpressionTranslator.translate(expr, 
context)).collect(ImmutableList.toImmutableList());
 
         // outputSlots's order need same with preRepeatExprs
-        List<Slot> outputSlots = Stream.concat(
-                repeat.getOutputExpressions().stream().filter(output -> 
flattenGroupingSetExprs.contains(output)),
-                repeat.getOutputExpressions().stream().filter(output -> 
!flattenGroupingSetExprs.contains(output)))
+        List<Slot> outputSlots = Stream
+                .concat(repeat.getOutputExpressions().stream()
+                        .filter(output -> 
flattenGroupingSetExprs.contains(output)),
+                        repeat.getOutputExpressions().stream()
+                                .filter(output -> 
!flattenGroupingSetExprs.contains(output)).distinct())
                 
.map(NamedExpression::toSlot).collect(ImmutableList.toImmutableList());
 
         // NOTE: we should first translate preRepeatExprs, then generate 
output tuple,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
index e20d3e8d551..8957800c7ed 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindExpression.java
@@ -643,7 +643,7 @@ public class BindExpression implements AnalysisRuleFactory {
             boundGroupingSetsBuilder.add(boundGroupingSet);
         }
         List<List<Expression>> boundGroupingSets = 
boundGroupingSetsBuilder.build();
-        List<NamedExpression> nullableOutput = 
adjustNullableForRepeat(boundGroupingSets, boundRepeatOutput);
+        List<NamedExpression> nullableOutput = 
PlanUtils.adjustNullableForRepeat(boundGroupingSets, boundRepeatOutput);
         for (List<Expression> groupingSet : boundGroupingSets) {
             checkIfOutputAliasNameDuplicatedForGroupBy(groupingSet, 
nullableOutput);
         }
@@ -800,30 +800,6 @@ public class BindExpression implements AnalysisRuleFactory 
{
         return new LogicalSort<>(boundOrderKeys.build(), sort.child());
     }
 
-    /**
-     * For the columns whose output exists in grouping sets, they need to be 
assigned as nullable.
-     */
-    private List<NamedExpression> adjustNullableForRepeat(
-            List<List<Expression>> groupingSets,
-            List<NamedExpression> outputs) {
-        Set<Slot> groupingSetsUsedSlots = groupingSets.stream()
-                .flatMap(Collection::stream)
-                .map(Expression::getInputSlots)
-                .flatMap(Set::stream)
-                .collect(Collectors.toSet());
-        Builder<NamedExpression> nullableOutputs = 
ImmutableList.builderWithExpectedSize(outputs.size());
-        for (NamedExpression output : outputs) {
-            Expression nullableOutput = output.rewriteUp(expr -> {
-                if (expr instanceof Slot && 
groupingSetsUsedSlots.contains(expr)) {
-                    return ((Slot) expr).withNullable(true);
-                }
-                return expr;
-            });
-            nullableOutputs.add((NamedExpression) nullableOutput);
-        }
-        return nullableOutputs.build();
-    }
-
     private LogicalTVFRelation 
bindTableValuedFunction(MatchingContext<UnboundTVFRelation> ctx) {
         UnboundTVFRelation unboundTVFRelation = ctx.root;
         StatementContext statementContext = ctx.statementContext;
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
index e0518b2c117..03bef1a6b47 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/AbstractSelectMaterializedIndexRule.java
@@ -57,6 +57,7 @@ import 
org.apache.doris.nereids.trees.plans.logical.LogicalProject;
 import org.apache.doris.nereids.trees.plans.logical.LogicalRepeat;
 import org.apache.doris.nereids.trees.plans.visitor.DefaultPlanVisitor;
 import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.PlanUtils;
 import org.apache.doris.planner.PlanNode;
 
 import com.google.common.collect.ImmutableList;
@@ -580,20 +581,20 @@ public abstract class AbstractSelectMaterializedIndexRule 
{
         public LogicalRepeat<Plan> visitLogicalRepeat(LogicalRepeat<? extends 
Plan> repeat, Void ctx) {
             Plan child = repeat.child(0).accept(this, ctx);
             List<List<Expression>> groupingSets = repeat.getGroupingSets();
-            ImmutableList.Builder<List<Expression>> newGroupingExprs = 
ImmutableList.builder();
+            List<List<Expression>> newGroupingExprs = Lists.newArrayList();
             for (List<Expression> expressions : groupingSets) {
-                newGroupingExprs.add(expressions.stream()
-                        .map(expr -> new 
ReplaceExpressionWithMvColumn(slotContext).replace(expr))
-                        .collect(ImmutableList.toImmutableList())
-                );
+                newGroupingExprs.add(
+                        expressions.stream().map(expr -> new 
ReplaceExpressionWithMvColumn(slotContext).replace(expr))
+                                .collect(ImmutableList.toImmutableList()));
             }
 
             List<NamedExpression> outputExpressions = 
repeat.getOutputExpressions();
-            List<NamedExpression> newOutputExpressions = 
outputExpressions.stream()
-                    .map(expr -> (NamedExpression) new 
ReplaceExpressionWithMvColumn(slotContext).replace(expr))
-                    .collect(ImmutableList.toImmutableList());
+            List<NamedExpression> newOutputExpressions = 
PlanUtils.adjustNullableForRepeat(newGroupingExprs,
+                    outputExpressions.stream()
+                            .map(expr -> (NamedExpression) new 
ReplaceExpressionWithMvColumn(slotContext).replace(expr))
+                            .collect(ImmutableList.toImmutableList()));
 
-            return repeat.withNormalizedExpr(newGroupingExprs.build(), 
newOutputExpressions, child);
+            return repeat.withNormalizedExpr(newGroupingExprs, 
newOutputExpressions, child);
         }
 
         @Override
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
index c84c5212a5b..f28b3952fa9 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/mv/SelectMaterializedIndexWithAggregate.java
@@ -39,7 +39,6 @@ import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.SlotNotFromChildren;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
 import org.apache.doris.nereids.trees.expressions.VirtualSlotReference;
 import org.apache.doris.nereids.trees.expressions.WhenClause;
 import 
org.apache.doris.nereids.trees.expressions.functions.agg.AggregateFunction;
@@ -81,6 +80,7 @@ import com.google.common.base.Suppliers;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.ImmutableMap;
 import com.google.common.collect.ImmutableSet;
+import com.google.common.collect.Lists;
 import com.google.common.collect.Maps;
 import com.google.common.collect.Sets;
 import com.google.common.collect.Streams;
@@ -207,22 +207,19 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                             LogicalOlapScan mvPlan = 
createLogicalOlapScan(scan, result);
                             SlotContext slotContext = 
generateBaseScanExprToMvExpr(mvPlan);
 
-                            List<NamedExpression> newProjectList = 
replaceProjectList(project,
+                            List<NamedExpression> newProjectList = 
replaceOutput(project.getProjects(),
                                     result.exprRewriteMap.projectExprMap);
                             LogicalProject<LogicalOlapScan> newProject = new 
LogicalProject<>(
                                     generateNewOutputsWithMvOutputs(mvPlan, 
newProjectList),
                                     
scan.withMaterializedIndexSelected(result.preAggStatus, result.indexId));
-                            return new LogicalProject<>(
-                                generateProjectsAlias(agg.getOutputs(), 
slotContext),
-                                    new 
ReplaceExpressions(slotContext).replace(
-                                        new LogicalAggregate<>(
-                                            agg.getGroupByExpressions(),
-                                            replaceAggOutput(agg, 
Optional.of(project), Optional.of(newProject),
-                                                    result.exprRewriteMap),
-                                            agg.isNormalized(),
-                                            agg.getSourceRepeat(),
-                                            newProject
-                                        ), mvPlan));
+                            return new 
LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext),
+                                    new ReplaceExpressions(slotContext)
+                                            .replace(
+                                                    new 
LogicalAggregate<>(agg.getGroupByExpressions(),
+                                                            
replaceAggOutput(agg, Optional.of(project),
+                                                                    
Optional.of(newProject), result.exprRewriteMap),
+                                                            
agg.isNormalized(), agg.getSourceRepeat(), newProject),
+                                                    mvPlan));
                         
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_PROJECT_SCAN),
 
                 // filter could push down and project.
@@ -274,7 +271,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                                 mvPlanWithoutAgg)));
                             }
 
-                            List<NamedExpression> newProjectList = 
replaceProjectList(project,
+                            List<NamedExpression> newProjectList = 
replaceOutput(project.getProjects(),
                                     result.exprRewriteMap.projectExprMap);
                             LogicalProject<Plan> newProject = new 
LogicalProject<>(
                                     generateNewOutputsWithMvOutputs(mvPlan, 
newProjectList),
@@ -322,7 +319,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                     .map(e -> 
result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
                                     filter.getConjuncts());
 
-                            List<NamedExpression> newProjectList = 
replaceProjectList(project,
+                            List<NamedExpression> newProjectList = 
replaceOutput(project.getProjects(),
                                     result.exprRewriteMap.projectExprMap);
                             LogicalProject<Plan> newProject = new 
LogicalProject<>(
                                     generateNewOutputsWithMvOutputs(mvPlan, 
newProjectList), mvPlan);
@@ -342,33 +339,31 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
 
                 // only agg above scan
                 // Aggregate(Repeat(Scan))
-                logicalAggregate(
-                    
logicalRepeat(logicalOlapScan().when(this::shouldSelectIndexWithAgg))).thenApplyNoThrow(ctx
 -> {
-                        LogicalAggregate<LogicalRepeat<LogicalOlapScan>> agg = 
ctx.root;
-                        LogicalRepeat<LogicalOlapScan> repeat = agg.child();
-                        LogicalOlapScan scan = repeat.child();
-                        SelectResult result = select(
-                                scan,
-                                agg.getInputSlots(),
-                                ImmutableSet.of(),
-                                extractAggFunctionAndReplaceSlot(agg, 
Optional.empty()),
-                                nonVirtualGroupByExprs(agg),
-                                new HashSet<>(agg.getExpressions()));
-
-                        LogicalOlapScan mvPlan = createLogicalOlapScan(scan, 
result);
-                        SlotContext slotContext = 
generateBaseScanExprToMvExpr(mvPlan);
-
-                        return new 
LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext),
-                                new 
ReplaceExpressions(slotContext).replace(new LogicalAggregate<>(
-                                        agg.getGroupByExpressions(),
-                                        replaceAggOutput(
-                                                agg, Optional.empty(), 
Optional.empty(), result.exprRewriteMap),
-                                        agg.isNormalized(), 
agg.getSourceRepeat(),
-                                        repeat.withAggOutputAndChild(
-                                                replaceRepeatOutput(repeat, 
result.exprRewriteMap.projectExprMap),
-                                                mvPlan)),
-                                        mvPlan));
-                    }).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_SCAN),
+                
logicalAggregate(logicalRepeat(logicalOlapScan().when(this::shouldSelectIndexWithAgg)))
+                        .thenApplyNoThrow(ctx -> {
+                            LogicalAggregate<LogicalRepeat<LogicalOlapScan>> 
agg = ctx.root;
+                            LogicalRepeat<LogicalOlapScan> repeat = 
agg.child();
+                            LogicalOlapScan scan = repeat.child();
+                            SelectResult result = select(scan, 
agg.getInputSlots(), ImmutableSet.of(),
+                                    extractAggFunctionAndReplaceSlot(agg, 
Optional.empty()),
+                                    nonVirtualGroupByExprs(agg), new 
HashSet<>(agg.getExpressions()));
+
+                            LogicalOlapScan mvPlan = 
createLogicalOlapScan(scan, result);
+                            SlotContext slotContext = 
generateBaseScanExprToMvExpr(mvPlan);
+
+                            return new 
LogicalProject<>(generateProjectsAlias(agg.getOutputs(), slotContext),
+                                    new ReplaceExpressions(slotContext)
+                                            .replace(
+                                                    new 
LogicalAggregate<>(agg.getGroupByExpressions(),
+                                                            
replaceAggOutput(agg, Optional.empty(), Optional.empty(),
+                                                                    
result.exprRewriteMap),
+                                                            
agg.isNormalized(), agg.getSourceRepeat(),
+                                                            
repeat.withAggOutputAndChild(
+                                                                    
replaceOutput(repeat.getOutputs(),
+                                                                            
result.exprRewriteMap.projectExprMap),
+                                                                    mvPlan)),
+                                                    mvPlan));
+                        }).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_SCAN),
 
                 // filter could push down scan.
                 // Aggregate(Repeat(Filter(Scan)))
@@ -411,9 +406,10 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                             // because the slots to replace
                                             // are value columns, which 
shouldn't appear in filters.
                                             repeat.withAggOutputAndChild(
-                                                    
replaceRepeatOutput(repeat, result.exprRewriteMap.projectExprMap),
-                                                    
filter.withChildren(mvPlan))
-                                        ), mvPlan));
+                                                    
replaceOutput(repeat.getOutputs(),
+                                                            
result.exprRewriteMap.projectExprMap),
+                                                    
filter.withChildren(mvPlan))),
+                                            mvPlan));
                         
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_FILTER_SCAN),
 
                 // column pruning or other projections such as alias, etc.
@@ -438,7 +434,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                             LogicalOlapScan mvPlan = 
createLogicalOlapScan(scan, result);
                             SlotContext slotContext = 
generateBaseScanExprToMvExpr(mvPlan);
 
-                            List<NamedExpression> newProjectList = 
replaceProjectList(project,
+                            List<NamedExpression> newProjectList = 
replaceOutput(project.getProjects(),
                                     result.exprRewriteMap.projectExprMap);
                             LogicalProject<LogicalOlapScan> newProject = new 
LogicalProject<>(
                                     generateNewOutputsWithMvOutputs(mvPlan, 
newProjectList), mvPlan);
@@ -449,7 +445,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                                     replaceAggOutput(agg, 
Optional.of(project), Optional.of(newProject),
                                                             
result.exprRewriteMap),
                                                     agg.isNormalized(), 
agg.getSourceRepeat(),
-                                                    
repeat.withAggOutputAndChild(replaceRepeatOutput(repeat,
+                                                    
repeat.withAggOutputAndChild(replaceOutput(repeat.getOutputs(),
                                                             
result.exprRewriteMap.projectExprMap), newProject)),
                                             mvPlan));
                         
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_SCAN),
@@ -487,7 +483,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                     .map(e -> 
result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
                                     filter.getConjuncts());
 
-                            List<NamedExpression> newProjectList = 
replaceProjectList(project,
+                            List<NamedExpression> newProjectList = 
replaceOutput(project.getProjects(),
                                     result.exprRewriteMap.projectExprMap);
                             LogicalProject<Plan> newProject = new 
LogicalProject<>(
                                     generateNewOutputsWithMvOutputs(mvPlan, 
newProjectList),
@@ -499,7 +495,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                                     replaceAggOutput(agg, 
Optional.of(project), Optional.of(newProject),
                                                             
result.exprRewriteMap),
                                                     agg.isNormalized(), 
agg.getSourceRepeat(),
-                                                    
repeat.withAggOutputAndChild(replaceRepeatOutput(repeat,
+                                                    
repeat.withAggOutputAndChild(replaceOutput(repeat.getOutputs(),
                                                             
result.exprRewriteMap.projectExprMap), newProject)),
                                             mvPlan));
                         
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_PROJECT_FILTER_SCAN),
@@ -535,7 +531,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                     .map(e -> 
result.exprRewriteMap.replaceAgg(e)).collect(Collectors.toSet()),
                                     filter.getConjuncts());
 
-                            List<NamedExpression> newProjectList = 
replaceProjectList(project,
+                            List<NamedExpression> newProjectList = 
replaceOutput(project.getProjects(),
                                     result.exprRewriteMap.projectExprMap);
                             LogicalProject<Plan> newProject = new 
LogicalProject<>(
                                     generateNewOutputsWithMvOutputs(mvPlan, 
newProjectList),
@@ -547,7 +543,8 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                                     Optional.of(newProject), 
result.exprRewriteMap),
                                             agg.isNormalized(), 
agg.getSourceRepeat(),
                                             repeat.withAggOutputAndChild(
-                                                    
replaceRepeatOutput(repeat, result.exprRewriteMap.projectExprMap),
+                                                    
replaceOutput(repeat.getOutputs(),
+                                                            
result.exprRewriteMap.projectExprMap),
                                                     
filter.withChildren(newProject))),
                                             mvPlan));
                         
}).toRule(RuleType.MATERIALIZED_INDEX_AGG_REPEAT_FILTER_PROJECT_SCAN)
@@ -1085,9 +1082,13 @@ public class SelectMaterializedIndexWithAggregate 
extends AbstractSelectMaterial
 
     private static class ExprRewriteMap {
         /**
-         * Replace map for expressions in project.
+         * Replace map for expressions in project. For example: the query have 
avg(v),
+         * stddev_samp(v) projectExprMap will contain v -> 
[mva_GENERIC__avg_state(`v`),
+         * mva_GENERIC__stddev_samp_state(CAST(`v` AS DOUBLE))] then some 
LogicalPlan
+         * will output [mva_GENERIC__avg_state(`v`),
+         * mva_GENERIC__stddev_samp_state(CAST(`v` AS DOUBLE))] to replace 
column v
          */
-        public final Map<Expression, Expression> projectExprMap;
+        public final Map<Expression, List<Expression>> projectExprMap;
         /**
          * Replace map for aggregate functions.
          */
@@ -1124,6 +1125,13 @@ public class SelectMaterializedIndexWithAggregate 
extends AbstractSelectMaterial
             buildStrMap();
             return aggFuncStrMap.getOrDefault(e.toSql(), (AggregateFunction) 
e);
         }
+
+        public void putIntoProjectExprMap(Expression key, Expression value) {
+            if (!projectExprMap.containsKey(key)) {
+                projectExprMap.put(key, Lists.newArrayList());
+            }
+            projectExprMap.get(key).add(value);
+        }
     }
 
     private static class AggRewriteResult {
@@ -1199,7 +1207,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                 .orElseThrow(() -> new AnalysisException(
                                         "cannot find bitmap union slot when 
select mv"));
 
-                        
context.exprRewriteMap.projectExprMap.put(slotOpt.get(), bitmapUnionSlot);
+                        
context.exprRewriteMap.putIntoProjectExprMap(slotOpt.get(), bitmapUnionSlot);
                         BitmapUnionCount bitmapUnionCount = new 
BitmapUnionCount(bitmapUnionSlot);
                         context.exprRewriteMap.aggFuncMap.put(count, 
bitmapUnionCount);
                         return bitmapUnionCount;
@@ -1229,7 +1237,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                             .filter(s -> 
countColumn.equalsIgnoreCase(normalizeName(s.getName()))).findFirst()
                             .orElseThrow(() -> new AnalysisException("cannot 
find count slot when select mv"));
 
-                    context.exprRewriteMap.projectExprMap.put(child, 
countSlot);
+                    context.exprRewriteMap.putIntoProjectExprMap(child, 
countSlot);
                     Sum sum = new Sum(countSlot);
                     context.exprRewriteMap.aggFuncMap.put(count, sum);
                     return sum;
@@ -1265,7 +1273,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                 .findFirst().orElseThrow(
                                         () -> new AnalysisException("cannot 
find bitmap union slot when select mv"));
 
-                        context.exprRewriteMap.projectExprMap.put(toBitmap, 
bitmapUnionSlot);
+                        context.exprRewriteMap.putIntoProjectExprMap(toBitmap, 
bitmapUnionSlot);
                         BitmapUnion newBitmapUnion = new 
BitmapUnion(bitmapUnionSlot);
                         context.exprRewriteMap.aggFuncMap.put(bitmapUnion, 
newBitmapUnion);
                         return newBitmapUnion;
@@ -1284,7 +1292,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                             .stream().filter(s -> 
bitmapUnionColumn.equalsIgnoreCase(normalizeName(s.getName())))
                             .findFirst()
                             .orElseThrow(() -> new AnalysisException("cannot 
find bitmap union slot when select mv"));
-                    context.exprRewriteMap.projectExprMap.put(child, 
bitmapUnionSlot);
+                    context.exprRewriteMap.putIntoProjectExprMap(child, 
bitmapUnionSlot);
                     BitmapUnion newBitmapUnion = new 
BitmapUnion(bitmapUnionSlot);
                     context.exprRewriteMap.aggFuncMap.put(bitmapUnion, 
newBitmapUnion);
                     return newBitmapUnion;
@@ -1323,7 +1331,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                 .orElseThrow(() -> new AnalysisException(
                                         "cannot find bitmap union count slot 
when select mv"));
 
-                        context.exprRewriteMap.projectExprMap.put(toBitmap, 
bitmapUnionCountSlot);
+                        context.exprRewriteMap.putIntoProjectExprMap(toBitmap, 
bitmapUnionCountSlot);
                         BitmapUnionCount newBitmapUnionCount = new 
BitmapUnionCount(bitmapUnionCountSlot);
                         
context.exprRewriteMap.aggFuncMap.put(bitmapUnionCount, newBitmapUnionCount);
                         return newBitmapUnionCount;
@@ -1342,7 +1350,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                             .stream().filter(s -> 
bitmapUnionCountColumn.equalsIgnoreCase(normalizeName(s.getName())))
                             .findFirst().orElseThrow(
                                     () -> new AnalysisException("cannot find 
bitmap union count slot when select mv"));
-                    context.exprRewriteMap.projectExprMap.put(child, 
bitmapUnionCountSlot);
+                    context.exprRewriteMap.putIntoProjectExprMap(child, 
bitmapUnionCountSlot);
                     BitmapUnionCount newBitmapUnionCount = new 
BitmapUnionCount(bitmapUnionCountSlot);
                     context.exprRewriteMap.aggFuncMap.put(bitmapUnionCount, 
newBitmapUnionCount);
                     return newBitmapUnionCount;
@@ -1378,7 +1386,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                 .orElseThrow(() -> new AnalysisException(
                                         "cannot find hll union slot when 
select mv"));
 
-                        context.exprRewriteMap.projectExprMap.put(hllHash, 
hllUnionSlot);
+                        context.exprRewriteMap.putIntoProjectExprMap(hllHash, 
hllUnionSlot);
                         HllUnion newHllUnion = new HllUnion(hllUnionSlot);
                         context.exprRewriteMap.aggFuncMap.put(hllUnion, 
newHllUnion);
                         return newHllUnion;
@@ -1415,7 +1423,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                                 .orElseThrow(() -> new AnalysisException(
                                         "cannot find hll union slot when 
select mv"));
 
-                        context.exprRewriteMap.projectExprMap.put(hllHash, 
hllUnionSlot);
+                        context.exprRewriteMap.putIntoProjectExprMap(hllHash, 
hllUnionSlot);
                         HllUnionAgg newHllUnionAgg = new 
HllUnionAgg(hllUnionSlot);
                         context.exprRewriteMap.aggFuncMap.put(hllUnionAgg, 
newHllUnionAgg);
                         return newHllUnionAgg;
@@ -1453,7 +1461,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                             .orElseThrow(() -> new AnalysisException(
                                     "cannot find hll union slot when select 
mv"));
 
-                    context.exprRewriteMap.projectExprMap.put(slotOpt.get(), 
hllUnionSlot);
+                    
context.exprRewriteMap.putIntoProjectExprMap(slotOpt.get(), hllUnionSlot);
                     HllUnionAgg hllUnionAgg = new HllUnionAgg(hllUnionSlot);
                     context.exprRewriteMap.aggFuncMap.put(ndv, hllUnionAgg);
                     return hllUnionAgg;
@@ -1477,7 +1485,7 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                     Slot sumSlot = 
context.checkContext.scan.getOutputByIndex(context.checkContext.index).stream()
                             .filter(s -> 
sumColumn.equalsIgnoreCase(normalizeName(s.getName()))).findFirst()
                             .orElseThrow(() -> new AnalysisException("cannot 
find sum slot when select mv"));
-                    context.exprRewriteMap.projectExprMap.put(sum.child(), 
sumSlot);
+                    context.exprRewriteMap.putIntoProjectExprMap(sum.child(), 
sumSlot);
                     Sum newSum = new Sum(sumSlot);
                     context.exprRewriteMap.aggFuncMap.put(sum, newSum);
                     return newSum;
@@ -1501,9 +1509,8 @@ public class SelectMaterializedIndexWithAggregate extends 
AbstractSelectMaterial
                         .filter(s -> 
aggStateName.equalsIgnoreCase(normalizeName(s.getName()))).findFirst()
                         .orElseThrow(() -> new AnalysisException("cannot find 
agg state slot when select mv"));
 
-                Set<Slot> slots = 
aggregateFunction.collect(SlotReference.class::isInstance);
-                for (Slot slot : slots) {
-                    context.exprRewriteMap.projectExprMap.put(slot, 
aggStateSlot);
+                for (Expression child : aggregateFunction.children()) {
+                    context.exprRewriteMap.putIntoProjectExprMap(child, 
aggStateSlot);
                 }
 
                 MergeCombinator mergeCombinator = new 
MergeCombinator(Arrays.asList(aggStateSlot), aggregateFunction);
@@ -1574,19 +1581,29 @@ public class SelectMaterializedIndexWithAggregate 
extends AbstractSelectMaterial
         }
     }
 
-    private List<NamedExpression> replaceProjectList(
-            LogicalProject<? extends Plan> project,
-            Map<Expression, Expression> projectMap) {
-        return project.getProjects().stream()
-                .map(expr -> (NamedExpression) 
ExpressionUtils.replaceNameExpression(expr, projectMap))
-                .collect(Collectors.toList());
-    }
+    private List<NamedExpression> replaceOutput(List<NamedExpression> outputs,
+            Map<Expression, List<Expression>> projectMap) {
+        Map<String, List<Expression>> strToExprs = Maps.newHashMap();
+        for (Expression expr : projectMap.keySet()) {
+            strToExprs.put(expr.toSql(), projectMap.get(expr));
+        }
 
-    private List<NamedExpression> replaceRepeatOutput(LogicalRepeat<? extends 
Plan> repeat,
-            Map<Expression, Expression> projectMap) {
-        return repeat.getOutputs().stream()
-                .map(expr -> (NamedExpression) 
ExpressionUtils.replaceNameExpression(expr, projectMap))
-                .collect(Collectors.toList());
+        List<NamedExpression> results = Lists.newArrayList();
+        for (NamedExpression expr : outputs) {
+            results.add(expr);
+
+            if (!strToExprs.containsKey(expr.toSql())) {
+                continue;
+            }
+            for (Expression newExpr : strToExprs.get(expr.toSql())) {
+                if (newExpr instanceof NamedExpression) {
+                    results.add((NamedExpression) newExpr);
+                } else {
+                    results.add(new Alias(expr.getExprId(), newExpr, 
expr.getName()));
+                }
+            }
+        }
+        return results;
     }
 
     private List<Expression> nonVirtualGroupByExprs(LogicalAggregate<? extends 
Plan> agg) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
index 3955b2d0f0c..9c5e6b318e8 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java
@@ -45,6 +45,7 @@ import java.util.List;
 import java.util.Map;
 import java.util.Optional;
 import java.util.Set;
+import java.util.stream.Collectors;
 
 /**
  * Util for plan
@@ -91,6 +92,30 @@ public class PlanUtils {
         }
     }
 
+    /**
+     * For the columns whose output exists in grouping sets, they need to be 
assigned as nullable.
+     */
+    public static List<NamedExpression> adjustNullableForRepeat(
+            List<List<Expression>> groupingSets,
+            List<NamedExpression> outputs) {
+        Set<Slot> groupingSetsUsedSlots = groupingSets.stream()
+                .flatMap(Collection::stream)
+                .map(Expression::getInputSlots)
+                .flatMap(Set::stream)
+                .collect(Collectors.toSet());
+        Builder<NamedExpression> nullableOutputs = 
ImmutableList.builderWithExpectedSize(outputs.size());
+        for (NamedExpression output : outputs) {
+            Expression nullableOutput = output.rewriteUp(expr -> {
+                if (expr instanceof Slot && 
groupingSetsUsedSlots.contains(expr)) {
+                    return ((Slot) expr).withNullable(true);
+                }
+                return expr;
+            });
+            nullableOutputs.add((NamedExpression) nullableOutput);
+        }
+        return nullableOutputs.build();
+    }
+
     /**
      * merge childProjects with parentProjects
      */
diff --git 
a/regression-test/data/mv_p0/multi_agg_with_same_slot/multi_agg_with_same_slot.out
 
b/regression-test/data/mv_p0/multi_agg_with_same_slot/multi_agg_with_same_slot.out
new file mode 100644
index 00000000000..05e88880925
--- /dev/null
+++ 
b/regression-test/data/mv_p0/multi_agg_with_same_slot/multi_agg_with_same_slot.out
@@ -0,0 +1,48 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select_star --
+\N     4       \N      d       4
+-4     -4      -4      d       -4
+1      1       1       a       1
+2      2       2       b       2
+3      -3      \N      c       -3
+
+-- !select_mv --
+\N     4       \N      \N
+-4     -4      -4.0    -4
+1      1       1.0     1
+2      2       2.0     2
+3      -3      \N      \N
+
+-- !select_mv --
+\N     4       \N      \N
+-4     -4      -4.0    -4
+1      1       1.0     1
+2      2       2.0     2
+3      -3      \N      \N
+
+-- !select_mv --
+\N     \N      \N      \N
+\N     \N      -0.3333333333333333     2
+\N     4       \N      \N
+-4     \N      -4.0    -4
+-4     -4      -4.0    -4
+1      \N      1.0     1
+1      1       1.0     1
+2      \N      2.0     2
+2      2       2.0     2
+3      \N      \N      \N
+3      -3      \N      \N
+
+-- !select_mv --
+\N     \N      0.0     4
+\N     \N      4.0     4
+\N     4       4.0     4
+-4     \N      -4.0    -4
+-4     -4      -4.0    -4
+1      \N      1.0     1
+1      1       1.0     1
+2      \N      2.0     2
+2      2       2.0     2
+3      \N      -3.0    -3
+3      -3      -3.0    -3
+
diff --git a/regression-test/data/mv_p0/mv_percentile/mv_percentile.out 
b/regression-test/data/mv_p0/mv_percentile/mv_percentile.out
new file mode 100644
index 00000000000..32e5595dac7
--- /dev/null
+++ b/regression-test/data/mv_p0/mv_percentile/mv_percentile.out
@@ -0,0 +1,36 @@
+-- This file is automatically generated. You should know what you did if you 
want to edit this
+-- !select_star --
+\N     4       \N      d
+-4     -4      -4.000000       d
+1      1       1.000000        a
+2      2       2.000000        b
+3      -3      \N      c
+
+-- !select_mv --
+\N     4       \N      \N
+-4     -4      -4.0    -4.0
+1      1       1.0     1.0
+2      2       2.0     2.0
+3      -3      \N      \N
+
+-- !select_mv --
+\N     \N      \N      \N
+\N     \N      -3.0    1.8
+\N     4       \N      \N
+-4     \N      -4.0    -4.0
+-4     -4      -4.0    -4.0
+1      \N      1.0     1.0
+1      1       1.0     1.0
+2      \N      2.0     2.0
+2      2       2.0     2.0
+3      \N      \N      \N
+3      -3      \N      \N
+
+-- !select_mv --
+\N
+\N
+-4.0
+-3.0
+1.0
+2.0
+
diff --git 
a/regression-test/suites/mv_p0/multi_agg_with_same_slot/multi_agg_with_same_slot.groovy
 
b/regression-test/suites/mv_p0/multi_agg_with_same_slot/multi_agg_with_same_slot.groovy
new file mode 100644
index 00000000000..e92147dc51f
--- /dev/null
+++ 
b/regression-test/suites/mv_p0/multi_agg_with_same_slot/multi_agg_with_same_slot.groovy
@@ -0,0 +1,73 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+import org.codehaus.groovy.runtime.IOGroovyMethods
+
+suite ("multi_agg_with_same_slot") {
+    sql "set enable_fallback_to_original_planner = false"
+
+    sql """DROP TABLE IF EXISTS d_table;"""
+
+    sql """
+            create table d_table(
+                k1 int null,
+                k2 int not null,
+                k3 bigint null,
+                k4 varchar(100) null,
+                k5 int not null
+            )
+            duplicate key (k1,k2,k3)
+            distributed BY hash(k1) buckets 3
+            properties("replication_num" = "1");
+        """
+
+    sql "insert into d_table select 1,1,1,'a',1;"
+    sql "insert into d_table select 2,2,2,'b',2;"
+    sql "insert into d_table select 3,-3,null,'c',-3;"
+
+    createMV("create materialized view kmv as select k1,k2,avg(k3),max(k3) 
from d_table group by k1,k2;")
+    createMV("create materialized view kmv2 as select k1,k2,avg(k5),max(k5) 
from d_table group by k1,k2;")
+
+    sql "insert into d_table select -4,-4,-4,'d',-4;"
+    sql "insert into d_table(k4,k2,k5) values('d',4,4);"
+
+    qt_select_star "select * from d_table order by k1;"
+
+    explain {
+        sql("select k1,k2,avg(k3),max(k3) from d_table group by k1,k2 order by 
1,2;")
+        contains "(kmv)"
+    }
+    qt_select_mv "select k1,k2,avg(k3),max(k3) from d_table group by k1,k2 
order by 1,2;"
+
+    explain {
+        sql("select k1,k2,avg(k3)+max(k3) from d_table group by k1,k2 order by 
1,2;")
+        contains "(kmv)"
+    }
+    qt_select_mv "select k1,k2,avg(k3),max(k3) from d_table group by k1,k2 
order by 1,2;"
+
+    explain {
+        sql("select k1,k2,avg(k3)+max(k3) from d_table group by grouping 
sets((k1),(k1,k2),()) order by 1,2;")
+        contains "(kmv)"
+    }
+    qt_select_mv "select k1,k2,avg(k3),max(k3) from d_table group by grouping 
sets((k1),(k1,k2),()) order by 1,2,3;"
+
+    explain {
+        sql("select k1,k2,max(k5) from d_table group by grouping 
sets((k1),(k1,k2),()) order by 1,2;")
+        contains "(kmv2)"
+    }
+    qt_select_mv "select k1,k2,avg(k5),max(k5) from d_table group by grouping 
sets((k1),(k1,k2),()) order by 1,2,3;"
+}
diff --git a/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy 
b/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy
new file mode 100644
index 00000000000..dd6cb453305
--- /dev/null
+++ b/regression-test/suites/mv_p0/mv_percentile/mv_percentile.groovy
@@ -0,0 +1,66 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+import org.codehaus.groovy.runtime.IOGroovyMethods
+
+suite ("mv_percentile") {
+    sql "set enable_fallback_to_original_planner = false"
+
+    sql """DROP TABLE IF EXISTS d_table;"""
+
+    sql """
+            create table d_table(
+                k1 int null,
+                k2 int not null,
+                k3 decimal(28,6) null,
+                k4 varchar(100) null
+            )
+            duplicate key (k1,k2,k3)
+            distributed BY hash(k1) buckets 3
+            properties("replication_num" = "1");
+        """
+
+    sql "insert into d_table select 1,1,1,'a';"
+    sql "insert into d_table select 2,2,2,'b';"
+    sql "insert into d_table select 3,-3,null,'c';"
+
+    createMV("create materialized view kp as select k1,k2,percentile(k3, 
0.1),percentile(k3, 0.9) from d_table group by k1,k2;")
+
+    sql "insert into d_table select -4,-4,-4,'d';"
+    sql "insert into d_table(k4,k2) values('d',4);"
+
+    qt_select_star "select * from d_table order by k1;"
+
+    explain {
+        sql("select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table 
group by k1,k2 order by k1,k2;")
+        contains "(kp)"
+    }
+    qt_select_mv "select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from 
d_table group by k1,k2 order by k1,k2;"
+
+    explain {
+        sql("select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from d_table 
group by grouping sets((k1),(k1,k2),()) order by 1,2;")
+        contains "(kp)"
+    }
+    qt_select_mv "select k1,k2,percentile(k3, 0.1),percentile(k3, 0.9) from 
d_table group by grouping sets((k1),(k1,k2),()) order by 1,2,3;"
+
+
+    explain {
+        sql("select percentile(k3, 0.1) from d_table group by grouping 
sets((k1),()) order by 1;")
+        contains "(kp)"
+    }
+    qt_select_mv "select percentile(k3, 0.1) from d_table group by grouping 
sets((k1),()) order by 1;"
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to