This is an automated email from the ASF dual-hosted git repository.

jakevin pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new d4cebb39ba [fix](Nereids): fix SemiJoinLogicalJoinTransposeProject. 
(#16883)
d4cebb39ba is described below

commit d4cebb39ba3d7c227947daa95a494282b60de857
Author: jakevin <jakevin...@gmail.com>
AuthorDate: Sat Feb 18 23:12:34 2023 +0800

    [fix](Nereids): fix SemiJoinLogicalJoinTransposeProject. (#16883)
---
 .../rules/exploration/join/OuterJoinLAsscom.java   |  2 +-
 .../join/SemiJoinLogicalJoinTransposeProject.java  | 95 +++++++++++++---------
 .../nereids/trees/plans/logical/LogicalJoin.java   |  1 -
 3 files changed, 56 insertions(+), 42 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
index dda781f327..a23c3f0015 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java
@@ -63,8 +63,8 @@ public class OuterJoinLAsscom extends 
OneExplorationRuleFactory {
         return logicalJoin(logicalJoin(), group())
                 .when(join -> 
VALID_TYPE_PAIR_SET.contains(Pair.of(join.left().getJoinType(), 
join.getJoinType())))
                 .when(topJoin -> checkReorder(topJoin, topJoin.left()))
-                .when(topJoin -> checkCondition(topJoin, 
topJoin.left().right().getOutputExprIdSet()))
                 .whenNot(join -> join.hasJoinHint() || 
join.left().hasJoinHint())
+                .when(topJoin -> checkCondition(topJoin, 
topJoin.left().right().getOutputExprIdSet()))
                 .then(topJoin -> {
                     LogicalJoin<GroupPlan, GroupPlan> bottomJoin = 
topJoin.left();
                     GroupPlan a = bottomJoin.left();
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
index 76fda42fe5..89a843fffc 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java
@@ -17,11 +17,13 @@
 
 package org.apache.doris.nereids.rules.exploration.join;
 
+import org.apache.doris.common.Pair;
 import org.apache.doris.nereids.rules.Rule;
 import org.apache.doris.nereids.rules.RuleType;
 import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory;
 import org.apache.doris.nereids.trees.expressions.ExprId;
 import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.NamedExpression;
 import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.plans.GroupPlan;
 import org.apache.doris.nereids.trees.plans.JoinHint;
@@ -33,9 +35,12 @@ import org.apache.doris.nereids.util.Utils;
 
 import com.google.common.base.Preconditions;
 
-import java.util.ArrayList;
+import java.util.HashMap;
 import java.util.List;
+import java.util.Map;
 import java.util.Set;
+import java.util.stream.Collectors;
+import java.util.stream.Stream;
 
 /**
  * <ul>
@@ -64,7 +69,6 @@ public class SemiJoinLogicalJoinTransposeProject extends 
OneExplorationRuleFacto
                 .whenNot(topJoin -> 
topJoin.left().child().getJoinType().isSemiOrAntiJoin())
                 .whenNot(join -> join.hasJoinHint() || 
join.left().child().hasJoinHint())
                 .when(join -> JoinReorderUtils.checkProject(join.left()))
-                .when(this::conditionChecker)
                 .then(topSemiJoin -> {
                     LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project 
= topSemiJoin.left();
                     LogicalJoin<GroupPlan, GroupPlan> bottomJoin = 
project.child();
@@ -72,17 +76,17 @@ public class SemiJoinLogicalJoinTransposeProject extends 
OneExplorationRuleFacto
                     GroupPlan b = bottomJoin.right();
                     GroupPlan c = topSemiJoin.right();
 
-                    Set<ExprId> aOutputExprIdSet = a.getOutputExprIdSet();
-
-                    List<Expression> hashJoinConjuncts = 
topSemiJoin.getHashJoinConjuncts();
-
-                    boolean lasscom = false;
-                    for (Expression hashJoinConjunct : hashJoinConjuncts) {
-                        Set<ExprId> usedSlotExprIdSet = 
hashJoinConjunct.getInputSlotExprIds();
-                        lasscom = Utils.isIntersecting(usedSlotExprIdSet, 
aOutputExprIdSet) || lasscom;
+                    // push topSemiJoin down project, so we need replace 
conjuncts by project.
+                    Pair<List<Expression>, List<Expression>> conjuncts = 
replaceConjuncts(topSemiJoin, project);
+                    Set<ExprId> conjunctsIds = 
Stream.concat(conjuncts.first.stream(), conjuncts.second.stream())
+                            .flatMap(expr -> 
expr.getInputSlotExprIds().stream()).collect(Collectors.toSet());
+                    ContainsType containsType = containsChildren(conjunctsIds, 
a.getOutputExprIdSet(),
+                            b.getOutputExprIdSet());
+                    if (containsType == ContainsType.ALL) {
+                        return null;
                     }
 
-                    if (lasscom) {
+                    if (containsType == ContainsType.LEFT) {
                         /*-
                          *     topSemiJoin                    project
                          *      /     \                         |
@@ -92,22 +96,24 @@ public class SemiJoinLogicalJoinTransposeProject extends 
OneExplorationRuleFacto
                          *   /    \                    /      \
                          *  A      B                  A        C
                          */
+                        // Preconditions.checkState(bottomJoin.getJoinType() 
!= JoinType.RIGHT_OUTER_JOIN);
                         if (bottomJoin.getJoinType() == 
JoinType.RIGHT_OUTER_JOIN) {
                             // when bottom join is right outer join, we change 
it to inner join
                             // if we want to do this trans. However, we do not 
allow different logical properties
                             // in one group. So we need to change it to inner 
join in rewrite step.
-                            return topSemiJoin;
+                            return null;
                         }
                         LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = 
new LogicalJoin<>(
-                                topSemiJoin.getJoinType(), 
topSemiJoin.getHashJoinConjuncts(),
-                                topSemiJoin.getOtherJoinConjuncts(), 
JoinHint.NONE, a, c);
+                                topSemiJoin.getJoinType(), conjuncts.first, 
conjuncts.second, JoinHint.NONE, a, c);
 
                         LogicalJoin<Plan, Plan> newTopJoin = new 
LogicalJoin<>(bottomJoin.getJoinType(),
                                 bottomJoin.getHashJoinConjuncts(), 
bottomJoin.getOtherJoinConjuncts(),
-                                JoinHint.NONE,
-                                newBottomSemiJoin, b);
-                        return JoinReorderUtils.projectOrSelf(new 
ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
+                                JoinHint.NONE, newBottomSemiJoin, b);
+                        return project.withChildren(newTopJoin);
                     } else {
+                        if (leftDeep) {
+                            return null;
+                        }
                         /*-
                          *     topSemiJoin                  project
                          *       /     \                       |
@@ -121,40 +127,49 @@ public class SemiJoinLogicalJoinTransposeProject extends 
OneExplorationRuleFacto
                             // when bottom join is left outer join, we change 
it to inner join
                             // if we want to do this trans. However, we do not 
allow different logical properties
                             // in one group. So we need to change it to inner 
join in rewrite step.
-                            return topSemiJoin;
+                            return null;
                         }
                         LogicalJoin<GroupPlan, GroupPlan> newBottomSemiJoin = 
new LogicalJoin<>(
-                                topSemiJoin.getJoinType(), 
topSemiJoin.getHashJoinConjuncts(),
-                                topSemiJoin.getOtherJoinConjuncts(), 
JoinHint.NONE, b, c);
+                                topSemiJoin.getJoinType(), conjuncts.first, 
conjuncts.second, JoinHint.NONE, b, c);
 
                         LogicalJoin<Plan, Plan> newTopJoin = new 
LogicalJoin<>(bottomJoin.getJoinType(),
                                 bottomJoin.getHashJoinConjuncts(), 
bottomJoin.getOtherJoinConjuncts(),
-                                JoinHint.NONE,
-                                a, newBottomSemiJoin);
-                        return JoinReorderUtils.projectOrSelf(new 
ArrayList<>(topSemiJoin.getOutput()), newTopJoin);
+                                JoinHint.NONE, a, newBottomSemiJoin);
+                        return project.withChildren(newTopJoin);
                     }
                 
}).toRule(RuleType.LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT);
     }
 
-    // project of bottomJoin just return A OR B, else return false.
-    private boolean conditionChecker(
-            LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, 
GroupPlan> topSemiJoin) {
-        List<Expression> hashJoinConjuncts = 
topSemiJoin.getHashJoinConjuncts();
+    private Pair<List<Expression>, List<Expression>> 
replaceConjuncts(LogicalJoin<? extends Plan, ? extends Plan> join,
+            LogicalProject<? extends Plan> project) {
+        Map<ExprId, Slot> outputToInput = new HashMap<>();
+        for (NamedExpression outputExpr : project.getProjects()) {
+            Set<Slot> usedSlots = outputExpr.getInputSlots();
+            Preconditions.checkState(usedSlots.size() == 1);
+            Slot inputSlot = usedSlots.iterator().next();
+            outputToInput.put(outputExpr.getExprId(), inputSlot);
+        }
+        List<Expression> topHashConjuncts =
+                
JoinReorderUtils.replaceJoinConjuncts(join.getHashJoinConjuncts(), 
outputToInput);
+        List<Expression> topOtherConjuncts =
+                
JoinReorderUtils.replaceJoinConjuncts(join.getOtherJoinConjuncts(), 
outputToInput);
+        return Pair.of(topHashConjuncts, topOtherConjuncts);
+    }
 
-        List<Slot> aOutput = topSemiJoin.left().child().left().getOutput();
-        List<Slot> bOutput = topSemiJoin.left().child().right().getOutput();
+    enum ContainsType {
+        LEFT, RIGHT, ALL
+    }
 
-        boolean hashContainsA = false;
-        boolean hashContainsB = false;
-        for (Expression hashJoinConjunct : hashJoinConjuncts) {
-            Set<Slot> usedSlot = 
hashJoinConjunct.collect(Slot.class::isInstance);
-            hashContainsA = Utils.isIntersecting(usedSlot, aOutput) || 
hashContainsA;
-            hashContainsB = Utils.isIntersecting(usedSlot, bOutput) || 
hashContainsB;
-        }
-        if (leftDeep && hashContainsB) {
-            return false;
+    private ContainsType containsChildren(Set<ExprId> conjunctsExprIdSet, 
Set<ExprId> left, Set<ExprId> right) {
+        boolean containsLeft = Utils.isIntersecting(conjunctsExprIdSet, left);
+        boolean containsRight = Utils.isIntersecting(conjunctsExprIdSet, 
right);
+        Preconditions.checkState(containsLeft || containsRight, "join output 
must contain child");
+        if (containsLeft && containsRight) {
+            return ContainsType.ALL;
+        } else if (containsLeft) {
+            return ContainsType.LEFT;
+        } else {
+            return ContainsType.RIGHT;
         }
-        Preconditions.checkState(hashContainsA || hashContainsB, "join output 
must contain child");
-        return !(hashContainsA && hashContainsB);
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
index 72018b914b..4c187674ef 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java
@@ -130,7 +130,6 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, 
RIGHT_CHILD_TYPE extends
         return otherJoinConjuncts;
     }
 
-    @Override
     public List<Expression> getHashJoinConjuncts() {
         return hashJoinConjuncts;
     }


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to