This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new d4cebb39ba [fix](Nereids): fix SemiJoinLogicalJoinTransposeProject. (#16883) d4cebb39ba is described below commit d4cebb39ba3d7c227947daa95a494282b60de857 Author: jakevin <jakevin...@gmail.com> AuthorDate: Sat Feb 18 23:12:34 2023 +0800 [fix](Nereids): fix SemiJoinLogicalJoinTransposeProject. (#16883) --- .../rules/exploration/join/OuterJoinLAsscom.java | 2 +- .../join/SemiJoinLogicalJoinTransposeProject.java | 95 +++++++++++++--------- .../nereids/trees/plans/logical/LogicalJoin.java | 1 - 3 files changed, 56 insertions(+), 42 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java index dda781f327..a23c3f0015 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java @@ -63,8 +63,8 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory { return logicalJoin(logicalJoin(), group()) .when(join -> VALID_TYPE_PAIR_SET.contains(Pair.of(join.left().getJoinType(), join.getJoinType()))) .when(topJoin -> checkReorder(topJoin, topJoin.left())) - .when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputExprIdSet())) .whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint()) + .when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputExprIdSet())) .then(topJoin -> { LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left(); GroupPlan a = bottomJoin.left(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java index 76fda42fe5..89a843fffc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java @@ -17,11 +17,13 @@ package org.apache.doris.nereids.rules.exploration.join; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinHint; @@ -33,9 +35,12 @@ import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; -import java.util.ArrayList; +import java.util.HashMap; import java.util.List; +import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; /** * <ul> @@ -64,7 +69,6 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto .whenNot(topJoin -> topJoin.left().child().getJoinType().isSemiOrAntiJoin()) .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .when(join -> JoinReorderUtils.checkProject(join.left())) - .when(this::conditionChecker) .then(topSemiJoin -> { LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topSemiJoin.left(); LogicalJoin<GroupPlan, GroupPlan> bottomJoin = project.child(); @@ -72,17 +76,17 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto GroupPlan b = bottomJoin.right(); GroupPlan c = topSemiJoin.right(); - Set<ExprId> aOutputExprIdSet = a.getOutputExprIdSet(); - - List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts(); - - boolean lasscom = false; - for (Expression hashJoinConjunct : hashJoinConjuncts) { - Set<ExprId> usedSlotExprIdSet = hashJoinConjunct.getInputSlotExprIds(); - lasscom = Utils.isIntersecting(usedSlotExprIdSet, aOutputExprIdSet) || lasscom; + // push topSemiJoin down project, so we need replace conjuncts by project. + Pair<List<Expression>, List<Expression>> conjuncts = replaceConjuncts(topSemiJoin, project); + Set<ExprId> conjunctsIds = Stream.concat(conjuncts.first.stream(), conjuncts.second.stream()) + .flatMap(expr -> expr.getInputSlotExprIds().stream()).collect(Collectors.toSet()); + ContainsType containsType = containsChildren(conjunctsIds, a.getOutputExprIdSet(), + b.getOutputExprIdSet()); + if (containsType == ContainsType.ALL) { + return null; } - if (lasscom) { + if (containsType == ContainsType.LEFT) { /*- * topSemiJoin project * / \ | @@ -92,22 +96,24 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto * / \ / \ * A B A C */ + // Preconditions.checkState(bottomJoin.getJoinType() != JoinType.RIGHT_OUTER_JOIN); if (bottomJoin.getJoinType() == JoinType.RIGHT_OUTER_JOIN) { // when bottom join is right outer join, we change it to inner join // if we want to do this trans. However, we do not allow different logical properties // in one group. So we need to change it to inner join in rewrite step. - return topSemiJoin; + return null; } LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( - topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(), - topSemiJoin.getOtherJoinConjuncts(), JoinHint.NONE, a, c); + topSemiJoin.getJoinType(), conjuncts.first, conjuncts.second, JoinHint.NONE, a, c); LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), - JoinHint.NONE, - newBottomSemiJoin, b); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin); + JoinHint.NONE, newBottomSemiJoin, b); + return project.withChildren(newTopJoin); } else { + if (leftDeep) { + return null; + } /*- * topSemiJoin project * / \ | @@ -121,40 +127,49 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto // when bottom join is left outer join, we change it to inner join // if we want to do this trans. However, we do not allow different logical properties // in one group. So we need to change it to inner join in rewrite step. - return topSemiJoin; + return null; } LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = new LogicalJoin<>( - topSemiJoin.getJoinType(), topSemiJoin.getHashJoinConjuncts(), - topSemiJoin.getOtherJoinConjuncts(), JoinHint.NONE, b, c); + topSemiJoin.getJoinType(), conjuncts.first, conjuncts.second, JoinHint.NONE, b, c); LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), - JoinHint.NONE, - a, newBottomSemiJoin); - return JoinReorderUtils.projectOrSelf(new ArrayList<>(topSemiJoin.getOutput()), newTopJoin); + JoinHint.NONE, a, newBottomSemiJoin); + return project.withChildren(newTopJoin); } }).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT); } - // project of bottomJoin just return A OR B, else return false. - private boolean conditionChecker( - LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan> topSemiJoin) { - List<Expression> hashJoinConjuncts = topSemiJoin.getHashJoinConjuncts(); + private Pair<List<Expression>, List<Expression>> replaceConjuncts(LogicalJoin<? extends Plan, ? extends Plan> join, + LogicalProject<? extends Plan> project) { + Map<ExprId, Slot> outputToInput = new HashMap<>(); + for (NamedExpression outputExpr : project.getProjects()) { + Set<Slot> usedSlots = outputExpr.getInputSlots(); + Preconditions.checkState(usedSlots.size() == 1); + Slot inputSlot = usedSlots.iterator().next(); + outputToInput.put(outputExpr.getExprId(), inputSlot); + } + List<Expression> topHashConjuncts = + JoinReorderUtils.replaceJoinConjuncts(join.getHashJoinConjuncts(), outputToInput); + List<Expression> topOtherConjuncts = + JoinReorderUtils.replaceJoinConjuncts(join.getOtherJoinConjuncts(), outputToInput); + return Pair.of(topHashConjuncts, topOtherConjuncts); + } - List<Slot> aOutput = topSemiJoin.left().child().left().getOutput(); - List<Slot> bOutput = topSemiJoin.left().child().right().getOutput(); + enum ContainsType { + LEFT, RIGHT, ALL + } - boolean hashContainsA = false; - boolean hashContainsB = false; - for (Expression hashJoinConjunct : hashJoinConjuncts) { - Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance); - hashContainsA = Utils.isIntersecting(usedSlot, aOutput) || hashContainsA; - hashContainsB = Utils.isIntersecting(usedSlot, bOutput) || hashContainsB; - } - if (leftDeep && hashContainsB) { - return false; + private ContainsType containsChildren(Set<ExprId> conjunctsExprIdSet, Set<ExprId> left, Set<ExprId> right) { + boolean containsLeft = Utils.isIntersecting(conjunctsExprIdSet, left); + boolean containsRight = Utils.isIntersecting(conjunctsExprIdSet, right); + Preconditions.checkState(containsLeft || containsRight, "join output must contain child"); + if (containsLeft && containsRight) { + return ContainsType.ALL; + } else if (containsLeft) { + return ContainsType.LEFT; + } else { + return ContainsType.RIGHT; } - Preconditions.checkState(hashContainsA || hashContainsB, "join output must contain child"); - return !(hashContainsA && hashContainsB); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 72018b914b..4c187674ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -130,7 +130,6 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends return otherJoinConjuncts; } - @Override public List<Expression> getHashJoinConjuncts() { return hashJoinConjuncts; } --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org