This is an automated email from the ASF dual-hosted git repository. starocean999 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 80c1a99ef6 [enhance](Nereids): refactor JoinReorder code. (#16477) 80c1a99ef6 is described below commit 80c1a99ef617a5ea9169ba01dc4c3848a49d82e0 Author: jakevin <jakevin...@gmail.com> AuthorDate: Mon Feb 13 09:08:58 2023 +0800 [enhance](Nereids): refactor JoinReorder code. (#16477) * [enhance](Nereids): refactor JoinReorder code. * apply nullable * checkstyle * set enableDPHypOptimizer default false --- .../exploration/join/InnerJoinLAsscomProject.java | 102 +++++--------- ...oinReorderCommon.java => JoinReorderUtils.java} | 42 +++++- .../exploration/join/OuterJoinLAsscomProject.java | 148 ++++++--------------- .../join/SemiJoinLogicalJoinTransposeProject.java | 12 +- .../join/SemiJoinSemiJoinTransposeProject.java | 19 +-- .../org/apache/doris/nereids/util/JoinUtils.java | 18 +-- .../org/apache/doris/nereids/util/PlanUtils.java | 15 --- .../java/org/apache/doris/qe/SessionVariable.java | 2 +- .../join/InnerJoinLAsscomProjectTest.java | 4 +- .../join/OuterJoinLAsscomProjectTest.java | 1 + .../apache/doris/nereids/util/PlanUtilsTest.java | 19 --- 11 files changed, 132 insertions(+), 250 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java index ca8cb1a130..8a3fe670b0 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java @@ -32,7 +32,6 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; -import org.apache.doris.nereids.util.PlanUtils; import com.google.common.base.Preconditions; @@ -64,7 +63,7 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { return innerLogicalJoin(logicalProject(innerLogicalJoin()), group()) .when(topJoin -> InnerJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) - .when(join -> JoinReorderCommon.checkProject(join.left())) + .when(join -> JoinReorderUtils.checkProject(join.left())) .then(topJoin -> { /* ********** init ********** */ @@ -74,41 +73,30 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { GroupPlan b = bottomJoin.right(); GroupPlan c = topJoin.right(); Set<Slot> cOutputSet = c.getOutputSet(); - Set<ExprId> bOutputExprIdSet = b.getOutputExprIdSet(); Set<ExprId> cOutputExprIdSet = c.getOutputExprIdSet(); /* ********** Split projects ********** */ - Map<Boolean, List<NamedExpression>> projectExprsMap = projects.stream() - .collect(Collectors.partitioningBy(projectExpr -> { - Set<ExprId> usedExprIds = projectExpr - .<Set<SlotReference>>collect(SlotReference.class::isInstance) - .stream() - .map(SlotReference::getExprId) - .collect(Collectors.toSet()); - return bOutputExprIdSet.containsAll(usedExprIds); - })); - List<NamedExpression> newLeftProjects = projectExprsMap.get(Boolean.FALSE); - List<NamedExpression> newRightProjects = projectExprsMap.get(Boolean.TRUE); - - Set<ExprId> bExprIdSet = getExprIdSetForB(bottomJoin.right(), newRightProjects); + Map<Boolean, List<NamedExpression>> map = JoinReorderUtils.splitProjection(projects, b); + List<NamedExpression> newLeftProjects = map.get(false); + List<NamedExpression> newRightProjects = map.get(true); + Set<ExprId> bExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(b, newRightProjects); /* ********** split HashConjuncts ********** */ - Map<Boolean, List<Expression>> splitHashJoinConjuncts = splitConjunctsWithAlias( + Map<Boolean, List<Expression>> splitHashConjuncts = splitConjunctsWithAlias( topJoin.getHashJoinConjuncts(), bottomJoin.getHashJoinConjuncts(), bExprIdSet); - List<Expression> newTopHashJoinConjuncts = splitHashJoinConjuncts.get(true); - List<Expression> newBottomHashJoinConjuncts = splitHashJoinConjuncts.get(false); - Preconditions.checkState(!newTopHashJoinConjuncts.isEmpty(), - "LAsscom newTopHashJoinConjuncts join can't empty"); - if (newBottomHashJoinConjuncts.size() == 0) { + List<Expression> newTopHashConjuncts = splitHashConjuncts.get(true); + List<Expression> newBottomHashConjuncts = splitHashConjuncts.get(false); + Preconditions.checkState(!newTopHashConjuncts.isEmpty(), "newTopHashConjuncts is empty"); + if (newBottomHashConjuncts.size() == 0) { return null; } /* ********** split OtherConjuncts ********** */ - Map<Boolean, List<Expression>> splitOtherJoinConjuncts = splitConjunctsWithAlias( + Map<Boolean, List<Expression>> splitOtherConjuncts = splitConjunctsWithAlias( topJoin.getOtherJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), bExprIdSet); - List<Expression> newTopOtherJoinConjuncts = splitOtherJoinConjuncts.get(true); - List<Expression> newBottomOtherJoinConjuncts = splitOtherJoinConjuncts.get(false); + List<Expression> newTopOtherConjuncts = splitOtherConjuncts.get(true); + List<Expression> newBottomOtherConjuncts = splitOtherConjuncts.get(false); /* ********** replace Conjuncts by projects ********** */ Map<ExprId, Slot> inputToOutput = new HashMap<>(); @@ -124,22 +112,18 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { outputToInput.put(outputSlot.getExprId(), inputSlot); } } - // replace hashJoinConjuncts - newBottomHashJoinConjuncts = JoinUtils.replaceJoinConjuncts( - newBottomHashJoinConjuncts, outputToInput); - newTopHashJoinConjuncts = JoinUtils.replaceJoinConjuncts( - newTopHashJoinConjuncts, inputToOutput); - - // replace otherJoinConjuncts - newBottomOtherJoinConjuncts = JoinUtils.replaceJoinConjuncts( - newBottomOtherJoinConjuncts, outputToInput); - newTopOtherJoinConjuncts = JoinUtils.replaceJoinConjuncts( - newTopOtherJoinConjuncts, inputToOutput); + // replace hashConjuncts + newBottomHashConjuncts = JoinUtils.replaceJoinConjuncts(newBottomHashConjuncts, outputToInput); + newTopHashConjuncts = JoinUtils.replaceJoinConjuncts(newTopHashConjuncts, inputToOutput); + + // replace otherConjuncts + newBottomOtherConjuncts = JoinUtils.replaceJoinConjuncts(newBottomOtherConjuncts, outputToInput); + newTopOtherConjuncts = JoinUtils.replaceJoinConjuncts(newTopOtherConjuncts, inputToOutput); // Add all slots used by OnCondition when projects not empty. Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat( - newTopHashJoinConjuncts.stream(), - newTopOtherJoinConjuncts.stream()) + newTopHashConjuncts.stream(), + newTopOtherConjuncts.stream()) .flatMap(onExpr -> { Set<Slot> usedSlotRefs = onExpr.collect(SlotReference.class::isInstance); return usedSlotRefs.stream(); @@ -147,8 +131,8 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { .filter(slot -> !cOutputExprIdSet.contains(slot.getExprId())) .collect(Collectors.partitioningBy( slot -> bExprIdSet.contains(slot.getExprId()), Collectors.toSet())); - Set<Slot> aUsedSlots = abOnUsedSlots.get(Boolean.FALSE); - Set<Slot> bUsedSlots = abOnUsedSlots.get(Boolean.TRUE); + Set<Slot> aUsedSlots = abOnUsedSlots.get(false); + Set<Slot> bUsedSlots = abOnUsedSlots.get(true); JoinUtils.addSlotsUsedByOn(bUsedSlots, newRightProjects); JoinUtils.addSlotsUsedByOn(aUsedSlots, newLeftProjects); @@ -159,43 +143,23 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { /* ********** new Plan ********** */ LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), - newBottomHashJoinConjuncts, newBottomOtherJoinConjuncts, JoinHint.NONE, + newBottomHashConjuncts, newBottomOtherConjuncts, JoinHint.NONE, a, c, bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); - Plan left = newBottomJoin; - if (!newLeftProjects.stream().map(NamedExpression::toSlot) - .map(NamedExpression::getExprId).collect(Collectors.toSet()) - .equals(newBottomJoin.getOutputExprIdSet())) { - left = PlanUtils.projectOrSelf(newLeftProjects, newBottomJoin); - } - Plan right = b; - if (!newRightProjects.stream().map(NamedExpression::toSlot) - .map(NamedExpression::getExprId).collect(Collectors.toSet()) - .equals(b.getOutputExprIdSet())) { - right = PlanUtils.projectOrSelf(newRightProjects, b); - } + Plan left = JoinReorderUtils.projectOrSelf(newLeftProjects, newBottomJoin); + Plan right = JoinReorderUtils.projectOrSelf(newRightProjects, b); LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), - newTopHashJoinConjuncts, newTopOtherJoinConjuncts, JoinHint.NONE, + newTopHashConjuncts, newTopOtherConjuncts, JoinHint.NONE, left, right, topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); - if (topJoin.getLogicalProperties().equals(newTopJoin.getLogicalProperties())) { - return newTopJoin; - } - - return PlanUtils.project(new ArrayList<>(topJoin.getOutput()), newTopJoin).get(); + return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_INNER_JOIN_LASSCOM_PROJECT); } - public static Set<ExprId> getExprIdSetForB(GroupPlan b, List<NamedExpression> bProject) { - return Stream.concat( - b.getOutput().stream().map(NamedExpression::getExprId), - bProject.stream().map(NamedExpression::getExprId)).collect(Collectors.toSet()); - } - /** * Split Condition into two part. * True: contains B. @@ -211,11 +175,11 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { return ExpressionUtils.isIntersecting(bExprIdSet, usedExprIds); })); // * don't include B, just include (A C) - // we add it into newBottomJoin HashJoinConjuncts. + // we add it into newBottomJoin HashConjuncts. // * include B, include (A B C) or (A B) - // we add it into newTopJoin HashJoinConjuncts. - List<Expression> newTopHashJoinConjuncts = splitOn.get(true); - newTopHashJoinConjuncts.addAll(bottomConjuncts); + // we add it into newTopJoin HashConjuncts. + List<Expression> newTopHashConjuncts = splitOn.get(true); + newTopHashConjuncts.addAll(bottomConjuncts); return splitOn; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderCommon.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java similarity index 56% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderCommon.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java index 321ab28d3d..b115d03d5a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderCommon.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java @@ -18,18 +18,24 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * Common */ -class JoinReorderCommon { +class JoinReorderUtils { /** * check project inside Join to prevent matching some pattern. * just allow projection is slot or Alias(slot) to prevent reorder when: @@ -44,11 +50,39 @@ class JoinReorderCommon { return true; } if (expr instanceof Alias) { - if (((Alias) expr).child() instanceof Slot) { - return true; - } + return ((Alias) expr).child() instanceof Slot; } return false; }); } + + static Map<Boolean, List<NamedExpression>> splitProjection( + List<NamedExpression> projects, Plan splitChild) { + Set<ExprId> splitExprIds = splitChild.getOutputExprIdSet(); + + Map<Boolean, List<NamedExpression>> projectExprsMap = projects.stream() + .collect(Collectors.partitioningBy(projectExpr -> { + Set<ExprId> usedExprIds = projectExpr.getInputSlotExprIds(); + return splitExprIds.containsAll(usedExprIds); + })); + + return projectExprsMap; + } + + public static Set<ExprId> combineProjectAndChildExprId(Plan b, List<NamedExpression> bProject) { + return Stream.concat( + b.getOutput().stream().map(NamedExpression::getExprId), + bProject.stream().map(NamedExpression::getExprId)).collect(Collectors.toSet()); + } + + /** + * If projectExprs is empty or project output equal plan output, return the original plan. + */ + public static Plan projectOrSelf(List<NamedExpression> projectExprs, Plan plan) { + if (projectExprs.isEmpty() || projectExprs.stream().map(NamedExpression::getExprId).collect(Collectors.toSet()) + .equals(plan.getOutputExprIdSet())) { + return plan; + } + return new LogicalProject<>(projectExprs, plan); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java index e58a227acf..cfb94e9b46 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java @@ -31,12 +31,10 @@ import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; -import org.apache.doris.nereids.util.PlanUtils; import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; +import com.google.common.collect.ImmutableSet; import java.util.ArrayList; import java.util.HashMap; @@ -68,9 +66,8 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { Pair.of(join.left().child().getJoinType(), join.getJoinType()))) .when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) - .when(join -> JoinReorderCommon.checkProject(join.left())) + .when(join -> JoinReorderUtils.checkProject(join.left())) .then(topJoin -> { - /* ********** init ********** */ List<NamedExpression> projects = topJoin.left().getProjects(); LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child(); @@ -78,51 +75,19 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { GroupPlan b = bottomJoin.right(); GroupPlan c = topJoin.right(); Set<Slot> cOutputSet = c.getOutputSet(); - Set<ExprId> aOutputExprIdSet = a.getOutputExprIdSet(); Set<ExprId> cOutputExprIdSet = c.getOutputExprIdSet(); /* ********** Split projects ********** */ - Map<Boolean, List<NamedExpression>> projectExprsMap = projects.stream() - .collect(Collectors.partitioningBy(projectExpr -> { - Set<ExprId> usedExprIds = projectExpr - .<Set<SlotReference>>collect(SlotReference.class::isInstance) - .stream() - .map(SlotReference::getExprId) - .collect(Collectors.toSet()); - return aOutputExprIdSet.containsAll(usedExprIds); - })); - List<NamedExpression> newLeftProjects = projectExprsMap.get(Boolean.TRUE); - List<NamedExpression> newRightProjects = projectExprsMap.get(Boolean.FALSE); - Set<ExprId> aExprIdSet = getExprIdSetForA(bottomJoin.left(), newLeftProjects); - - /* ********** split Conjuncts ********** */ - Map<Boolean, List<Expression>> newHashJoinConjuncts - = createNewConjunctsWithAlias( - topJoin.getHashJoinConjuncts(), bottomJoin.getHashJoinConjuncts(), aExprIdSet); - List<Expression> newTopHashJoinConjuncts = newHashJoinConjuncts.get(true); - Preconditions.checkState(!newTopHashJoinConjuncts.isEmpty(), - "LAsscom newTopHashJoinConjuncts join can't empty"); - // When newTopHashJoinConjuncts.size() != bottomJoin.getHashJoinConjuncts().size() - // It means that topHashJoinConjuncts contain A, B, C, we shouldn't do LAsscom. - if (topJoin.getJoinType() != bottomJoin.getJoinType() - && newTopHashJoinConjuncts.size() != bottomJoin.getHashJoinConjuncts().size()) { - return null; - } - List<Expression> newBottomHashJoinConjuncts = newHashJoinConjuncts.get(false); - if (newBottomHashJoinConjuncts.size() == 0) { - return null; - } + Map<Boolean, List<NamedExpression>> map = JoinReorderUtils.splitProjection(projects, a); + List<NamedExpression> newLeftProjects = map.get(true); + List<NamedExpression> newRightProjects = map.get(false); + Set<ExprId> aExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(a, newLeftProjects); - Map<Boolean, List<Expression>> newOtherJoinConjuncts - = createNewConjunctsWithAlias( - topJoin.getOtherJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), - aExprIdSet); - List<Expression> newTopOtherJoinConjuncts = newOtherJoinConjuncts.get(true); - List<Expression> newBottomOtherJoinConjuncts = newOtherJoinConjuncts.get(false); - if (newBottomOtherJoinConjuncts.size() != topJoin.getOtherJoinConjuncts().size() - || newTopOtherJoinConjuncts.size() != bottomJoin.getOtherJoinConjuncts().size()) { - return null; - } + /* ********** Conjuncts ********** */ + List<Expression> newTopHashConjuncts = bottomJoin.getHashJoinConjuncts(); + List<Expression> newBottomHashConjuncts = topJoin.getHashJoinConjuncts(); + List<Expression> newTopOtherConjuncts = bottomJoin.getOtherJoinConjuncts(); + List<Expression> newBottomOtherConjuncts = topJoin.getOtherJoinConjuncts(); /* ********** replace Conjuncts by projects ********** */ Map<ExprId, Slot> inputToOutput = new HashMap<>(); @@ -139,21 +104,28 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { outputToInput.put(outputSlot.getExprId(), inputSlot); } } - // replace hashJoinConjuncts - newBottomHashJoinConjuncts = JoinUtils.replaceJoinConjuncts( - newBottomHashJoinConjuncts, outputToInput); - newTopHashJoinConjuncts = JoinUtils.replaceJoinConjuncts( - newTopHashJoinConjuncts, inputToOutput); - // replace otherJoinConjuncts - newBottomOtherJoinConjuncts = JoinUtils.replaceJoinConjuncts( - newBottomOtherJoinConjuncts, outputToInput); - newTopOtherJoinConjuncts = JoinUtils.replaceJoinConjuncts( - newTopOtherJoinConjuncts, inputToOutput); + // replace hashConjuncts + newBottomHashConjuncts = JoinUtils.replaceJoinConjuncts(newBottomHashConjuncts, outputToInput); + newTopHashConjuncts = JoinUtils.replaceJoinConjuncts(newTopHashConjuncts, inputToOutput); + // replace otherConjuncts + newBottomOtherConjuncts = JoinUtils.replaceJoinConjuncts(newBottomOtherConjuncts, outputToInput); + newTopOtherConjuncts = JoinUtils.replaceJoinConjuncts(newTopOtherConjuncts, inputToOutput); + + /* ********** check ********** */ + Set<Slot> acOutputSet = ImmutableSet.<Slot>builder().addAll(a.getOutputSet()) + .addAll(c.getOutputSet()).build(); + if (!Stream.concat(newBottomHashConjuncts.stream(), newBottomOtherConjuncts.stream()) + .allMatch(expr -> { + Set<Slot> inputSlots = expr.getInputSlots(); + return acOutputSet.containsAll(inputSlots); + })) { + return null; + } // Add all slots used by OnCondition when projects not empty. Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat( - newTopHashJoinConjuncts.stream(), - newTopOtherJoinConjuncts.stream()) + newTopHashConjuncts.stream(), + newTopOtherConjuncts.stream()) .flatMap(onExpr -> { Set<Slot> usedSlotRefs = onExpr.collect(SlotReference.class::isInstance); return usedSlotRefs.stream(); @@ -161,8 +133,8 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { .filter(slot -> !cOutputExprIdSet.contains(slot.getExprId())) .collect(Collectors.partitioningBy( slot -> aExprIdSet.contains(slot.getExprId()), Collectors.toSet())); - Set<Slot> aUsedSlots = abOnUsedSlots.get(Boolean.TRUE); - Set<Slot> bUsedSlots = abOnUsedSlots.get(Boolean.FALSE); + Set<Slot> aUsedSlots = abOnUsedSlots.get(true); + Set<Slot> bUsedSlots = abOnUsedSlots.get(false); JoinUtils.addSlotsUsedByOn(bUsedSlots, newRightProjects); JoinUtils.addSlotsUsedByOn(aUsedSlots, newLeftProjects); @@ -174,68 +146,22 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { /* ********** new Plan ********** */ LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), - newBottomHashJoinConjuncts, newBottomOtherJoinConjuncts, JoinHint.NONE, + newBottomHashConjuncts, newBottomOtherConjuncts, JoinHint.NONE, a, c, bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); - Plan left = newBottomJoin; - if (!newLeftProjects.stream().map(NamedExpression::toSlot) - .map(NamedExpression::getExprId).collect(Collectors.toSet()) - .equals(newBottomJoin.getOutputExprIdSet())) { - left = PlanUtils.projectOrSelf(newLeftProjects, newBottomJoin); - } - Plan right = b; - if (!newRightProjects.stream().map(NamedExpression::toSlot) - .map(NamedExpression::getExprId).collect(Collectors.toSet()) - .equals(b.getOutputExprIdSet())) { - right = PlanUtils.projectOrSelf(newRightProjects, b); - } + Plan left = JoinReorderUtils.projectOrSelf(newLeftProjects, newBottomJoin); + Plan right = JoinReorderUtils.projectOrSelf(newRightProjects, b); LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), - newTopHashJoinConjuncts, newTopOtherJoinConjuncts, JoinHint.NONE, + newTopHashConjuncts, newTopOtherConjuncts, JoinHint.NONE, left, right, topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); - - if (topJoin.getLogicalProperties().equals(newTopJoin.getLogicalProperties())) { - return newTopJoin; - } - - return PlanUtils.project(new ArrayList<>(topJoin.getOutput()), newTopJoin).get(); + return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_OUTER_JOIN_LASSCOM_PROJECT); } - private Map<Boolean, List<Expression>> createNewConjunctsWithAlias(List<Expression> topConjuncts, - List<Expression> bottomConjuncts, Set<ExprId> bExprIdSet) { - // if top join's conjuncts are all related to A, we can do reorder - Map<Boolean, List<Expression>> splitOn = new HashMap<>(); - splitOn.put(true, Lists.newArrayList()); - if (topConjuncts.stream().allMatch(topHashOn -> { - Set<ExprId> usedSlotsId = topHashOn.getInputSlots().stream() - .map(NamedExpression::getExprId) - .collect(Collectors.toSet()); - - return ExpressionUtils.isIntersecting(bExprIdSet, usedSlotsId); - })) { - // do reorder, create new bottom join conjuncts - splitOn.put(false, Lists.newArrayList(topConjuncts)); - } else { - // can't reorder, return empty list - splitOn.put(false, Lists.newArrayList()); - } - - List<Expression> newTopHashJoinConjuncts = splitOn.get(true); - newTopHashJoinConjuncts.addAll(bottomConjuncts); - - return splitOn; - } - - private Set<ExprId> getExprIdSetForA(GroupPlan a, List<NamedExpression> aProject) { - return Stream.concat( - a.getOutput().stream().map(NamedExpression::getExprId), - aProject.stream().map(NamedExpression::getExprId)).collect(Collectors.toSet()); - } - private Set<Slot> forceToNullable(Set<Slot> slotSet) { return slotSet.stream().map(s -> (Slot) s.rewriteUp(e -> { if (e instanceof SlotReference) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java index 8cb502c95f..6d2f705ec7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java @@ -59,11 +59,11 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto return logicalJoin(logicalProject(logicalJoin()), group()) .when(topJoin -> (topJoin.getJoinType().isLeftSemiOrAntiJoin() && (topJoin.left().child().getJoinType().isInnerJoin() - || topJoin.left().child().getJoinType().isLeftOuterJoin() - || topJoin.left().child().getJoinType().isRightOuterJoin()))) + || topJoin.left().child().getJoinType().isLeftOuterJoin() + || topJoin.left().child().getJoinType().isRightOuterJoin()))) .whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin()) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) - .when(join -> JoinReorderCommon.checkProject(join.left())) + .when(join -> JoinReorderUtils.checkProject(join.left())) .when(this::conditionChecker) .then(topSemiJoin -> { LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left(); @@ -106,8 +106,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), JoinHint.NONE, newBottomSemiJoin, b); - - return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin); + return JoinReorderUtils.projectOrSelf(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin); } else { /*- * topSemiJoin project @@ -132,8 +131,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), JoinHint.NONE, a, newBottomSemiJoin); - - return new LogicalProject<>(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin); + return JoinReorderUtils.projectOrSelf(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin); } }).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java index cfb77fb38d..3e86b5b17e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinSemiJoinTransposeProject.java @@ -55,7 +55,7 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory .when(this::typeChecker) .when(topSemi -> InnerJoinLAsscom.checkReorder(topSemi, topSemi.left().child())) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) - .when(join -> JoinReorderCommon.checkProject(join.left())) + .when(join -> JoinReorderUtils.checkProject(join.left())) .then(topSemi -> { LogicalJoin<GroupPlan, GroupPlan> bottomSemi = topSemi.left().child(); LogicalProject abProject = topSemi.left(); @@ -73,24 +73,17 @@ public class SemiJoinSemiJoinTransposeProject extends OneExplorationRuleFactory } }) ); - LogicalJoin newBottomSemi = new LogicalJoin(topSemi.getJoinType(), topSemi.getHashJoinConjuncts(), + LogicalJoin newBottomSemi = new LogicalJoin<>(topSemi.getJoinType(), topSemi.getHashJoinConjuncts(), topSemi.getOtherJoinConjuncts(), JoinHint.NONE, a, c, bottomSemi.getJoinReorderContext()); newBottomSemi.getJoinReorderContext().setHasCommute(false); newBottomSemi.getJoinReorderContext().setHasLAsscom(false); - LogicalProject acProject = new LogicalProject(Lists.newArrayList(acProjects), - newBottomSemi); - LogicalJoin newTopSemi = new LogicalJoin(bottomSemi.getJoinType(), + LogicalProject acProject = new LogicalProject<>(Lists.newArrayList(acProjects), newBottomSemi); + LogicalJoin newTopSemi = new LogicalJoin<>(bottomSemi.getJoinType(), bottomSemi.getHashJoinConjuncts(), bottomSemi.getOtherJoinConjuncts(), JoinHint.NONE, - acProject, b, - topSemi.getJoinReorderContext()); + acProject, b, topSemi.getJoinReorderContext()); newTopSemi.getJoinReorderContext().setHasLAsscom(true); - //return newTopSemi; - if (topSemi.getLogicalProperties().equals(newTopSemi)) { - return newTopSemi; - } else { - return new LogicalProject<>(new ArrayList<>(topSemi.getOutput()), newTopSemi); - } + return JoinReorderUtils.projectOrSelf(new ArrayList<>(topSemi.getOutput()), newTopSemi); }).toRule(RuleType.LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE_PROJECT); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java index 1a995003e6..0532f0b26a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java @@ -335,6 +335,11 @@ public class JoinUtils { return joinOutputExprIdSet; } + private static List<Slot> applyNullable(List<Slot> slots, boolean nullable) { + return slots.stream().map(o -> o.withNullable(nullable)) + .collect(ImmutableList.toImmutableList()); + } + /** * calculate the output slot of a join operator according join type and its children * @param joinType the type of join operator @@ -343,11 +348,6 @@ public class JoinUtils { * @return return the output slots */ public static List<Slot> getJoinOutput(JoinType joinType, Plan left, Plan right) { - List<Slot> newLeftOutput = left.getOutput().stream().map(o -> o.withNullable(true)) - .collect(ImmutableList.toImmutableList()); - List<Slot> newRightOutput = right.getOutput().stream().map(o -> o.withNullable(true)) - .collect(ImmutableList.toImmutableList()); - switch (joinType) { case LEFT_SEMI_JOIN: case LEFT_ANTI_JOIN: @@ -359,17 +359,17 @@ public class JoinUtils { case LEFT_OUTER_JOIN: return ImmutableList.<Slot>builder() .addAll(left.getOutput()) - .addAll(newRightOutput) + .addAll(applyNullable(right.getOutput(), true)) .build(); case RIGHT_OUTER_JOIN: return ImmutableList.<Slot>builder() - .addAll(newLeftOutput) + .addAll(applyNullable(left.getOutput(), true)) .addAll(right.getOutput()) .build(); case FULL_OUTER_JOIN: return ImmutableList.<Slot>builder() - .addAll(newLeftOutput) - .addAll(newRightOutput) + .addAll(applyNullable(left.getOutput(), true)) + .addAll(applyNullable(right.getOutput(), true)) .build(); default: return ImmutableList.<Slot>builder() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java index a5401258cf..1a08fed6dd 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/PlanUtils.java @@ -18,12 +18,9 @@ package org.apache.doris.nereids.util; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import java.util.List; import java.util.Optional; import java.util.Set; @@ -31,18 +28,6 @@ import java.util.Set; * Util for plan */ public class PlanUtils { - public static Optional<LogicalProject<? extends Plan>> project(List<NamedExpression> projectExprs, Plan plan) { - if (projectExprs.isEmpty()) { - return Optional.empty(); - } - - return Optional.of(new LogicalProject<>(projectExprs, plan)); - } - - public static Plan projectOrSelf(List<NamedExpression> projectExprs, Plan plan) { - return project(projectExprs, plan).map(Plan.class::cast).orElse(plan); - } - public static Optional<LogicalFilter<? extends Plan>> filter(Set<Expression> predicates, Plan plan) { if (predicates.isEmpty()) { return Optional.empty(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java index 5be52b1097..17d5273da2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java +++ b/fe/fe-core/src/main/java/org/apache/doris/qe/SessionVariable.java @@ -558,7 +558,7 @@ public class SessionVariable implements Serializable, Writable { private boolean checkOverflowForDecimal = false; @VariableMgr.VarAttr(name = ENABLE_DPHYP_OPTIMIZER) - private boolean enableDPHypOptimizer = true; + private boolean enableDPHypOptimizer = false; /** * as the new optimizer is not mature yet, use this var * to control whether to use new optimizer, remove it when diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProjectTest.java index 3c982a2eb4..c272ce97cf 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProjectTest.java @@ -165,6 +165,7 @@ class InnerJoinLAsscomProjectTest implements PatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) .printlnTree() .applyExploration(InnerJoinLAsscomProject.INSTANCE.build()) + .printlnExploration() .matchesExploration( innerLogicalJoin( logicalProject( @@ -178,8 +179,7 @@ class InnerJoinLAsscomProjectTest implements PatternMatchSupported { "[(t2.id#6 = id#8), (t1.id#4 = t2.id#6)]") && Objects.equals(join.getOtherJoinConjuncts().toString(), "[(t2.name#7 > name#9), (t1.name#5 > t2.name#7)]")) - ) - .printlnExploration(); + ); } /** diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java index 6ac562b800..b2492e86e9 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java @@ -51,6 +51,7 @@ class OuterJoinLAsscomProjectTest implements PatternMatchSupported { .build(); PlanChecker.from(MemoTestUtils.createConnectContext(), plan) + .printlnTree() .applyExploration(OuterJoinLAsscomProject.INSTANCE.build()) .printlnExploration() .matchesExploration( diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanUtilsTest.java index 08cb4b087c..a13825a885 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanUtilsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanUtilsTest.java @@ -17,35 +17,16 @@ package org.apache.doris.nereids.util; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalFilter; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import com.google.common.collect.ImmutableSet; -import com.google.common.collect.Lists; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.List; - class PlanUtilsTest { - - @Test - void projectOrSelf() { - LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - Plan self = PlanUtils.projectOrSelf(Lists.newArrayList(), scan); - Assertions.assertSame(scan, self); - - NamedExpression slot = scan.getOutput().get(0); - List<NamedExpression> projects = Lists.newArrayList(); - projects.add(slot); - Plan project = PlanUtils.projectOrSelf(projects, scan); - Assertions.assertTrue(project instanceof LogicalProject); - } - @Test void filterOrSelf() { LogicalOlapScan scan = PlanConstructor.newLogicalOlapScan(0, "t1", 0); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org