This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 4ae777bfc5 [fix](Nereids) NPE caused by GroupExpression has null owner group when choosing best plan (#13252) 4ae777bfc5 is described below commit 4ae777bfc55a8b0fa1721011a088f8d9392d9327 Author: Kikyou1997 <33112463+kikyou1...@users.noreply.github.com> AuthorDate: Thu Oct 20 22:23:36 2022 +0800 [fix](Nereids) NPE caused by GroupExpression has null owner group when choosing best plan (#13252) --- .../org/apache/doris/nereids/NereidsPlanner.java | 42 +-- .../nereids/jobs/cascades/DeriveStatsJob.java | 5 +- .../java/org/apache/doris/nereids/memo/Group.java | 3 +- .../apache/doris/nereids/memo/GroupExpression.java | 11 +- .../java/org/apache/doris/nereids/memo/Memo.java | 13 +- .../apache/doris/nereids/trees/plans/FakePlan.java | 104 +++++++ .../apache/doris/nereids/memo/MemoCopyInTest.java | 84 ------ .../apache/doris/nereids/memo/MemoInitTest.java | 182 ------------ .../memo/{MemoRewriteTest.java => MemoTest.java} | 307 ++++++++++++++++++--- .../suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy | 107 +++++++ 10 files changed, 520 insertions(+), 338 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java index ff1ec80890..59cb499900 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java @@ -190,25 +190,31 @@ public class NereidsPlanner extends Planner { private PhysicalPlan chooseBestPlan(Group rootGroup, PhysicalProperties physicalProperties) throws AnalysisException { - GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow( - () -> new AnalysisException("lowestCostPlans with physicalProperties doesn't exist")).second; - List<PhysicalProperties> inputPropertiesList = groupExpression.getInputPropertiesList(physicalProperties); - - List<Plan> planChildren = Lists.newArrayList(); - for (int i = 0; i < groupExpression.arity(); i++) { - planChildren.add(chooseBestPlan(groupExpression.child(i), inputPropertiesList.get(i))); - } - - Plan plan = groupExpression.getPlan().withChildren(planChildren); - if (!(plan instanceof PhysicalPlan)) { - throw new AnalysisException("Result plan must be PhysicalPlan"); + try { + GroupExpression groupExpression = rootGroup.getLowestCostPlan(physicalProperties).orElseThrow( + () -> new AnalysisException("lowestCostPlans with physicalProperties doesn't exist")).second; + List<PhysicalProperties> inputPropertiesList = groupExpression.getInputPropertiesList(physicalProperties); + + List<Plan> planChildren = Lists.newArrayList(); + for (int i = 0; i < groupExpression.arity(); i++) { + planChildren.add(chooseBestPlan(groupExpression.child(i), inputPropertiesList.get(i))); + } + + Plan plan = groupExpression.getPlan().withChildren(planChildren); + if (!(plan instanceof PhysicalPlan)) { + throw new AnalysisException("Result plan must be PhysicalPlan"); + } + + // TODO: set (logical and physical)properties/statistics/... for physicalPlan. + PhysicalPlan physicalPlan = ((PhysicalPlan) plan).withPhysicalPropertiesAndStats( + groupExpression.getOutputProperties(physicalProperties), + groupExpression.getOwnerGroup().getStatistics()); + return physicalPlan; + } catch (Exception e) { + String memo = cascadesContext.getMemo().toString(); + LOG.warn("Failed to choose best plan, memo structure:{}", memo, e); + throw new AnalysisException("Failed to choose best plan", e); } - - // TODO: set (logical and physical)properties/statistics/... for physicalPlan. - PhysicalPlan physicalPlan = ((PhysicalPlan) plan).withPhysicalPropertiesAndStats( - groupExpression.getOutputProperties(physicalProperties), - groupExpression.getOwnerGroup().getStatistics()); - return physicalPlan; } @Override diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java index 618382438a..8797327311 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/cascades/DeriveStatsJob.java @@ -60,9 +60,8 @@ public class DeriveStatsJob extends Job { deriveChildren = true; pushJob(new DeriveStatsJob(this)); for (Group child : groupExpression.children()) { - GroupExpression childGroupExpr = child.getLogicalExpressions().get(0); - if (!child.getLogicalExpressions().isEmpty() && !childGroupExpr.isStatDerived()) { - pushJob(new DeriveStatsJob(childGroupExpr, context)); + if (!child.getLogicalExpressions().isEmpty()) { + pushJob(new DeriveStatsJob(child.getLogicalExpressions().get(0), context)); } } } else { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java index 6a65f4a373..fbb0a0d32b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Group.java @@ -429,9 +429,8 @@ public class Group { lowestCostPlans.forEach((physicalProperties, costAndGroupExpr) -> { GroupExpression bestGroupExpression = costAndGroupExpr.second; // change into target group. - if (bestGroupExpression.getOwnerGroup() == this) { + if (bestGroupExpression.getOwnerGroup() == this || bestGroupExpression.getOwnerGroup() == null) { bestGroupExpression.setOwnerGroup(target); - bestGroupExpression.children().set(0, target); } if (!target.lowestCostPlans.containsKey(physicalProperties)) { target.lowestCostPlans.put(physicalProperties, costAndGroupExpr); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java index a6f555b649..e8de9a6e54 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/GroupExpression.java @@ -229,11 +229,20 @@ public class GroupExpression { public String toString() { StringBuilder builder = new StringBuilder(); builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan).append(") children=["); + if (ownerGroup == null) { + builder.append("OWNER GROUP IS NULL[]"); + } else { + builder.append(ownerGroup.getGroupId()).append("(plan=").append(plan.toString()).append(") children=["); + } for (Group group : children) { builder.append(group.getGroupId()).append(" "); } builder.append("] stats="); - builder.append(ownerGroup.getStatistics()); + if (ownerGroup != null) { + builder.append(ownerGroup.getStatistics()); + } else { + builder.append("NULL"); + } return builder.toString(); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index a484584a49..abafce3880 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -47,6 +47,11 @@ public class Memo { private final Map<GroupExpression, GroupExpression> groupExpressions = Maps.newHashMap(); private final Group root; + // FOR TEST ONLY + public Memo() { + root = null; + } + public Memo(Plan plan) { root = init(plan); } @@ -325,20 +330,20 @@ public class Memo { * @param destination destination group * @return merged group */ - private Group mergeGroup(Group source, Group destination) { + public Group mergeGroup(Group source, Group destination) { if (source.equals(destination)) { return source; } List<GroupExpression> needReplaceChild = Lists.newArrayList(); - groupExpressions.values().forEach(groupExpression -> { + for (GroupExpression groupExpression : groupExpressions.values()) { if (groupExpression.children().contains(source)) { if (groupExpression.getOwnerGroup().equals(destination)) { // cycle, we should not merge - return; + return null; } needReplaceChild.add(groupExpression); } - }); + } for (GroupExpression groupExpression : needReplaceChild) { groupExpressions.remove(groupExpression); List<Group> children = groupExpression.children(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/FakePlan.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/FakePlan.java new file mode 100644 index 0000000000..2058ad37d2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/FakePlan.java @@ -0,0 +1,104 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.plans; + +import org.apache.doris.nereids.memo.GroupExpression; +import org.apache.doris.nereids.properties.LogicalProperties; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.plans.visitor.PlanVisitor; + +import java.util.ArrayList; +import java.util.List; +import java.util.Optional; + +/** + * Used for unit test only. + */ +public class FakePlan implements Plan { + + @Override + public List<Plan> children() { + return null; + } + + @Override + public Plan child(int index) { + return null; + } + + @Override + public int arity() { + return 0; + } + + @Override + public Plan withChildren(List<Plan> children) { + return null; + } + + @Override + public PlanType getType() { + return null; + } + + @Override + public Optional<GroupExpression> getGroupExpression() { + return Optional.empty(); + } + + @Override + public <R, C> R accept(PlanVisitor<R, C> visitor, C context) { + return null; + } + + @Override + public List<? extends Expression> getExpressions() { + return new ArrayList<>(); + } + + @Override + public LogicalProperties getLogicalProperties() { + return new LogicalProperties(ArrayList::new); + } + + @Override + public boolean canBind() { + return false; + } + + @Override + public List<Slot> getOutput() { + return new ArrayList<>(); + } + + @Override + public String treeString() { + return "DUMMY"; + } + + @Override + public Plan withGroupExpression(Optional<GroupExpression> groupExpression) { + return this; + } + + @Override + public Plan withLogicalProperties(Optional<LogicalProperties> logicalProperties) { + return this; + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoCopyInTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoCopyInTest.java deleted file mode 100644 index 5cf28d4528..0000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoCopyInTest.java +++ /dev/null @@ -1,84 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.memo; - -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.util.MemoTestUtils; -import org.apache.doris.nereids.util.PatternMatchSupported; -import org.apache.doris.nereids.util.PlanChecker; -import org.apache.doris.nereids.util.PlanConstructor; - -import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.Test; - -public class MemoCopyInTest implements PatternMatchSupported { - LogicalJoin<LogicalOlapScan, LogicalOlapScan> logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN, - PlanConstructor.newLogicalOlapScan(0, "A", 0), PlanConstructor.newLogicalOlapScan(1, "B", 0)); - LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>( - JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0)); - - /** - * Original: - * Group 0: LogicalOlapScan C - * Group 1: LogicalOlapScan B - * Group 2: LogicalOlapScan A - * Group 3: Join(Group 1, Group 2) - * Group 4: Join(Group 0, Group 3) - * <p> - * Then: - * Copy In Join(Group 2, Group 1) into Group 3 - * <p> - * Expected: - * Group 0: LogicalOlapScan C - * Group 1: LogicalOlapScan B - * Group 2: LogicalOlapScan A - * Group 3: Join(Group 1, Group 2), Join(Group 2, Group 1) - * Group 4: Join(Group 0, Group 3) - */ - @Test - public void testInsertSameGroup() { - PlanChecker.from(MemoTestUtils.createConnectContext(), logicalJoinABC) - .transform( - // swap join's children - logicalJoin(logicalOlapScan(), logicalOlapScan()).then(joinBA -> - new LogicalProject<>(Lists.newArrayList(joinBA.getOutput()), - new LogicalJoin<>(JoinType.INNER_JOIN, joinBA.right(), joinBA.left())) - )) - .checkGroupNum(6) - .checkGroupExpressionNum(7) - .checkMemo(memo -> { - Group root = memo.getRoot(); - Assertions.assertEquals(1, root.getLogicalExpressions().size()); - GroupExpression joinABC = root.getLogicalExpression(); - Assertions.assertEquals(2, joinABC.child(0).getLogicalExpressions().size()); - Assertions.assertEquals(1, joinABC.child(1).getLogicalExpressions().size()); - GroupExpression joinAB = joinABC.child(0).getLogicalExpressions().get(0); - GroupExpression project = joinABC.child(0).getLogicalExpressions().get(1); - GroupExpression joinBA = project.child(0).getLogicalExpression(); - Assertions.assertTrue(joinAB.getPlan() instanceof LogicalJoin); - Assertions.assertTrue(joinBA.getPlan() instanceof LogicalJoin); - }); - - } - - // TODO: test mergeGroup(). -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoInitTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoInitTest.java deleted file mode 100644 index 81f13afc80..0000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoInitTest.java +++ /dev/null @@ -1,182 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.memo; - -import org.apache.doris.catalog.OlapTable; -import org.apache.doris.common.IdGenerator; -import org.apache.doris.nereids.analyzer.UnboundRelation; -import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.RelationId; -import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.util.MemoTestUtils; -import org.apache.doris.nereids.util.PatternMatchSupported; -import org.apache.doris.nereids.util.PlanChecker; -import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.qe.ConnectContext; - -import com.google.common.collect.ImmutableList; -import org.junit.jupiter.api.Test; - -import java.util.Objects; - -public class MemoInitTest implements PatternMatchSupported { - private ConnectContext connectContext = MemoTestUtils.createConnectContext(); - - @Test - public void initByOneLevelPlan() { - OlapTable table = PlanConstructor.newOlapTable(0, "a", 1); - LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table); - - PlanChecker.from(connectContext, scan) - .checkGroupNum(1) - .matches( - logicalOlapScan().when(scan::equals) - ); - } - - @Test - public void initByTwoLevelChainPlan() { - OlapTable table = PlanConstructor.newOlapTable(0, "a", 1); - LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table); - - LogicalProject<LogicalOlapScan> topProject = new LogicalProject<>( - ImmutableList.of(scan.computeOutput().get(0)), scan); - - PlanChecker.from(connectContext, topProject) - .checkGroupNum(2) - .matches( - logicalProject( - any().when(child -> Objects.equals(child, scan)) - ).when(root -> Objects.equals(root, topProject)) - ); - } - - @Test - public void initByJoinSameUnboundTable() { - UnboundRelation scanA = new UnboundRelation(ImmutableList.of("a")); - - LogicalJoin<UnboundRelation, UnboundRelation> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA); - - PlanChecker.from(connectContext, topJoin) - .checkGroupNum(3) - .matches( - logicalJoin( - any().when(left -> Objects.equals(left, scanA)), - any().when(right -> Objects.equals(right, scanA)) - ).when(root -> Objects.equals(root, topJoin)) - ); - } - - @Test - public void initByJoinSameLogicalTable() { - IdGenerator<RelationId> generator = RelationId.createGenerator(); - OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1); - LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA); - LogicalOlapScan scanA1 = new LogicalOlapScan(generator.getNextId(), tableA); - - LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA1); - - PlanChecker.from(connectContext, topJoin) - .checkGroupNum(3) - .matches( - logicalJoin( - any().when(left -> Objects.equals(left, scanA)), - any().when(right -> Objects.equals(right, scanA1)) - ).when(root -> Objects.equals(root, topJoin)) - ); - } - - @Test - public void initByTwoLevelJoinPlan() { - IdGenerator<RelationId> generator = RelationId.createGenerator(); - OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1); - OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1); - LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA); - LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB); - - LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanB); - - PlanChecker.from(connectContext, topJoin) - .checkGroupNum(3) - .matches( - logicalJoin( - any().when(left -> Objects.equals(left, scanA)), - any().when(right -> Objects.equals(right, scanB)) - ).when(root -> Objects.equals(root, topJoin)) - ); - } - - @Test - public void initByThreeLevelChainPlan() { - OlapTable table = PlanConstructor.newOlapTable(0, "a", 1); - LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table); - - LogicalProject<LogicalOlapScan> project = new LogicalProject<>( - ImmutableList.of(scan.computeOutput().get(0)), scan); - LogicalFilter<LogicalProject<LogicalOlapScan>> filter = new LogicalFilter<>( - new EqualTo(scan.computeOutput().get(0), new IntegerLiteral(1)), project); - - PlanChecker.from(connectContext, filter) - .checkGroupNum(3) - .matches( - logicalFilter( - logicalProject( - any().when(child -> Objects.equals(child, scan)) - ).when(root -> Objects.equals(root, project)) - ).when(root -> Objects.equals(root, filter)) - ); - } - - @Test - public void initByThreeLevelBushyPlan() { - IdGenerator<RelationId> generator = RelationId.createGenerator(); - OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1); - OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1); - OlapTable tableC = PlanConstructor.newOlapTable(0, "c", 1); - OlapTable tableD = PlanConstructor.newOlapTable(0, "d", 1); - LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA); - LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB); - LogicalOlapScan scanC = new LogicalOlapScan(generator.getNextId(), tableC); - LogicalOlapScan scanD = new LogicalOlapScan(generator.getNextId(), tableD); - - LogicalJoin<LogicalOlapScan, LogicalOlapScan> leftJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanA, scanB); - LogicalJoin<LogicalOlapScan, LogicalOlapScan> rightJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanC, scanD); - LogicalJoin topJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, leftJoin, rightJoin); - - PlanChecker.from(connectContext, topJoin) - .checkGroupNum(7) - .matches( - logicalJoin( - logicalJoin( - any().when(child -> Objects.equals(child, scanA)), - any().when(child -> Objects.equals(child, scanB)) - ).when(left -> Objects.equals(left, leftJoin)), - - logicalJoin( - any().when(child -> Objects.equals(child, scanC)), - any().when(child -> Objects.equals(child, scanD)) - ).when(right -> Objects.equals(right, rightJoin)) - ).when(root -> Objects.equals(root, topJoin)) - ); - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java similarity index 74% rename from fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java rename to fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java index 7a82c5dcd2..9ad9bf8d16 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/memo/MemoTest.java @@ -17,17 +17,24 @@ package org.apache.doris.nereids.memo; +import org.apache.doris.catalog.OlapTable; +import org.apache.doris.common.IdGenerator; +import org.apache.doris.common.jmockit.Deencapsulation; import org.apache.doris.nereids.analyzer.UnboundRelation; import org.apache.doris.nereids.analyzer.UnboundSlot; import org.apache.doris.nereids.properties.LogicalProperties; +import org.apache.doris.nereids.properties.PhysicalProperties; import org.apache.doris.nereids.properties.UnboundLogicalProperties; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLiteral; +import org.apache.doris.nereids.trees.plans.FakePlan; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.LeafPlan; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.RelationId; +import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalLimit; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; @@ -45,13 +52,235 @@ import com.google.common.collect.Lists; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; +import java.util.ArrayList; +import java.util.Arrays; import java.util.List; +import java.util.Map; import java.util.Objects; import java.util.Optional; -public class MemoRewriteTest implements PatternMatchSupported { +class MemoTest implements PatternMatchSupported { + private ConnectContext connectContext = MemoTestUtils.createConnectContext(); + private LogicalJoin<LogicalOlapScan, LogicalOlapScan> logicalJoinAB = new LogicalJoin<>(JoinType.INNER_JOIN, + PlanConstructor.newLogicalOlapScan(0, "A", 0), + PlanConstructor.newLogicalOlapScan(1, "B", 0)); + + private LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> logicalJoinABC = new LogicalJoin<>( + JoinType.INNER_JOIN, logicalJoinAB, PlanConstructor.newLogicalOlapScan(2, "C", 0)); + + @Test + void mergeGroup() throws Exception { + Memo memo = new Memo(); + GroupId gid2 = new GroupId(2); + Group srcGroup = new Group(gid2, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new)); + GroupId gid3 = new GroupId(3); + Group dstGroup = new Group(gid3, new GroupExpression(new FakePlan()), new LogicalProperties(ArrayList::new)); + FakePlan d = new FakePlan(); + GroupExpression ge1 = new GroupExpression(d, Arrays.asList(srcGroup)); + GroupId gid0 = new GroupId(0); + Group g1 = new Group(gid0, ge1, new LogicalProperties(ArrayList::new)); + g1.setBestPlan(ge1, Double.MIN_VALUE, PhysicalProperties.ANY); + GroupExpression ge2 = new GroupExpression(d, Arrays.asList(dstGroup)); + GroupId gid1 = new GroupId(1); + Group g2 = new Group(gid1, ge2, new LogicalProperties(ArrayList::new)); + Map<GroupId, Group> groups = (Map<GroupId, Group>) Deencapsulation.getField(memo, "groups"); + groups.put(gid2, srcGroup); + groups.put(gid3, dstGroup); + groups.put(gid0, g1); + groups.put(gid1, g2); + Map<GroupExpression, GroupExpression> groupExpressions = + (Map<GroupExpression, GroupExpression>) Deencapsulation.getField(memo, "groupExpressions"); + groupExpressions.put(ge1, ge1); + groupExpressions.put(ge2, ge2); + memo.mergeGroup(srcGroup, dstGroup); + Assertions.assertNull(g1.getBestPlan(PhysicalProperties.ANY)); + Assertions.assertEquals(ge1.getOwnerGroup(), g2); + } + + /** + * Original: + * Group 0: LogicalOlapScan C + * Group 1: LogicalOlapScan B + * Group 2: LogicalOlapScan A + * Group 3: Join(Group 1, Group 2) + * Group 4: Join(Group 0, Group 3) + * <p> + * Then: + * Copy In Join(Group 2, Group 1) into Group 3 + * <p> + * Expected: + * Group 0: LogicalOlapScan C + * Group 1: LogicalOlapScan B + * Group 2: LogicalOlapScan A + * Group 3: Join(Group 1, Group 2), Join(Group 2, Group 1) + * Group 4: Join(Group 0, Group 3) + */ + @Test + public void testInsertSameGroup() { + PlanChecker.from(MemoTestUtils.createConnectContext(), logicalJoinABC) + .transform( + // swap join's children + logicalJoin(logicalOlapScan(), logicalOlapScan()).then(joinBA -> + new LogicalProject<>(Lists.newArrayList(joinBA.getOutput()), + new LogicalJoin<>(JoinType.INNER_JOIN, joinBA.right(), joinBA.left())) + )) + .checkGroupNum(6) + .checkGroupExpressionNum(7) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(1, root.getLogicalExpressions().size()); + GroupExpression joinABC = root.getLogicalExpression(); + Assertions.assertEquals(2, joinABC.child(0).getLogicalExpressions().size()); + Assertions.assertEquals(1, joinABC.child(1).getLogicalExpressions().size()); + GroupExpression joinAB = joinABC.child(0).getLogicalExpressions().get(0); + GroupExpression project = joinABC.child(0).getLogicalExpressions().get(1); + GroupExpression joinBA = project.child(0).getLogicalExpression(); + Assertions.assertTrue(joinAB.getPlan() instanceof LogicalJoin); + Assertions.assertTrue(joinBA.getPlan() instanceof LogicalJoin); + }); + + } + + @Test + public void initByOneLevelPlan() { + OlapTable table = PlanConstructor.newOlapTable(0, "a", 1); + LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table); + + PlanChecker.from(connectContext, scan) + .checkGroupNum(1) + .matches( + logicalOlapScan().when(scan::equals) + ); + } + + @Test + public void initByTwoLevelChainPlan() { + OlapTable table = PlanConstructor.newOlapTable(0, "a", 1); + LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table); + + LogicalProject<LogicalOlapScan> topProject = new LogicalProject<>( + ImmutableList.of(scan.computeOutput().get(0)), scan); + + PlanChecker.from(connectContext, topProject) + .checkGroupNum(2) + .matches( + logicalProject( + any().when(child -> Objects.equals(child, scan)) + ).when(root -> Objects.equals(root, topProject)) + ); + } + + @Test + public void initByJoinSameUnboundTable() { + UnboundRelation scanA = new UnboundRelation(ImmutableList.of("a")); + + LogicalJoin<UnboundRelation, UnboundRelation> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA); + + PlanChecker.from(connectContext, topJoin) + .checkGroupNum(3) + .matches( + logicalJoin( + any().when(left -> Objects.equals(left, scanA)), + any().when(right -> Objects.equals(right, scanA)) + ).when(root -> Objects.equals(root, topJoin)) + ); + } + + @Test + public void initByJoinSameLogicalTable() { + IdGenerator<RelationId> generator = RelationId.createGenerator(); + OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1); + LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA); + LogicalOlapScan scanA1 = new LogicalOlapScan(generator.getNextId(), tableA); + + LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanA1); + + PlanChecker.from(connectContext, topJoin) + .checkGroupNum(3) + .matches( + logicalJoin( + any().when(left -> Objects.equals(left, scanA)), + any().when(right -> Objects.equals(right, scanA1)) + ).when(root -> Objects.equals(root, topJoin)) + ); + } + + @Test + public void initByTwoLevelJoinPlan() { + IdGenerator<RelationId> generator = RelationId.createGenerator(); + OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1); + OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1); + LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA); + LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB); + + LogicalJoin<LogicalOlapScan, LogicalOlapScan> topJoin = new LogicalJoin<>(JoinType.INNER_JOIN, scanA, scanB); + + PlanChecker.from(connectContext, topJoin) + .checkGroupNum(3) + .matches( + logicalJoin( + any().when(left -> Objects.equals(left, scanA)), + any().when(right -> Objects.equals(right, scanB)) + ).when(root -> Objects.equals(root, topJoin)) + ); + } + + @Test + public void initByThreeLevelChainPlan() { + OlapTable table = PlanConstructor.newOlapTable(0, "a", 1); + LogicalOlapScan scan = new LogicalOlapScan(RelationId.createGenerator().getNextId(), table); + + LogicalProject<LogicalOlapScan> project = new LogicalProject<>( + ImmutableList.of(scan.computeOutput().get(0)), scan); + LogicalFilter<LogicalProject<LogicalOlapScan>> filter = new LogicalFilter<>( + new EqualTo(scan.computeOutput().get(0), new IntegerLiteral(1)), project); + + PlanChecker.from(connectContext, filter) + .checkGroupNum(3) + .matches( + logicalFilter( + logicalProject( + any().when(child -> Objects.equals(child, scan)) + ).when(root -> Objects.equals(root, project)) + ).when(root -> Objects.equals(root, filter)) + ); + } + + @Test + public void initByThreeLevelBushyPlan() { + IdGenerator<RelationId> generator = RelationId.createGenerator(); + OlapTable tableA = PlanConstructor.newOlapTable(0, "a", 1); + OlapTable tableB = PlanConstructor.newOlapTable(0, "b", 1); + OlapTable tableC = PlanConstructor.newOlapTable(0, "c", 1); + OlapTable tableD = PlanConstructor.newOlapTable(0, "d", 1); + LogicalOlapScan scanA = new LogicalOlapScan(generator.getNextId(), tableA); + LogicalOlapScan scanB = new LogicalOlapScan(generator.getNextId(), tableB); + LogicalOlapScan scanC = new LogicalOlapScan(generator.getNextId(), tableC); + LogicalOlapScan scanD = new LogicalOlapScan(generator.getNextId(), tableD); + + LogicalJoin<LogicalOlapScan, LogicalOlapScan> leftJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanA, scanB); + LogicalJoin<LogicalOlapScan, LogicalOlapScan> rightJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, scanC, scanD); + LogicalJoin topJoin = new LogicalJoin<>(JoinType.CROSS_JOIN, leftJoin, rightJoin); + + PlanChecker.from(connectContext, topJoin) + .checkGroupNum(7) + .matches( + logicalJoin( + logicalJoin( + any().when(child -> Objects.equals(child, scanA)), + any().when(child -> Objects.equals(child, scanB)) + ).when(left -> Objects.equals(left, leftJoin)), + + logicalJoin( + any().when(child -> Objects.equals(child, scanC)), + any().when(child -> Objects.equals(child, scanD)) + ).when(right -> Objects.equals(right, rightJoin)) + ).when(root -> Objects.equals(root, topJoin)) + ); + } + /* * A -> A: * @@ -204,25 +433,15 @@ public class MemoRewriteTest implements PatternMatchSupported { A a2 = new A(ImmutableList.of("student"), State.ALREADY_REWRITE); LogicalLimit<UnboundRelation> limit = new LogicalLimit<>(1, 0, a2); - PlanChecker.from(connectContext, a) - .applyBottomUp( - unboundRelation() - // 4: add state condition to the pattern's predicates - .when(r -> (r instanceof A) && ((A) r).state == State.NOT_REWRITE) - .then(unboundRelation -> { - // 5: new plan and change state, so this case equal to 'A -> B(C)', which C has - // different state with A - A notRewritePlan = (A) unboundRelation; - return limit.withChildren(notRewritePlan.withState(State.ALREADY_REWRITE)); - } - ) - ) - .checkGroupNum(2) - .matchesFromRoot( - logicalLimit( - unboundRelation().when(a2::equals) - ).when(limit::equals) - ); + PlanChecker.from(connectContext, a).applyBottomUp(unboundRelation() + // 4: add state condition to the pattern's predicates + .when(r -> (r instanceof A) && ((A) r).state == State.NOT_REWRITE).then(unboundRelation -> { + // 5: new plan and change state, so this case equal to 'A -> B(C)', which C has + // different state with A + A notRewritePlan = (A) unboundRelation; + return limit.withChildren(notRewritePlan.withState(State.ALREADY_REWRITE)); + })).checkGroupNum(2) + .matchesFromRoot(logicalLimit(unboundRelation().when(a2::equals)).when(limit::equals)); } /* @@ -359,7 +578,7 @@ public class MemoRewriteTest implements PatternMatchSupported { ) .checkGroupNum(1) .matchesFromRoot( - logicalOlapScan().when(student::equals) + logicalOlapScan().when(student::equals) ); } @@ -801,30 +1020,30 @@ public class MemoRewriteTest implements PatternMatchSupported { ) )) .applyTopDown( - logicalLimit(logicalJoin()).then(limit -> { - LogicalJoin<GroupPlan, GroupPlan> join = limit.child(); - switch (join.getJoinType()) { - case LEFT_OUTER_JOIN: - return join.withChildren(limit.withChildren(join.left()), join.right()); - case RIGHT_OUTER_JOIN: - return join.withChildren(join.left(), limit.withChildren(join.right())); - case CROSS_JOIN: - return join.withChildren(limit.withChildren(join.left()), limit.withChildren(join.right())); - case INNER_JOIN: - if (!join.getHashJoinConjuncts().isEmpty()) { - return join.withChildren( - limit.withChildren(join.left()), - limit.withChildren(join.right()) - ); - } else { + logicalLimit(logicalJoin()).then(limit -> { + LogicalJoin<GroupPlan, GroupPlan> join = limit.child(); + switch (join.getJoinType()) { + case LEFT_OUTER_JOIN: + return join.withChildren(limit.withChildren(join.left()), join.right()); + case RIGHT_OUTER_JOIN: + return join.withChildren(join.left(), limit.withChildren(join.right())); + case CROSS_JOIN: + return join.withChildren(limit.withChildren(join.left()), limit.withChildren(join.right())); + case INNER_JOIN: + if (!join.getHashJoinConjuncts().isEmpty()) { + return join.withChildren( + limit.withChildren(join.left()), + limit.withChildren(join.right()) + ); + } else { + return limit; + } + case LEFT_ANTI_JOIN: + // todo: support anti join. + default: return limit; - } - case LEFT_ANTI_JOIN: - // todo: support anti join. - default: - return limit; - } - }) + } + }) ) .matchesFromRoot( logicalJoin( diff --git a/regression-test/suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy b/regression-test/suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy new file mode 100644 index 0000000000..0407830d15 --- /dev/null +++ b/regression-test/suites/tpch_sf1_p1/tpch_sf1/nereids/q21.groovy @@ -0,0 +1,107 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ + +suite("tpch_sf1_q21_nereids") { + String realDb = context.config.getDbNameByFile(context.file) + // get parent directory's group + realDb = realDb.substring(0, realDb.lastIndexOf("_")) + + sql "use ${realDb}" + + sql 'set enable_nereids_planner=true' + sql 'set enable_fallback_to_original_planner=false' + + qt_select """ +select + s_name, + count(*) as numwait +from + supplier, + lineitem l1, + orders, + nation +where + s_suppkey = l1.l_suppkey + and o_orderkey = l1.l_orderkey + and o_orderstatus = 'F' + and l1.l_receiptdate > l1.l_commitdate + and exists ( + select + * + from + lineitem l2 + where + l2.l_orderkey = l1.l_orderkey + and l2.l_suppkey <> l1.l_suppkey + ) + and not exists ( + select + * + from + lineitem l3 + where + l3.l_orderkey = l1.l_orderkey + and l3.l_suppkey <> l1.l_suppkey + and l3.l_receiptdate > l3.l_commitdate + ) + and s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' +group by + s_name +order by + numwait desc, + s_name +limit 100; + """ + + qt_select """ +select /*+SET_VAR(exec_mem_limit=8589934592, parallel_fragment_exec_instance_num=16, enable_vectorized_engine=true, batch_size=4096, disable_join_reorder=true, enable_cost_based_join_reorder=true, enable_projection=true) */ +s_name, count(*) as numwait +from orders join +( + select * from + lineitem l2 right semi join + ( + select * from + lineitem l3 right anti join + ( + select * from + lineitem l1 join + ( + select * from + supplier join nation + where s_nationkey = n_nationkey + and n_name = 'SAUDI ARABIA' + ) t1 + where t1.s_suppkey = l1.l_suppkey and l1.l_receiptdate > l1.l_commitdate + ) t2 + on l3.l_orderkey = t2.l_orderkey and l3.l_suppkey <> t2.l_suppkey and l3.l_receiptdate > l3.l_commitdate + ) t3 + on l2.l_orderkey = t3.l_orderkey and l2.l_suppkey <> t3.l_suppkey +) t4 +on o_orderkey = t4.l_orderkey and o_orderstatus = 'F' +group by + t4.s_name +order by + numwait desc, + t4.s_name +limit 100; + """ + +} \ No newline at end of file --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org