This is an automated email from the ASF dual-hosted git repository.
jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push:
new ceb7b60a64 [fix](Nereids) update immutable LogicalAggregate attribute
by mistake (#13740)
ceb7b60a64 is described below
commit ceb7b60a64725b430f2e124a382a05032fa115ff
Author: jakevin <[email protected]>
AuthorDate: Mon Oct 31 14:11:55 2022 +0800
[fix](Nereids) update immutable LogicalAggregate attribute by mistake
(#13740)
---
.../apache/doris/nereids/memo/GroupExpression.java | 12 +-
.../java/org/apache/doris/nereids/memo/Memo.java | 13 +-
.../rules/rewrite/AggregateDisassemble.java | 3 +-
.../trees/plans/logical/LogicalAggregate.java | 4 +-
.../org/apache/doris/nereids/memo/MemoTest.java | 6 +-
.../rewrite/logical/AggregateDisassembleTest.java | 282 +++++++++------------
.../doris/nereids/stats/StatsCalculatorTest.java | 12 +-
7 files changed, 146 insertions(+), 186 deletions(-)
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
index 9e95f9e8d7..92067f607d 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java
@@ -27,6 +27,7 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.statistics.StatsDeriveResult;
import com.google.common.base.Preconditions;
+import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
@@ -43,7 +44,7 @@ public class GroupExpression {
private double cost = 0.0;
private CostEstimate costEstimate = null;
private Group ownerGroup;
- private List<Group> children;
+ private ImmutableList<Group> children;
private final Plan plan;
private final BitSet ruleMasks;
private boolean statDerived;
@@ -66,7 +67,7 @@ public class GroupExpression {
public GroupExpression(Plan plan, List<Group> children) {
this.plan = Objects.requireNonNull(plan, "plan can not be null")
.withGroupExpression(Optional.of(this));
- this.children = Lists.newArrayList(Objects.requireNonNull(children,
"children can not be null"));
+ this.children = ImmutableList.copyOf(Objects.requireNonNull(children,
"children can not be null"));
this.ruleMasks = new BitSet(RuleType.SENTINEL.ordinal());
this.statDerived = false;
this.lowestCostTable = Maps.newHashMap();
@@ -84,10 +85,6 @@ public class GroupExpression {
return children.size();
}
- public void addChild(Group child) {
- children.add(child);
- }
-
public Group getOwnerGroup() {
return ownerGroup;
}
@@ -108,12 +105,13 @@ public class GroupExpression {
return children;
}
- public void setChildren(List<Group> children) {
+ public void setChildren(ImmutableList<Group> children) {
this.children = children;
}
/**
* replaceChild.
+ *
* @param originChild origin child group
* @param newChild new child group
*/
diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
index e9b9bbfce9..f5616a71c7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java
@@ -203,7 +203,7 @@ public class Memo {
/**
* add or replace the plan into the target group.
- *
+ * <p>
* the result truth table:
* <pre>
*
+---------------------------------------+-----------------------------------+--------------------------------+
@@ -296,8 +296,7 @@ public class Memo {
}
}
plan = replaceChildrenToGroupPlan(plan, childrenGroups);
- GroupExpression newGroupExpression = new GroupExpression(plan);
- newGroupExpression.setChildren(childrenGroups);
+ GroupExpression newGroupExpression = new GroupExpression(plan,
childrenGroups);
return insertGroupExpression(newGroupExpression, targetGroup,
plan.getLogicalProperties());
// TODO: need to derive logical property if generate new group.
currently we not copy logical plan into
}
@@ -388,13 +387,15 @@ public class Memo {
}
for (GroupExpression groupExpression : needReplaceChild) {
groupExpressions.remove(groupExpression);
- List<Group> children = groupExpression.children();
+ List<Group> children = new ArrayList<>(groupExpression.children());
// TODO: use a better way to replace child, avoid traversing all
groupExpression
for (int i = 0; i < children.size(); i++) {
if (children.get(i).equals(source)) {
children.set(i, destination);
}
}
+ groupExpression.setChildren(ImmutableList.copyOf(children));
+
GroupExpression that = groupExpressions.get(groupExpression);
if (that != null && that.getOwnerGroup() != null
&&
!that.getOwnerGroup().equals(groupExpression.getOwnerGroup())) {
@@ -487,14 +488,14 @@ public class Memo {
/**
* eliminate fromGroup, clear targetGroup, then move the logical group
expressions in the fromGroup to the toGroup.
- *
+ * <p>
* the scenario is:
* ```
* Group 1(project, the targetGroup) Group
1(logicalOlapScan, the targetGroup)
* | =>
* Group 0(logicalOlapScan, the fromGroup)
* ```
- *
+ * <p>
* we should recycle the group 0, and recycle all group expressions in
group 1, then move the logicalOlapScan to
* the group 1, and reset logical properties of the group 1.
*/
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
index 911b6735ac..3deb794412 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/AggregateDisassemble.java
@@ -34,6 +34,7 @@ import com.google.common.base.Preconditions;
import com.google.common.collect.Lists;
import com.google.common.collect.Maps;
+import java.util.ArrayList;
import java.util.List;
import java.util.Map;
import java.util.Optional;
@@ -147,7 +148,7 @@ public class AggregateDisassemble extends
OneRewriteRuleFactory {
//
+-----------+---------------------+-------------------------+--------------------------------+
// NOTICE: Ref: SlotReference, A: Alias, AF: AggregateFunction, #x:
ExprId x
// 2. collect local aggregate output expressions and local aggregate
group by expression list
- List<Expression> localGroupByExprs = aggregate.getGroupByExpressions();
+ List<Expression> localGroupByExprs = new
ArrayList<>(aggregate.getGroupByExpressions());
List<NamedExpression> localOutputExprs = Lists.newArrayList();
for (Expression originGroupByExpr : originGroupByExprs) {
if (inputSubstitutionMap.containsKey(originGroupByExpr)) {
diff --git
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
index 3dfd2ab06d..626c74dab9 100644
---
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
+++
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalAggregate.java
@@ -119,8 +119,8 @@ public class LogicalAggregate<CHILD_TYPE extends Plan>
extends LogicalUnary<CHIL
Optional<LogicalProperties> logicalProperties,
CHILD_TYPE child) {
super(PlanType.LOGICAL_AGGREGATE, groupExpression, logicalProperties,
child);
- this.groupByExpressions = groupByExpressions;
- this.outputExpressions = outputExpressions;
+ this.groupByExpressions = ImmutableList.copyOf(groupByExpressions);
+ this.outputExpressions = ImmutableList.copyOf(outputExpressions);
this.partitionExpressions = partitionExpressions;
this.disassembled = disassembled;
this.normalized = normalized;
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
index 9ad9bf8d16..f9b04d9758 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java
@@ -71,7 +71,7 @@ class MemoTest implements PatternMatchSupported {
JoinType.INNER_JOIN, logicalJoinAB,
PlanConstructor.newLogicalOlapScan(2, "C", 0));
@Test
- void mergeGroup() throws Exception {
+ void mergeGroup() {
Memo memo = new Memo();
GroupId gid2 = new GroupId(2);
Group srcGroup = new Group(gid2, new GroupExpression(new FakePlan()),
new LogicalProperties(ArrayList::new));
@@ -85,13 +85,13 @@ class MemoTest implements PatternMatchSupported {
GroupExpression ge2 = new GroupExpression(d, Arrays.asList(dstGroup));
GroupId gid1 = new GroupId(1);
Group g2 = new Group(gid1, ge2, new LogicalProperties(ArrayList::new));
- Map<GroupId, Group> groups = (Map<GroupId, Group>)
Deencapsulation.getField(memo, "groups");
+ Map<GroupId, Group> groups = Deencapsulation.getField(memo, "groups");
groups.put(gid2, srcGroup);
groups.put(gid3, dstGroup);
groups.put(gid0, g1);
groups.put(gid1, g2);
Map<GroupExpression, GroupExpression> groupExpressions =
- (Map<GroupExpression, GroupExpression>)
Deencapsulation.getField(memo, "groupExpressions");
+ Deencapsulation.getField(memo, "groupExpressions");
groupExpressions.put(ge1, ge1);
groupExpressions.put(ge2, ge2);
memo.mergeGroup(srcGroup, dstGroup);
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
index 6fa37387b9..d92d6efb4e 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/AggregateDisassembleTest.java
@@ -22,7 +22,6 @@ import org.apache.doris.nereids.trees.expressions.Add;
import org.apache.doris.nereids.trees.expressions.Alias;
import org.apache.doris.nereids.trees.expressions.Expression;
import org.apache.doris.nereids.trees.expressions.NamedExpression;
-import org.apache.doris.nereids.trees.expressions.SlotReference;
import org.apache.doris.nereids.trees.expressions.functions.agg.Count;
import org.apache.doris.nereids.trees.expressions.functions.agg.Sum;
import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral;
@@ -31,14 +30,13 @@ import org.apache.doris.nereids.trees.plans.Plan;
import org.apache.doris.nereids.trees.plans.RelationId;
import org.apache.doris.nereids.trees.plans.logical.LogicalAggregate;
import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan;
-import org.apache.doris.nereids.trees.plans.logical.LogicalUnary;
+import org.apache.doris.nereids.util.MemoTestUtils;
+import org.apache.doris.nereids.util.PatternMatchSupported;
+import org.apache.doris.nereids.util.PlanChecker;
import org.apache.doris.nereids.util.PlanConstructor;
-import org.apache.doris.nereids.util.PlanRewriter;
-import org.apache.doris.qe.ConnectContext;
import com.google.common.collect.ImmutableList;
import com.google.common.collect.Lists;
-import org.junit.jupiter.api.Assertions;
import org.junit.jupiter.api.BeforeAll;
import org.junit.jupiter.api.Test;
import org.junit.jupiter.api.TestInstance;
@@ -46,15 +44,17 @@ import org.junit.jupiter.api.TestInstance;
import java.util.List;
@TestInstance(TestInstance.Lifecycle.PER_CLASS)
-public class AggregateDisassembleTest {
+public class AggregateDisassembleTest implements PatternMatchSupported {
private Plan rStudent;
@BeforeAll
public final void beforeAll() {
- rStudent = new
LogicalOlapScan(RelationId.createGenerator().getNextId(),
PlanConstructor.student, ImmutableList.of(""));
+ rStudent = new
LogicalOlapScan(RelationId.createGenerator().getNextId(),
PlanConstructor.student,
+ ImmutableList.of(""));
}
/**
+ * <pre>
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [age, SUM(id) as sum],
groupByExpr: [age])
* +--childPlan(id, name, age)
@@ -62,6 +62,7 @@ public class AggregateDisassembleTest {
* Aggregate(phase: [GLOBAL], outputExpr: [a, SUM(b) as c], groupByExpr:
[a])
* +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b],
groupByExpr: [age])
* +--childPlan(id, name, age)
+ * </pre>
*/
@Test
public void slotReferenceGroupBy() {
@@ -70,50 +71,43 @@ public class AggregateDisassembleTest {
List<NamedExpression> outputExpressionList = Lists.newArrayList(
rStudent.getOutput().get(2).toSlot(),
new Alias(new Sum(rStudent.getOutput().get(0).toSlot()),
"sum"));
- Plan root = new LogicalAggregate(groupExpressionList,
outputExpressionList, rStudent);
-
- Plan after = rewrite(root);
-
- Assertions.assertTrue(after instanceof LogicalUnary);
- Assertions.assertTrue(after instanceof LogicalAggregate);
- Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
- LogicalAggregate<Plan> global = (LogicalAggregate) after;
- LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
- Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
- Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+ Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
Expression localOutput1 = new
Sum(rStudent.getOutput().get(0).toSlot());
Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
- Assertions.assertEquals(2, local.getOutputExpressions().size());
- Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof
SlotReference);
- Assertions.assertEquals(localOutput0,
local.getOutputExpressions().get(0));
- Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof
Alias);
- Assertions.assertEquals(localOutput1,
local.getOutputExpressions().get(1).child(0));
- Assertions.assertEquals(1, local.getGroupByExpressions().size());
- Assertions.assertEquals(localGroupBy,
local.getGroupByExpressions().get(0));
-
- Expression globalOutput0 =
local.getOutputExpressions().get(0).toSlot();
- Expression globalOutput1 = new
Sum(local.getOutputExpressions().get(1).toSlot());
- Expression globalGroupBy =
local.getOutputExpressions().get(0).toSlot();
-
- Assertions.assertEquals(2, global.getOutputExpressions().size());
- Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof
SlotReference);
- Assertions.assertEquals(globalOutput0,
global.getOutputExpressions().get(0));
- Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof
Alias);
- Assertions.assertEquals(globalOutput1,
global.getOutputExpressions().get(1).child(0));
- Assertions.assertEquals(1, global.getGroupByExpressions().size());
- Assertions.assertEquals(globalGroupBy,
global.getGroupByExpressions().get(0));
-
- // check id:
- Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
- global.getOutputExpressions().get(0).getExprId());
- Assertions.assertEquals(outputExpressionList.get(1).getExprId(),
- global.getOutputExpressions().get(1).getExprId());
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new AggregateDisassemble())
+ .printlnTree()
+ .matchesFromRoot(
+ logicalAggregate(
+ logicalAggregate()
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.LOCAL))
+ .when(agg ->
agg.getOutputExpressions().size() == 2)
+ .when(agg ->
agg.getOutputExpressions().get(0).equals(localOutput0))
+ .when(agg ->
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
+ .when(agg ->
agg.getGroupByExpressions().size() == 1)
+ .when(agg ->
agg.getGroupByExpressions().get(0).equals(localGroupBy))
+ ).when(agg ->
agg.getAggPhase().equals(AggPhase.GLOBAL))
+ .when(agg -> agg.getOutputExpressions().size()
== 2)
+ .when(agg -> agg.getOutputExpressions().get(0)
+
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
+ .when(agg ->
agg.getOutputExpressions().get(1).child(0)
+ .equals(new
Sum(agg.child().getOutputExpressions().get(1).toSlot())))
+ .when(agg ->
agg.getGroupByExpressions().size() == 1)
+ .when(agg -> agg.getGroupByExpressions().get(0)
+
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
+ // check id:
+ .when(agg ->
agg.getOutputExpressions().get(0).getExprId()
+
.equals(outputExpressionList.get(0).getExprId()))
+ .when(agg ->
agg.getOutputExpressions().get(1).getExprId()
+
.equals(outputExpressionList.get(1).getExprId()))
+ );
}
/**
+ * <pre>
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr:
[])
* +--childPlan(id, name, age)
@@ -121,44 +115,41 @@ public class AggregateDisassembleTest {
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as b], groupByExpr: [])
* +--Aggregate(phase: [LOCAL], outputExpr: [SUM(id) as a], groupByExpr:
[])
* +--childPlan(id, name, age)
+ * </pre>
*/
@Test
public void globalAggregate() {
List<Expression> groupExpressionList = Lists.newArrayList();
List<NamedExpression> outputExpressionList = Lists.newArrayList(
- new Alias(new Sum(rStudent.getOutput().get(0).toSlot()),
"sum"));
- Plan root = new LogicalAggregate(groupExpressionList,
outputExpressionList, rStudent);
-
- Plan after = rewrite(root);
-
- Assertions.assertTrue(after instanceof LogicalUnary);
- Assertions.assertTrue(after instanceof LogicalAggregate);
- Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
- LogicalAggregate<Plan> global = (LogicalAggregate) after;
- LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
- Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
- Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+ new Alias(new Sum(rStudent.getOutput().get(0)), "sum"));
+ Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
Expression localOutput0 = new
Sum(rStudent.getOutput().get(0).toSlot());
- Assertions.assertEquals(1, local.getOutputExpressions().size());
- Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof
Alias);
- Assertions.assertEquals(localOutput0,
local.getOutputExpressions().get(0).child(0));
- Assertions.assertEquals(0, local.getGroupByExpressions().size());
-
- Expression globalOutput0 = new
Sum(local.getOutputExpressions().get(0).toSlot());
-
- Assertions.assertEquals(1, global.getOutputExpressions().size());
- Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof
Alias);
- Assertions.assertEquals(globalOutput0,
global.getOutputExpressions().get(0).child(0));
- Assertions.assertEquals(0, global.getGroupByExpressions().size());
-
- // check id:
- Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
- global.getOutputExpressions().get(0).getExprId());
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new AggregateDisassemble())
+ .printlnTree()
+ .matchesFromRoot(
+ logicalAggregate(
+ logicalAggregate()
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.LOCAL))
+ .when(agg ->
agg.getOutputExpressions().size() == 1)
+ .when(agg ->
agg.getOutputExpressions().get(0).child(0).equals(localOutput0))
+ .when(agg ->
agg.getGroupByExpressions().size() == 0)
+ ).when(agg ->
agg.getAggPhase().equals(AggPhase.GLOBAL))
+ .when(agg -> agg.getOutputExpressions().size()
== 1)
+ .when(agg -> agg.getOutputExpressions().get(0)
instanceof Alias)
+ .when(agg ->
agg.getOutputExpressions().get(0).child(0)
+ .equals(new
Sum(agg.child().getOutputExpressions().get(0).toSlot())))
+ .when(agg ->
agg.getGroupByExpressions().size() == 0)
+ // check id:
+ .when(agg ->
agg.getOutputExpressions().get(0).getExprId()
+
.equals(outputExpressionList.get(0).getExprId()))
+ );
}
/**
+ * <pre>
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(id) as sum], groupByExpr:
[age])
* +--childPlan(id, name, age)
@@ -166,6 +157,7 @@ public class AggregateDisassembleTest {
* Aggregate(phase: [GLOBAL], outputExpr: [SUM(b) as c], groupByExpr:
[a])
* +--Aggregate(phase: [LOCAL], outputExpr: [age as a, SUM(id) as b],
groupByExpr: [age])
* +--childPlan(id, name, age)
+ * </pre>
*/
@Test
public void groupExpressionNotInOutput() {
@@ -173,45 +165,40 @@ public class AggregateDisassembleTest {
rStudent.getOutput().get(2).toSlot());
List<NamedExpression> outputExpressionList = Lists.newArrayList(
new Alias(new Sum(rStudent.getOutput().get(0).toSlot()),
"sum"));
- Plan root = new LogicalAggregate(groupExpressionList,
outputExpressionList, rStudent);
-
- Plan after = rewrite(root);
-
- Assertions.assertTrue(after instanceof LogicalUnary);
- Assertions.assertTrue(after instanceof LogicalAggregate);
- Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
- LogicalAggregate<Plan> global = (LogicalAggregate) after;
- LogicalAggregate<Plan> local = (LogicalAggregate) after.child(0);
- Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
- Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
+ Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
Expression localOutput0 = rStudent.getOutput().get(2).toSlot();
Expression localOutput1 = new
Sum(rStudent.getOutput().get(0).toSlot());
Expression localGroupBy = rStudent.getOutput().get(2).toSlot();
- Assertions.assertEquals(2, local.getOutputExpressions().size());
- Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof
SlotReference);
- Assertions.assertEquals(localOutput0,
local.getOutputExpressions().get(0));
- Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof
Alias);
- Assertions.assertEquals(localOutput1,
local.getOutputExpressions().get(1).child(0));
- Assertions.assertEquals(1, local.getGroupByExpressions().size());
- Assertions.assertEquals(localGroupBy,
local.getGroupByExpressions().get(0));
-
- Expression globalOutput0 = new
Sum(local.getOutputExpressions().get(1).toSlot());
- Expression globalGroupBy =
local.getOutputExpressions().get(0).toSlot();
-
- Assertions.assertEquals(1, global.getOutputExpressions().size());
- Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof
Alias);
- Assertions.assertEquals(globalOutput0,
global.getOutputExpressions().get(0).child(0));
- Assertions.assertEquals(1, global.getGroupByExpressions().size());
- Assertions.assertEquals(globalGroupBy,
global.getGroupByExpressions().get(0));
-
- // check id:
- Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
- global.getOutputExpressions().get(0).getExprId());
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new AggregateDisassemble())
+ .printlnTree()
+ .matchesFromRoot(
+ logicalAggregate(
+ logicalAggregate()
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.LOCAL))
+ .when(agg ->
agg.getOutputExpressions().size() == 2)
+ .when(agg ->
agg.getOutputExpressions().get(0).equals(localOutput0))
+ .when(agg ->
agg.getOutputExpressions().get(1).child(0).equals(localOutput1))
+ .when(agg ->
agg.getGroupByExpressions().size() == 1)
+ .when(agg ->
agg.getGroupByExpressions().get(0).equals(localGroupBy))
+ ).when(agg ->
agg.getAggPhase().equals(AggPhase.GLOBAL))
+ .when(agg -> agg.getOutputExpressions().size()
== 1)
+ .when(agg -> agg.getOutputExpressions().get(0)
instanceof Alias)
+ .when(agg ->
agg.getOutputExpressions().get(0).child(0)
+ .equals(new
Sum(agg.child().getOutputExpressions().get(1).toSlot())))
+ .when(agg ->
agg.getGroupByExpressions().size() == 1)
+ .when(agg -> agg.getGroupByExpressions().get(0)
+
.equals(agg.child().getOutputExpressions().get(0).toSlot()))
+ // check id:
+ .when(agg ->
agg.getOutputExpressions().get(0).getExprId()
+
.equals(outputExpressionList.get(0).getExprId()))
+ );
}
/**
+ * <pre>
* the initial plan is:
* Aggregate(phase: [GLOBAL], outputExpr: [(COUNT(distinct age) + 2) as
c], groupByExpr: [id])
* +-- childPlan(id, name, age)
@@ -220,6 +207,7 @@ public class AggregateDisassembleTest {
* +-- Aggregate(phase: [GLOBAL], outputExpr: [id, age], groupByExpr:
[id, age])
* +-- Aggregate(phase: [LOCAL], outputExpr: [id, age], groupByExpr:
[id, age])
* +-- childPlan(id, name, age)
+ * </pre>
*/
@Test
public void distinctAggregateWithGroupBy() {
@@ -229,68 +217,44 @@ public class AggregateDisassembleTest {
new IntegerLiteral(2)), "c"));
Plan root = new LogicalAggregate<>(groupExpressionList,
outputExpressionList, rStudent);
- Plan after = rewrite(root);
-
- Assertions.assertTrue(after instanceof LogicalUnary);
- Assertions.assertTrue(after instanceof LogicalAggregate);
- Assertions.assertTrue(after.child(0) instanceof LogicalUnary);
- LogicalAggregate<Plan> distinctLocal = (LogicalAggregate) after;
- LogicalAggregate<Plan> global = (LogicalAggregate) after.child(0);
- LogicalAggregate<Plan> local = (LogicalAggregate)
after.child(0).child(0);
- Assertions.assertEquals(AggPhase.DISTINCT_LOCAL,
distinctLocal.getAggPhase());
- Assertions.assertEquals(AggPhase.GLOBAL, global.getAggPhase());
- Assertions.assertEquals(AggPhase.LOCAL, local.getAggPhase());
// check local:
// id
- Expression localOutput0 = rStudent.getOutput().get(0).toSlot();
+ Expression localOutput0 = rStudent.getOutput().get(0);
// age
- Expression localOutput1 = rStudent.getOutput().get(2).toSlot();
+ Expression localOutput1 = rStudent.getOutput().get(2);
// id
- Expression localGroupBy0 = rStudent.getOutput().get(0).toSlot();
+ Expression localGroupBy0 = rStudent.getOutput().get(0);
// age
- Expression localGroupBy1 = rStudent.getOutput().get(2).toSlot();
-
- Assertions.assertEquals(2, local.getOutputExpressions().size());
- Assertions.assertTrue(local.getOutputExpressions().get(0) instanceof
SlotReference);
- Assertions.assertEquals(localOutput0,
local.getOutputExpressions().get(0));
- Assertions.assertTrue(local.getOutputExpressions().get(1) instanceof
SlotReference);
- Assertions.assertEquals(localOutput1,
local.getOutputExpressions().get(1));
- Assertions.assertEquals(2, local.getGroupByExpressions().size());
- Assertions.assertEquals(localGroupBy0,
local.getGroupByExpressions().get(0));
- Assertions.assertEquals(localGroupBy1,
local.getGroupByExpressions().get(1));
-
- // check global:
- Expression globalOutput0 =
local.getOutputExpressions().get(0).toSlot();
- Expression globalOutput1 =
local.getOutputExpressions().get(1).toSlot();
- Expression globalGroupBy0 =
local.getOutputExpressions().get(0).toSlot();
- Expression globalGroupBy1 =
local.getOutputExpressions().get(1).toSlot();
-
- Assertions.assertEquals(2, global.getOutputExpressions().size());
- Assertions.assertTrue(global.getOutputExpressions().get(0) instanceof
SlotReference);
- Assertions.assertEquals(globalOutput0,
global.getOutputExpressions().get(0));
- Assertions.assertTrue(global.getOutputExpressions().get(1) instanceof
SlotReference);
- Assertions.assertEquals(globalOutput1,
global.getOutputExpressions().get(1));
- Assertions.assertEquals(2, global.getGroupByExpressions().size());
- Assertions.assertEquals(globalGroupBy0,
global.getGroupByExpressions().get(0));
- Assertions.assertEquals(globalGroupBy1,
global.getGroupByExpressions().get(1));
-
- // check distinct local:
- Expression distinctLocalOutput = new Add(new
Count(local.getOutputExpressions().get(1).toSlot(), true),
- new IntegerLiteral(2));
- Expression distinctLocalGroupBy =
local.getOutputExpressions().get(0).toSlot();
-
- Assertions.assertEquals(1,
distinctLocal.getOutputExpressions().size());
- Assertions.assertTrue(distinctLocal.getOutputExpressions().get(0)
instanceof Alias);
- Assertions.assertEquals(distinctLocalOutput,
distinctLocal.getOutputExpressions().get(0).child(0));
- Assertions.assertEquals(1,
distinctLocal.getGroupByExpressions().size());
- Assertions.assertEquals(distinctLocalGroupBy,
distinctLocal.getGroupByExpressions().get(0));
-
- // check id:
- Assertions.assertEquals(outputExpressionList.get(0).getExprId(),
- distinctLocal.getOutputExpressions().get(0).getExprId());
- }
-
- private Plan rewrite(Plan input) {
- return PlanRewriter.topDownRewrite(input, new ConnectContext(), new
AggregateDisassemble());
+ Expression localGroupBy1 = rStudent.getOutput().get(2);
+
+ PlanChecker.from(MemoTestUtils.createConnectContext(), root)
+ .applyTopDown(new AggregateDisassemble())
+ .matchesFromRoot(
+ logicalAggregate(
+ logicalAggregate(
+ logicalAggregate()
+ .when(agg ->
agg.getAggPhase().equals(AggPhase.LOCAL))
+ .when(agg ->
agg.getOutputExpressions().get(0).equals(localOutput0))
+ .when(agg ->
agg.getOutputExpressions().get(1).equals(localOutput1))
+ .when(agg ->
agg.getGroupByExpressions().get(0).equals(localGroupBy0))
+ .when(agg ->
agg.getGroupByExpressions().get(1).equals(localGroupBy1))
+ ).when(agg ->
agg.getAggPhase().equals(AggPhase.GLOBAL))
+ .when(agg ->
agg.getOutputExpressions().get(0)
+
.equals(agg.child().getOutputExpressions().get(0)))
+ .when(agg ->
agg.getOutputExpressions().get(1)
+
.equals(agg.child().getOutputExpressions().get(1)))
+ .when(agg ->
agg.getGroupByExpressions().get(0)
+
.equals(agg.child().getOutputExpressions().get(0)))
+ .when(agg ->
agg.getGroupByExpressions().get(1)
+
.equals(agg.child().getOutputExpressions().get(1)))
+ ).when(agg ->
agg.getAggPhase().equals(AggPhase.DISTINCT_LOCAL))
+ .when(agg -> agg.getOutputExpressions().size()
== 1)
+ .when(agg -> agg.getOutputExpressions().get(0)
instanceof Alias)
+ .when(agg ->
agg.getOutputExpressions().get(0).child(0) instanceof Add)
+ .when(agg -> agg.getGroupByExpressions().get(0)
+
.equals(agg.child().child().getOutputExpressions().get(0)))
+ .when(agg ->
agg.getOutputExpressions().get(0).getExprId() == outputExpressionList.get(
+ 0).getExprId())
+ );
}
}
diff --git
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java
index b9fa35a335..3b8e2be8e1 100644
---
a/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java
+++
b/fe/fe-core/src/test/java/org/apache/doris/nereids/stats/StatsCalculatorTest.java
@@ -130,16 +130,14 @@ public class StatsCalculatorTest {
childGroup.setStatistics(childStats);
LogicalFilter<GroupPlan> logicalFilter = new LogicalFilter<>(and,
groupPlan);
- GroupExpression groupExpression = new GroupExpression(logicalFilter);
- groupExpression.addChild(childGroup);
+ GroupExpression groupExpression = new GroupExpression(logicalFilter,
ImmutableList.of(childGroup));
Group ownerGroup = new Group();
groupExpression.setOwnerGroup(ownerGroup);
StatsCalculator.estimate(groupExpression);
Assertions.assertEquals((long) (10000 * 0.1 * 0.05),
ownerGroup.getStatistics().getRowCount(), 0.001);
LogicalFilter<GroupPlan> logicalFilterOr = new LogicalFilter<>(or,
groupPlan);
- GroupExpression groupExpressionOr = new
GroupExpression(logicalFilterOr);
- groupExpressionOr.addChild(childGroup);
+ GroupExpression groupExpressionOr = new
GroupExpression(logicalFilterOr, ImmutableList.of(childGroup));
Group ownerGroupOr = new Group();
groupExpressionOr.setOwnerGroup(ownerGroupOr);
StatsCalculator.estimate(groupExpressionOr);
@@ -243,8 +241,7 @@ public class StatsCalculatorTest {
childGroup.setStatistics(childStats);
LogicalLimit<GroupPlan> logicalLimit = new LogicalLimit<>(1, 2,
groupPlan);
- GroupExpression groupExpression = new GroupExpression(logicalLimit);
- groupExpression.addChild(childGroup);
+ GroupExpression groupExpression = new GroupExpression(logicalLimit,
ImmutableList.of(childGroup));
Group ownerGroup = new Group();
ownerGroup.addGroupExpression(groupExpression);
StatsCalculator.estimate(groupExpression);
@@ -274,8 +271,7 @@ public class StatsCalculatorTest {
childGroup.setStatistics(childStats);
LogicalTopN<GroupPlan> logicalTopN = new
LogicalTopN<>(Collections.emptyList(), 1, 2, groupPlan);
- GroupExpression groupExpression = new GroupExpression(logicalTopN);
- groupExpression.addChild(childGroup);
+ GroupExpression groupExpression = new GroupExpression(logicalTopN,
ImmutableList.of(childGroup));
Group ownerGroup = new Group();
ownerGroup.addGroupExpression(groupExpression);
StatsCalculator.estimate(groupExpression);
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]