924060929 commented on code in PR #49096:
URL: https://github.com/apache/doris/pull/49096#discussion_r2086315730
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java:
##########
@@ -2088,4 +2107,130 @@ private boolean couldConvertToMulti(LogicalAggregate<?
extends Plan> aggregate)
}
return true;
}
+
+ /**
+ * LogicalAggregate(groupByExpr=[a], outputExpr=[a,count(distinct b)])
+ * ->
+ * +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a,
count(partial_count(m))]
+ * +--PhysicalDistribute(shuffleColumn=[a])
+ * +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a,
partial_count(m)]
+ * +--PhysicalHashAggregate(groupByExpr=[a, saltExpr],
outputExpr=[a, multi_distinct_count(b) as m])
+ * +--PhysicalDistribute(shuffleColumn=[a, saltExpr])
+ * +--PhysicalProject(projects=[a, b, xxhash_32(b)%512 as
saltExpr])
+ * +--PhysicalHashAggregate(groupByExpr=[a, b], outputExpr=[a,
b])
+ * */
+ private PhysicalHashAggregate<Plan>
countDistinctSkewRewrite(LogicalAggregate<GroupPlan> logicalAgg,
+ CascadesContext cascadesContext) {
+ if (!logicalAgg.canSkewRewrite()) {
+ return null;
+ }
+
+ // 1.local agg
+ ImmutableList.Builder<Expression> localAggGroupByBuilder =
ImmutableList.builderWithExpectedSize(
+ logicalAgg.getGroupByExpressions().size() + 1);
+ localAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+ Count count = (Count)
logicalAgg.getAggregateFunctions().iterator().next();
+ if (!(count.child(0) instanceof Slot)) {
+ return null;
+ }
+ localAggGroupByBuilder.add(count.child(0));
+ List<Expression> localAggGroupBy = localAggGroupByBuilder.build();
+ List<NamedExpression> localAggOutput =
Utils.fastToImmutableList((List) localAggGroupBy);
+ RequireProperties requireAny =
RequireProperties.of(PhysicalProperties.ANY);
+ boolean maybeUsingStreamAgg =
maybeUsingStreamAgg(cascadesContext.getConnectContext(),
+ localAggGroupBy);
+ boolean couldBanned = false;
+ AggregateParam localParam = new AggregateParam(AggPhase.LOCAL,
AggMode.INPUT_TO_BUFFER, couldBanned);
+ PhysicalHashAggregate<Plan> localAgg = new
PhysicalHashAggregate<>(localAggGroupBy, localAggOutput,
+ Optional.empty(), localParam, maybeUsingStreamAgg,
Optional.empty(), null,
+ requireAny, logicalAgg.child());
+ // add shuffle expr in project
+ ImmutableList.Builder<NamedExpression> projections =
ImmutableList.builderWithExpectedSize(
+ localAgg.getOutputs().size() + 1);
+ projections.addAll(localAgg.getOutputs());
+ Alias modAlias = getShuffleExpr(count, cascadesContext);
+ projections.add(modAlias);
+ PhysicalProject<Plan> physicalProject = new
PhysicalProject<>(projections.build(), null, localAgg);
+
+ // 2.second phase agg: multi_distinct_count(b) group by a,h
+ ImmutableList.Builder<Expression> secondPhaseAggGroupByBuilder =
ImmutableList.builderWithExpectedSize(
+ logicalAgg.getGroupByExpressions().size() + 1);
+
secondPhaseAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+ secondPhaseAggGroupByBuilder.add(modAlias.toSlot());
+ List<Expression> secondPhaseAggGroupBy =
secondPhaseAggGroupByBuilder.build();
+ ImmutableList.Builder<NamedExpression> secondPhaseAggOutput =
ImmutableList.builderWithExpectedSize(
+ secondPhaseAggGroupBy.size() + 1);
+ secondPhaseAggOutput.addAll((List) secondPhaseAggGroupBy);
+ Alias aliasTarget = new Alias(new TinyIntLiteral((byte) 0));
+ for (NamedExpression ne : logicalAgg.getOutputExpressions()) {
+ if (ne instanceof Alias) {
+ if (((Alias) ne).child().equals(count)) {
+ aliasTarget = (Alias) ne;
+ }
+ }
+ }
+ AggregateParam secondParam = new AggregateParam(AggPhase.GLOBAL,
AggMode.INPUT_TO_RESULT, couldBanned);
+ AggregateFunction multiDistinct = count.convertToMultiDistinct();
+ Alias multiDistinctAlias = new Alias(new
AggregateExpression(multiDistinct, secondParam));
+ secondPhaseAggOutput.add(multiDistinctAlias);
+ List<ExprId> shuffleIds = new ArrayList<>();
+ for (Expression expr : secondPhaseAggGroupBy) {
+ if (expr instanceof Slot) {
+ shuffleIds.add(((Slot) expr).getExprId());
+ }
+ }
+ RequireProperties secondRequireProperties = RequireProperties.of(
+ PhysicalProperties.createHash(shuffleIds,
ShuffleType.REQUIRE));
+ PhysicalHashAggregate<Plan> secondPhaseAgg = new
PhysicalHashAggregate<>(
+ secondPhaseAggGroupBy, secondPhaseAggOutput.build(),
+ Optional.empty(), secondParam, false, Optional.empty(), null,
+ secondRequireProperties, physicalProject);
+
+ // 3. third phase agg
+ List<Expression> thirdPhaseAggGroupBy =
Utils.fastToImmutableList(logicalAgg.getGroupByExpressions());
+ ImmutableList.Builder<NamedExpression> thirdPhaseAggOutput =
ImmutableList.builderWithExpectedSize(
+ thirdPhaseAggGroupBy.size() + 1);
+ thirdPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+ AggregateParam thirdParam = new
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
+ Count thirdCount = new Count(multiDistinctAlias.toSlot());
+ Alias thirdCountAlias = new Alias(new AggregateExpression(thirdCount,
thirdParam));
+ thirdPhaseAggOutput.add(thirdCountAlias);
+ PhysicalHashAggregate<Plan> thirdPhaseAgg = new
PhysicalHashAggregate<>(
+ thirdPhaseAggGroupBy, thirdPhaseAggOutput.build(),
+ Optional.empty(), thirdParam, false, Optional.empty(), null,
+ secondRequireProperties, secondPhaseAgg);
+
+ // 4. fourth phase agg
+ ImmutableList.Builder<NamedExpression> fourthPhaseAggOutput =
ImmutableList.builderWithExpectedSize(
+ thirdPhaseAggGroupBy.size() + 1);
+ fourthPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+ AggregateParam fourthParam = new
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT,
+ couldBanned);
+ Alias sumAliasFour = new Alias(aliasTarget.getExprId(),
+ new AggregateExpression(thirdCount, fourthParam,
thirdCountAlias.toSlot()),
+ aliasTarget.getName());
+ fourthPhaseAggOutput.add(sumAliasFour);
+ List<ExprId> shuffleIdsFour = new ArrayList<>();
+ for (Expression expr : logicalAgg.getExpressions()) {
+ if (expr instanceof Slot) {
+ shuffleIdsFour.add(((Slot) expr).getExprId());
+ }
+ }
+ RequireProperties fourthRequireProperties = RequireProperties.of(
+ PhysicalProperties.createHash(shuffleIdsFour,
ShuffleType.REQUIRE));
+ return new PhysicalHashAggregate<>(thirdPhaseAggGroupBy,
+ fourthPhaseAggOutput.build(), Optional.empty(), fourthParam,
+ false, Optional.empty(), logicalAgg.getLogicalProperties(),
+ fourthRequireProperties, thirdPhaseAgg);
+ }
+
+ private Alias getShuffleExpr(Count count, CascadesContext cascadesContext)
{
+ int bucketNum =
cascadesContext.getConnectContext().getSessionVariable().aggDistinctSkewBucketNum;
+ DataType type = bucketNum <= 256 ? TinyIntType.INSTANCE :
SmallIntType.INSTANCE;
+ int bucket = bucketNum / 2;
Review Comment:
what is the meaning of `bucketNum / 2`
##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/implementation/AggregateStrategies.java:
##########
@@ -2088,4 +2107,130 @@ private boolean couldConvertToMulti(LogicalAggregate<?
extends Plan> aggregate)
}
return true;
}
+
+ /**
+ * LogicalAggregate(groupByExpr=[a], outputExpr=[a,count(distinct b)])
+ * ->
+ * +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a,
count(partial_count(m))]
+ * +--PhysicalDistribute(shuffleColumn=[a])
+ * +--PhysicalHashAggregate(groupByExpr=[a], outputExpr=[a,
partial_count(m)]
+ * +--PhysicalHashAggregate(groupByExpr=[a, saltExpr],
outputExpr=[a, multi_distinct_count(b) as m])
+ * +--PhysicalDistribute(shuffleColumn=[a, saltExpr])
+ * +--PhysicalProject(projects=[a, b, xxhash_32(b)%512 as
saltExpr])
+ * +--PhysicalHashAggregate(groupByExpr=[a, b], outputExpr=[a,
b])
+ * */
+ private PhysicalHashAggregate<Plan>
countDistinctSkewRewrite(LogicalAggregate<GroupPlan> logicalAgg,
+ CascadesContext cascadesContext) {
+ if (!logicalAgg.canSkewRewrite()) {
+ return null;
+ }
+
+ // 1.local agg
+ ImmutableList.Builder<Expression> localAggGroupByBuilder =
ImmutableList.builderWithExpectedSize(
+ logicalAgg.getGroupByExpressions().size() + 1);
+ localAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+ Count count = (Count)
logicalAgg.getAggregateFunctions().iterator().next();
+ if (!(count.child(0) instanceof Slot)) {
+ return null;
+ }
+ localAggGroupByBuilder.add(count.child(0));
+ List<Expression> localAggGroupBy = localAggGroupByBuilder.build();
+ List<NamedExpression> localAggOutput =
Utils.fastToImmutableList((List) localAggGroupBy);
+ RequireProperties requireAny =
RequireProperties.of(PhysicalProperties.ANY);
+ boolean maybeUsingStreamAgg =
maybeUsingStreamAgg(cascadesContext.getConnectContext(),
+ localAggGroupBy);
+ boolean couldBanned = false;
+ AggregateParam localParam = new AggregateParam(AggPhase.LOCAL,
AggMode.INPUT_TO_BUFFER, couldBanned);
+ PhysicalHashAggregate<Plan> localAgg = new
PhysicalHashAggregate<>(localAggGroupBy, localAggOutput,
+ Optional.empty(), localParam, maybeUsingStreamAgg,
Optional.empty(), null,
+ requireAny, logicalAgg.child());
+ // add shuffle expr in project
+ ImmutableList.Builder<NamedExpression> projections =
ImmutableList.builderWithExpectedSize(
+ localAgg.getOutputs().size() + 1);
+ projections.addAll(localAgg.getOutputs());
+ Alias modAlias = getShuffleExpr(count, cascadesContext);
+ projections.add(modAlias);
+ PhysicalProject<Plan> physicalProject = new
PhysicalProject<>(projections.build(), null, localAgg);
+
+ // 2.second phase agg: multi_distinct_count(b) group by a,h
+ ImmutableList.Builder<Expression> secondPhaseAggGroupByBuilder =
ImmutableList.builderWithExpectedSize(
+ logicalAgg.getGroupByExpressions().size() + 1);
+
secondPhaseAggGroupByBuilder.addAll(logicalAgg.getGroupByExpressions());
+ secondPhaseAggGroupByBuilder.add(modAlias.toSlot());
+ List<Expression> secondPhaseAggGroupBy =
secondPhaseAggGroupByBuilder.build();
+ ImmutableList.Builder<NamedExpression> secondPhaseAggOutput =
ImmutableList.builderWithExpectedSize(
+ secondPhaseAggGroupBy.size() + 1);
+ secondPhaseAggOutput.addAll((List) secondPhaseAggGroupBy);
+ Alias aliasTarget = new Alias(new TinyIntLiteral((byte) 0));
+ for (NamedExpression ne : logicalAgg.getOutputExpressions()) {
+ if (ne instanceof Alias) {
+ if (((Alias) ne).child().equals(count)) {
+ aliasTarget = (Alias) ne;
+ }
+ }
+ }
+ AggregateParam secondParam = new AggregateParam(AggPhase.GLOBAL,
AggMode.INPUT_TO_RESULT, couldBanned);
+ AggregateFunction multiDistinct = count.convertToMultiDistinct();
+ Alias multiDistinctAlias = new Alias(new
AggregateExpression(multiDistinct, secondParam));
+ secondPhaseAggOutput.add(multiDistinctAlias);
+ List<ExprId> shuffleIds = new ArrayList<>();
+ for (Expression expr : secondPhaseAggGroupBy) {
+ if (expr instanceof Slot) {
+ shuffleIds.add(((Slot) expr).getExprId());
+ }
+ }
+ RequireProperties secondRequireProperties = RequireProperties.of(
+ PhysicalProperties.createHash(shuffleIds,
ShuffleType.REQUIRE));
+ PhysicalHashAggregate<Plan> secondPhaseAgg = new
PhysicalHashAggregate<>(
+ secondPhaseAggGroupBy, secondPhaseAggOutput.build(),
+ Optional.empty(), secondParam, false, Optional.empty(), null,
+ secondRequireProperties, physicalProject);
+
+ // 3. third phase agg
+ List<Expression> thirdPhaseAggGroupBy =
Utils.fastToImmutableList(logicalAgg.getGroupByExpressions());
+ ImmutableList.Builder<NamedExpression> thirdPhaseAggOutput =
ImmutableList.builderWithExpectedSize(
+ thirdPhaseAggGroupBy.size() + 1);
+ thirdPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+ AggregateParam thirdParam = new
AggregateParam(AggPhase.DISTINCT_LOCAL, AggMode.INPUT_TO_BUFFER, couldBanned);
+ Count thirdCount = new Count(multiDistinctAlias.toSlot());
+ Alias thirdCountAlias = new Alias(new AggregateExpression(thirdCount,
thirdParam));
+ thirdPhaseAggOutput.add(thirdCountAlias);
+ PhysicalHashAggregate<Plan> thirdPhaseAgg = new
PhysicalHashAggregate<>(
+ thirdPhaseAggGroupBy, thirdPhaseAggOutput.build(),
+ Optional.empty(), thirdParam, false, Optional.empty(), null,
+ secondRequireProperties, secondPhaseAgg);
+
+ // 4. fourth phase agg
+ ImmutableList.Builder<NamedExpression> fourthPhaseAggOutput =
ImmutableList.builderWithExpectedSize(
+ thirdPhaseAggGroupBy.size() + 1);
+ fourthPhaseAggOutput.addAll((List) thirdPhaseAggGroupBy);
+ AggregateParam fourthParam = new
AggregateParam(AggPhase.DISTINCT_GLOBAL, AggMode.BUFFER_TO_RESULT,
+ couldBanned);
+ Alias sumAliasFour = new Alias(aliasTarget.getExprId(),
+ new AggregateExpression(thirdCount, fourthParam,
thirdCountAlias.toSlot()),
+ aliasTarget.getName());
+ fourthPhaseAggOutput.add(sumAliasFour);
+ List<ExprId> shuffleIdsFour = new ArrayList<>();
+ for (Expression expr : logicalAgg.getExpressions()) {
+ if (expr instanceof Slot) {
+ shuffleIdsFour.add(((Slot) expr).getExprId());
+ }
+ }
+ RequireProperties fourthRequireProperties = RequireProperties.of(
+ PhysicalProperties.createHash(shuffleIdsFour,
ShuffleType.REQUIRE));
+ return new PhysicalHashAggregate<>(thirdPhaseAggGroupBy,
+ fourthPhaseAggOutput.build(), Optional.empty(), fourthParam,
+ false, Optional.empty(), logicalAgg.getLogicalProperties(),
+ fourthRequireProperties, thirdPhaseAgg);
+ }
+
+ private Alias getShuffleExpr(Count count, CascadesContext cascadesContext)
{
+ int bucketNum =
cascadesContext.getConnectContext().getSessionVariable().aggDistinctSkewBucketNum;
+ DataType type = bucketNum <= 256 ? TinyIntType.INSTANCE :
SmallIntType.INSTANCE;
Review Comment:
should you prevent the bucketNum when >= 65535 or < 0?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]