924060929 commented on code in PR #49096:
URL: https://github.com/apache/doris/pull/49096#discussion_r2086315730


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java:
##########
@@ -2088,4 +2107,130 @@ private boolean couldConvertToMulti(LogicalAggregate<? 
extends Plan> aggregate)
         }
         return true;
     }
+
+    /**
+     * LogicalAggregate(groupByExpr=[a], outputExpr=[a,count(distinct b)])
+     * ->
+     * +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a, 
count(partial_count(m))]
+     *   +--PhysicalDistribute(shuffleColumn=[a])
+     *     +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a, 
partial_count(m)]
+     *       +--PhysicalHashAggregate(groupByExpr=[a, saltExpr], 
outputExpr=[a, multi_distinct_count(b) as m])
+     *         +--PhysicalDistribute(shuffleColumn=[a, saltExpr])
+     *           +--PhysicalProject(projects=[a, b, xxhash_32(b)%512 as 
saltExpr])
+     *             +--PhysicalHashAggregate(groupByExpr=[a, b], outputExpr=[a, 
b])
+     * */
+    private PhysicalHashAggregate<Plan> 
countDistinctSkewRewrite(LogicalAggregate<GroupPlan> logicalAgg,
+            CascadesContext cascadesContext) {
+        if (!logicalAgg.canSkewRewrite()) {
+            return null;
+        }
+
+        // 1.local agg
+        ImmutableList.Builder<Expression> localAggGroupByBuilder = 
ImmutableList.builderWithExpectedSize(
+                logicalAgg.getGroupByExpressions().size() + 1);
+        localAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+        Count count = (Count) 
logicalAgg.getAggregateFunctions().iterator().next();
+        if (!(count.child(0) instanceof Slot)) {
+            return null;
+        }
+        localAggGroupByBuilder.add(count.child(0));
+        List<Expression> localAggGroupBy = localAggGroupByBuilder.build();
+        List<NamedExpression> localAggOutput = 
Utils.fastToImmutableList((List) localAggGroupBy);
+        RequireProperties requireAny = 
RequireProperties.of(PhysicalProperties.ANY);
+        boolean maybeUsingStreamAgg = 
maybeUsingStreamAgg(cascadesContext.getConnectContext(),
+                localAggGroupBy);
+        boolean couldBanned = false;
+        AggregateParam localParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER, couldBanned);
+        PhysicalHashAggregate<Plan> localAgg = new 
PhysicalHashAggregate<>(localAggGroupBy, localAggOutput,
+                Optional.empty(), localParam, maybeUsingStreamAgg, 
Optional.empty(), null,
+                requireAny, logicalAgg.child());
+        // add shuffle expr in project
+        ImmutableList.Builder<NamedExpression> projections = 
ImmutableList.builderWithExpectedSize(
+                localAgg.getOutputs().size() + 1);
+        projections.addAll(localAgg.getOutputs());
+        Alias modAlias = getShuffleExpr(count, cascadesContext);
+        projections.add(modAlias);
+        PhysicalProject<Plan> physicalProject = new 
PhysicalProject<>(projections.build(), null, localAgg);
+
+        // 2.second phase agg: multi_distinct_count(b) group by a,h
+        ImmutableList.Builder<Expression> secondPhaseAggGroupByBuilder = 
ImmutableList.builderWithExpectedSize(
+                logicalAgg.getGroupByExpressions().size() + 1);
+        
secondPhaseAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+        secondPhaseAggGroupByBuilder.add(modAlias.toSlot());
+        List<Expression> secondPhaseAggGroupBy = 
secondPhaseAggGroupByBuilder.build();
+        ImmutableList.Builder<NamedExpression> secondPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                secondPhaseAggGroupBy.size() + 1);
+        secondPhaseAggOutput.addAll((List) secondPhaseAggGroupBy);
+        Alias aliasTarget = new Alias(new TinyIntLiteral((byte) 0));
+        for (NamedExpression ne : logicalAgg.getOutputExpressions()) {
+            if (ne instanceof Alias) {
+                if (((Alias) ne).child().equals(count)) {
+                    aliasTarget = (Alias) ne;
+                }
+            }
+        }
+        AggregateParam secondParam = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_RESULT, couldBanned);
+        AggregateFunction multiDistinct = count.convertToMultiDistinct();
+        Alias multiDistinctAlias = new Alias(new 
AggregateExpression(multiDistinct, secondParam));
+        secondPhaseAggOutput.add(multiDistinctAlias);
+        List<ExprId> shuffleIds = new ArrayList<>();
+        for (Expression expr : secondPhaseAggGroupBy) {
+            if (expr instanceof Slot) {
+                shuffleIds.add(((Slot) expr).getExprId());
+            }
+        }
+        RequireProperties secondRequireProperties = RequireProperties.of(
+                PhysicalProperties.createHash(shuffleIds, 
ShuffleType.REQUIRE));
+        PhysicalHashAggregate<Plan> secondPhaseAgg = new 
PhysicalHashAggregate<>(
+                secondPhaseAggGroupBy, secondPhaseAggOutput.build(),
+                Optional.empty(), secondParam, false, Optional.empty(), null,
+                secondRequireProperties, physicalProject);
+
+        // 3. third phase agg
+        List<Expression> thirdPhaseAggGroupBy = 
Utils.fastToImmutableList(logicalAgg.getGroupByExpressions());
+        ImmutableList.Builder<NamedExpression> thirdPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                thirdPhaseAggGroupBy.size() + 1);
+        thirdPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+        AggregateParam thirdParam = new 
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
+        Count thirdCount = new Count(multiDistinctAlias.toSlot());
+        Alias thirdCountAlias = new Alias(new AggregateExpression(thirdCount, 
thirdParam));
+        thirdPhaseAggOutput.add(thirdCountAlias);
+        PhysicalHashAggregate<Plan> thirdPhaseAgg = new 
PhysicalHashAggregate<>(
+                thirdPhaseAggGroupBy, thirdPhaseAggOutput.build(),
+                Optional.empty(), thirdParam, false, Optional.empty(), null,
+                secondRequireProperties, secondPhaseAgg);
+
+        // 4. fourth phase agg
+        ImmutableList.Builder<NamedExpression> fourthPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                thirdPhaseAggGroupBy.size() + 1);
+        fourthPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+        AggregateParam fourthParam = new 
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT,
+                couldBanned);
+        Alias sumAliasFour = new Alias(aliasTarget.getExprId(),
+                new AggregateExpression(thirdCount, fourthParam, 
thirdCountAlias.toSlot()),
+                aliasTarget.getName());
+        fourthPhaseAggOutput.add(sumAliasFour);
+        List<ExprId> shuffleIdsFour = new ArrayList<>();
+        for (Expression expr : logicalAgg.getExpressions()) {
+            if (expr instanceof Slot) {
+                shuffleIdsFour.add(((Slot) expr).getExprId());
+            }
+        }
+        RequireProperties fourthRequireProperties = RequireProperties.of(
+                PhysicalProperties.createHash(shuffleIdsFour, 
ShuffleType.REQUIRE));
+        return new PhysicalHashAggregate<>(thirdPhaseAggGroupBy,
+                fourthPhaseAggOutput.build(), Optional.empty(), fourthParam,
+                false, Optional.empty(), logicalAgg.getLogicalProperties(),
+                fourthRequireProperties, thirdPhaseAgg);
+    }
+
+    private Alias getShuffleExpr(Count count, CascadesContext cascadesContext) 
{
+        int bucketNum = 
cascadesContext.getConnectContext().getSessionVariable().aggDistinctSkewBucketNum;
+        DataType type = bucketNum <= 256 ? TinyIntType.INSTANCE : 
SmallIntType.INSTANCE;
+        int bucket = bucketNum / 2;

Review Comment:
   what is the meaning of `bucketNum / 2`



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java:
##########
@@ -2088,4 +2107,130 @@ private boolean couldConvertToMulti(LogicalAggregate<? 
extends Plan> aggregate)
         }
         return true;
     }
+
+    /**
+     * LogicalAggregate(groupByExpr=[a], outputExpr=[a,count(distinct b)])
+     * ->
+     * +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a, 
count(partial_count(m))]
+     *   +--PhysicalDistribute(shuffleColumn=[a])
+     *     +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a, 
partial_count(m)]
+     *       +--PhysicalHashAggregate(groupByExpr=[a, saltExpr], 
outputExpr=[a, multi_distinct_count(b) as m])
+     *         +--PhysicalDistribute(shuffleColumn=[a, saltExpr])
+     *           +--PhysicalProject(projects=[a, b, xxhash_32(b)%512 as 
saltExpr])
+     *             +--PhysicalHashAggregate(groupByExpr=[a, b], outputExpr=[a, 
b])
+     * */
+    private PhysicalHashAggregate<Plan> 
countDistinctSkewRewrite(LogicalAggregate<GroupPlan> logicalAgg,
+            CascadesContext cascadesContext) {
+        if (!logicalAgg.canSkewRewrite()) {
+            return null;
+        }
+
+        // 1.local agg
+        ImmutableList.Builder<Expression> localAggGroupByBuilder = 
ImmutableList.builderWithExpectedSize(
+                logicalAgg.getGroupByExpressions().size() + 1);
+        localAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+        Count count = (Count) 
logicalAgg.getAggregateFunctions().iterator().next();
+        if (!(count.child(0) instanceof Slot)) {
+            return null;
+        }
+        localAggGroupByBuilder.add(count.child(0));
+        List<Expression> localAggGroupBy = localAggGroupByBuilder.build();
+        List<NamedExpression> localAggOutput = 
Utils.fastToImmutableList((List) localAggGroupBy);
+        RequireProperties requireAny = 
RequireProperties.of(PhysicalProperties.ANY);
+        boolean maybeUsingStreamAgg = 
maybeUsingStreamAgg(cascadesContext.getConnectContext(),
+                localAggGroupBy);
+        boolean couldBanned = false;
+        AggregateParam localParam = new AggregateParam(AggPhase.LOCAL, 
AggMode.INPUT_TO_BUFFER, couldBanned);
+        PhysicalHashAggregate<Plan> localAgg = new 
PhysicalHashAggregate<>(localAggGroupBy, localAggOutput,
+                Optional.empty(), localParam, maybeUsingStreamAgg, 
Optional.empty(), null,
+                requireAny, logicalAgg.child());
+        // add shuffle expr in project
+        ImmutableList.Builder<NamedExpression> projections = 
ImmutableList.builderWithExpectedSize(
+                localAgg.getOutputs().size() + 1);
+        projections.addAll(localAgg.getOutputs());
+        Alias modAlias = getShuffleExpr(count, cascadesContext);
+        projections.add(modAlias);
+        PhysicalProject<Plan> physicalProject = new 
PhysicalProject<>(projections.build(), null, localAgg);
+
+        // 2.second phase agg: multi_distinct_count(b) group by a,h
+        ImmutableList.Builder<Expression> secondPhaseAggGroupByBuilder = 
ImmutableList.builderWithExpectedSize(
+                logicalAgg.getGroupByExpressions().size() + 1);
+        
secondPhaseAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+        secondPhaseAggGroupByBuilder.add(modAlias.toSlot());
+        List<Expression> secondPhaseAggGroupBy = 
secondPhaseAggGroupByBuilder.build();
+        ImmutableList.Builder<NamedExpression> secondPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                secondPhaseAggGroupBy.size() + 1);
+        secondPhaseAggOutput.addAll((List) secondPhaseAggGroupBy);
+        Alias aliasTarget = new Alias(new TinyIntLiteral((byte) 0));
+        for (NamedExpression ne : logicalAgg.getOutputExpressions()) {
+            if (ne instanceof Alias) {
+                if (((Alias) ne).child().equals(count)) {
+                    aliasTarget = (Alias) ne;
+                }
+            }
+        }
+        AggregateParam secondParam = new AggregateParam(AggPhase.GLOBAL, 
AggMode.INPUT_TO_RESULT, couldBanned);
+        AggregateFunction multiDistinct = count.convertToMultiDistinct();
+        Alias multiDistinctAlias = new Alias(new 
AggregateExpression(multiDistinct, secondParam));
+        secondPhaseAggOutput.add(multiDistinctAlias);
+        List<ExprId> shuffleIds = new ArrayList<>();
+        for (Expression expr : secondPhaseAggGroupBy) {
+            if (expr instanceof Slot) {
+                shuffleIds.add(((Slot) expr).getExprId());
+            }
+        }
+        RequireProperties secondRequireProperties = RequireProperties.of(
+                PhysicalProperties.createHash(shuffleIds, 
ShuffleType.REQUIRE));
+        PhysicalHashAggregate<Plan> secondPhaseAgg = new 
PhysicalHashAggregate<>(
+                secondPhaseAggGroupBy, secondPhaseAggOutput.build(),
+                Optional.empty(), secondParam, false, Optional.empty(), null,
+                secondRequireProperties, physicalProject);
+
+        // 3. third phase agg
+        List<Expression> thirdPhaseAggGroupBy = 
Utils.fastToImmutableList(logicalAgg.getGroupByExpressions());
+        ImmutableList.Builder<NamedExpression> thirdPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                thirdPhaseAggGroupBy.size() + 1);
+        thirdPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+        AggregateParam thirdParam = new 
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
+        Count thirdCount = new Count(multiDistinctAlias.toSlot());
+        Alias thirdCountAlias = new Alias(new AggregateExpression(thirdCount, 
thirdParam));
+        thirdPhaseAggOutput.add(thirdCountAlias);
+        PhysicalHashAggregate<Plan> thirdPhaseAgg = new 
PhysicalHashAggregate<>(
+                thirdPhaseAggGroupBy, thirdPhaseAggOutput.build(),
+                Optional.empty(), thirdParam, false, Optional.empty(), null,
+                secondRequireProperties, secondPhaseAgg);
+
+        // 4. fourth phase agg
+        ImmutableList.Builder<NamedExpression> fourthPhaseAggOutput = 
ImmutableList.builderWithExpectedSize(
+                thirdPhaseAggGroupBy.size() + 1);
+        fourthPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+        AggregateParam fourthParam = new 
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT,
+                couldBanned);
+        Alias sumAliasFour = new Alias(aliasTarget.getExprId(),
+                new AggregateExpression(thirdCount, fourthParam, 
thirdCountAlias.toSlot()),
+                aliasTarget.getName());
+        fourthPhaseAggOutput.add(sumAliasFour);
+        List<ExprId> shuffleIdsFour = new ArrayList<>();
+        for (Expression expr : logicalAgg.getExpressions()) {
+            if (expr instanceof Slot) {
+                shuffleIdsFour.add(((Slot) expr).getExprId());
+            }
+        }
+        RequireProperties fourthRequireProperties = RequireProperties.of(
+                PhysicalProperties.createHash(shuffleIdsFour, 
ShuffleType.REQUIRE));
+        return new PhysicalHashAggregate<>(thirdPhaseAggGroupBy,
+                fourthPhaseAggOutput.build(), Optional.empty(), fourthParam,
+                false, Optional.empty(), logicalAgg.getLogicalProperties(),
+                fourthRequireProperties, thirdPhaseAgg);
+    }
+
+    private Alias getShuffleExpr(Count count, CascadesContext cascadesContext) 
{
+        int bucketNum = 
cascadesContext.getConnectContext().getSessionVariable().aggDistinctSkewBucketNum;
+        DataType type = bucketNum <= 256 ? TinyIntType.INSTANCE : 
SmallIntType.INSTANCE;

Review Comment:
   should you prevent the bucketNum when >= 65535 or < 0?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to