morrySnow commented on code in PR #11812:
URL: https://github.com/apache/doris/pull/11812#discussion_r949340541


##########
fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java:
##########
@@ -147,6 +148,7 @@ private void analyze() {
     private void rewrite() {
         new NormalizeExpressionRulesJob(cascadesContext).execute();
         new JoinReorderRulesJob(cascadesContext).execute();
+        new FindHashConditionForJoinJob(cascadesContext).execute();

Review Comment:
   i think we should execute this rule with predicate push down rule in same 
time, so we should add this rule to PredicatePushDownRulesJob



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java:
##########
@@ -369,13 +368,16 @@ public PlanFragment 
visitPhysicalHashJoin(PhysicalHashJoin<Plan, Plan> hashJoin,
         if (JoinUtils.shouldNestedLoopJoin(hashJoin)) {
             throw new RuntimeException("Physical hash join could not execute 
without equal join condition.");
         } else {
-            Expression eqJoinExpression = hashJoin.getCondition().get();
-            List<Expr> execEqConjunctList = 
ExpressionUtils.extractConjunction(eqJoinExpression).stream()
-                    .map(EqualTo.class::cast)
-                    .map(e -> swapEqualToForChildrenOrder(e, 
hashJoin.left().getOutput()))
-                    .map(e -> ExpressionTranslator.translate(e, context))
-                    .collect(Collectors.toList());
-
+            //TODO: after we apply rule FindHashConditionForJoin,
+            // we could get execEqConjunctList by hashjoin.getJoinPredicates() 
directly

Review Comment:
   we need to fix this TODO in this PR, and add other conjuncts into join node 
too



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java:
##########
@@ -92,9 +93,10 @@ public List<Rule> buildRules() {
             RuleType.BINDING_JOIN_SLOT.build(
                 logicalJoin().thenApply(ctx -> {
                     LogicalJoin<GroupPlan, GroupPlan> join = ctx.root;
-                    Optional<Expression> cond = join.getCondition()
+                    Optional<Expression> cond = join.getOtherJoinCondition()
                             .map(expr -> bind(expr, join.children(), join, 
ctx.cascadesContext));
-                    return new LogicalJoin<>(join.getJoinType(), cond, 
join.left(), join.right());
+                    return new LogicalJoin<>(join.getJoinType(),
+                            new ArrayList<Expression>(), cond, join.left(), 
join.right());

Review Comment:
   ditto



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java:
##########
@@ -630,6 +630,7 @@ public LogicalPlan visitFromClause(FromClauseContext ctx) {
                 left = (left == null) ? right :
                         new LogicalJoin<>(
                                 JoinType.CROSS_JOIN,
+                                new ArrayList<>(),

Review Comment:
   use ImmutableLists.of() or Collections.emptyList() instead



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java:
##########
@@ -206,6 +209,7 @@ public LogicalJoin newTopJoin() {
 
         return new LogicalJoin(
                 topJoin.getJoinType(),
+                topJoin.getHashJoinPredicates(),

Review Comment:
   ditto



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java:
##########
@@ -814,7 +815,8 @@ private LogicalPlan withJoinRelations(LogicalPlan input, 
RelationContext ctx) {
                 condition = getExpression(joinCriteria.booleanExpression());
             }
 
-            last = new LogicalJoin<>(joinType, Optional.ofNullable(condition), 
last, plan(join.relationPrimary()));
+            last = new LogicalJoin<>(joinType, new ArrayList<Expression>(),

Review Comment:
   ditto



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java:
##########
@@ -191,6 +193,7 @@ public LogicalJoin newProjectTopJoin() {
 
         return new LogicalJoin<>(
                 topJoin.getJoinType(),
+                topJoin.getHashJoinPredicates(),

Review Comment:
   ditto



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java:
##########
@@ -106,15 +107,25 @@ private class JoinExpressionRewrite extends 
OneRewriteRuleFactory {
         @Override
         public Rule build() {
             return logicalJoin().then(join -> {
-                Optional<Expression> condition = join.getCondition();
-                if (!condition.isPresent()) {
+                List<Expression> hashJoinPredicates = 
join.getHashJoinPredicates();
+                Optional<Expression> otherJoinCondition = 
join.getOtherJoinCondition();
+                if (!otherJoinCondition.isPresent() && 
hashJoinPredicates.isEmpty()) {
                     return join;
                 }
-                Expression newCondition = rewriter.rewrite(condition.get());
-                if (newCondition.equals(condition.get())) {
+                List<Expression> rewriteHashJoinPredicates = new ArrayList<>();
+                boolean joinPredicatesChanged = false;
+                for (Expression expr : hashJoinPredicates) {
+                    Expression newExpr = rewriter.rewrite(expr);
+                    joinPredicatesChanged = joinPredicatesChanged || 
newExpr.equals(expr);
+                    rewriteHashJoinPredicates.add(newExpr);
+                }

Review Comment:
   use stream api could brief this code block



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java:
##########
@@ -106,15 +107,25 @@ private class JoinExpressionRewrite extends 
OneRewriteRuleFactory {
         @Override
         public Rule build() {
             return logicalJoin().then(join -> {
-                Optional<Expression> condition = join.getCondition();
-                if (!condition.isPresent()) {
+                List<Expression> hashJoinPredicates = 
join.getHashJoinPredicates();
+                Optional<Expression> otherJoinCondition = 
join.getOtherJoinCondition();
+                if (!otherJoinCondition.isPresent() && 
hashJoinPredicates.isEmpty()) {
                     return join;
                 }
-                Expression newCondition = rewriter.rewrite(condition.get());
-                if (newCondition.equals(condition.get())) {
+                List<Expression> rewriteHashJoinPredicates = new ArrayList<>();

Review Comment:
   ```suggestion
                   List<Expression> rewriteHashJoinPredicates = 
Lists.newArrayLists();
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java:
##########
@@ -43,7 +45,9 @@
         extends LogicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> implements 
Join {
 
     private final JoinType joinType;
-    private final Optional<Expression> condition;
+    private final Optional<Expression> otherJoinCondition;
+
+    private final List<Expression> hashJoinPredicates;

Review Comment:
   we need has uniform name, so use condition or predicates uniformly. and BTW, 
remove blank line L49 or add a new blank line between L47 and L48



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java:
##########
@@ -111,4 +112,26 @@ public boolean equals(Object o) {
     public int hashCode() {
         return 0;
     }
+
+    /**
+     * get the conjuct list from expr
+     * for example:
+     * a=1 and f(b)=f(c) and (d=e or x=y) => {a=1, f(b)=f(c), (d=e or x=y)}, 
list size = 3
+     * a=1 or b=1 => {a=1 or b=1}, list size = 1.
+     * @return conjuct list
+     */
+    public List<Expression> getConjucts() {

Review Comment:
   maybe u could use `ExpressionUtils.extractConjunction` directly



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java:
##########
@@ -132,7 +167,10 @@ public List<Slot> computeOutput(Plan leftInput, Plan 
rightInput) {
     @Override
     public String toString() {
         StringBuilder sb = new StringBuilder("LogicalJoin (").append(joinType);
-        condition.ifPresent(expression -> sb.append(", ").append(expression));
+        sb.append(" [");
+        hashJoinPredicates.stream().map(expr -> sb.append(" 
").append(expr)).collect(Collectors.toList());
+        sb.append(" ]");
+        otherJoinCondition.ifPresent(expression -> sb.append(", 
").append(expression));

Review Comment:
   we could print them separately for easy reading



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java:
##########
@@ -78,11 +90,27 @@ public boolean equals(Object o) {
             return false;
         }
         AbstractPhysicalJoin that = (AbstractPhysicalJoin) o;
-        return joinType == that.joinType && Objects.equals(condition, 
that.condition);
+        return joinType == that.joinType && Objects.equals(otherJoinCondition, 
that.otherJoinCondition);

Review Comment:
   why not compare equal join condition? in Nereids NestedLoopJoin could do 
equal join, currently we don't do that just because be cannot support it



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java:
##########
@@ -145,12 +183,16 @@ public boolean equals(Object o) {
             return false;
         }
         LogicalJoin that = (LogicalJoin) o;
-        return joinType == that.joinType && Objects.equals(condition, 
that.condition);
+
+        return joinType == that.joinType
+                && that.getHashJoinPredicates().containsAll(hashJoinPredicates)
+                && hashJoinPredicates.containsAll(that.getHashJoinPredicates())
+                && Objects.equals(otherJoinCondition, that.otherJoinCondition);

Review Comment:
   add a todo to use semantic rather than equals directly



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java:
##########
@@ -51,14 +55,14 @@ public static boolean onlyShuffle(AbstractPhysicalJoin 
join) {
      */
     public static List<EqualTo> getEqualTo(AbstractPhysicalJoin<Plan, Plan> 
join) {
         List<EqualTo> eqConjuncts = Lists.newArrayList();
-        if (!join.getCondition().isPresent()) {
+        if (!join.getOtherJoinCondition().isPresent()) {
             return eqConjuncts;
         }
 
         List<SlotReference> leftSlots = 
Utils.getOutputSlotReference(join.left());
         List<SlotReference> rightSlots = 
Utils.getOutputSlotReference(join.right());
 
-        Expression onCondition = join.getCondition().get();
+        Expression onCondition = join.getOtherJoinCondition().get();

Review Comment:
   ditto



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java:
##########
@@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> 
leftSlots, List<SlotReferen
                 || (leftSlotsSet.containsAll(rightUsed) && 
rightSlotsSet.containsAll(leftUsed));
     }
 
+    private static class JoinSlotCoverageChecker {
+        HashSet<SlotReference> left;
+        HashSet<ExprId> leftExprIds;
+        HashSet<SlotReference> right;
+        HashSet<ExprId> rightExprIds;
+
+        JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> 
right) {
+            this.left = new HashSet<>(left);
+            leftExprIds = new 
HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList()));

Review Comment:
   u could collect to set directly



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java:
##########
@@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> 
leftSlots, List<SlotReferen
                 || (leftSlotsSet.containsAll(rightUsed) && 
rightSlotsSet.containsAll(leftUsed));
     }
 
+    private static class JoinSlotCoverageChecker {
+        HashSet<SlotReference> left;
+        HashSet<ExprId> leftExprIds;
+        HashSet<SlotReference> right;
+        HashSet<ExprId> rightExprIds;
+
+        JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> 
right) {
+            this.left = new HashSet<>(left);
+            leftExprIds = new 
HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList()));
+            this.right = new HashSet<>(right);
+            rightExprIds = new 
HashSet<>(right.stream().map(SlotReference::getExprId).collect(Collectors.toList()));
+        }
+
+        boolean isCoveredByLeftSlots(List<SlotReference> slots) {
+            boolean covered = left.containsAll(slots);
+            if (covered) {
+                return true;
+            }
+            List<ExprId> slotsExprIds = slots.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());

Review Comment:
   don't need to collect to list, use `allMatch(leftExprIds::contains)` instead



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java:
##########
@@ -51,14 +55,14 @@ public static boolean onlyShuffle(AbstractPhysicalJoin 
join) {
      */
     public static List<EqualTo> getEqualTo(AbstractPhysicalJoin<Plan, Plan> 
join) {
         List<EqualTo> eqConjuncts = Lists.newArrayList();
-        if (!join.getCondition().isPresent()) {
+        if (!join.getOtherJoinCondition().isPresent()) {

Review Comment:
   we need to return all equal conditions here, so we should to check 
getHashPredicates here



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java:
##########
@@ -124,6 +240,7 @@ public static Pair<List<SlotReference>, 
List<SlotReference>> getOnClauseUsedSlot
 
     public static boolean shouldNestedLoopJoin(Join join) {
         JoinType joinType = join.getJoinType();
-        return (joinType.isInnerJoin() && !join.getCondition().isPresent()) || 
joinType.isCrossJoin();
+        //return (joinType.isInnerJoin() && 
!join.getOnClauseCondition().isPresent()) || joinType.isCrossJoin();

Review Comment:
   please remove useless comment



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java:
##########
@@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> 
leftSlots, List<SlotReferen
                 || (leftSlotsSet.containsAll(rightUsed) && 
rightSlotsSet.containsAll(leftUsed));
     }
 
+    private static class JoinSlotCoverageChecker {
+        HashSet<SlotReference> left;
+        HashSet<ExprId> leftExprIds;
+        HashSet<SlotReference> right;
+        HashSet<ExprId> rightExprIds;
+
+        JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> 
right) {
+            this.left = new HashSet<>(left);
+            leftExprIds = new 
HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList()));
+            this.right = new HashSet<>(right);
+            rightExprIds = new 
HashSet<>(right.stream().map(SlotReference::getExprId).collect(Collectors.toList()));
+        }
+
+        boolean isCoveredByLeftSlots(List<SlotReference> slots) {
+            boolean covered = left.containsAll(slots);
+            if (covered) {
+                return true;
+            }
+            List<ExprId> slotsExprIds = slots.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());
+            return leftExprIds.containsAll(slotsExprIds);
+        }
+
+        boolean isCoveredByRightSlots(List<SlotReference> slots) {
+            boolean covered = right.containsAll(slots);
+            if (covered) {
+                return true;
+            }
+            List<ExprId> slotsExprIds = slots.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());
+            return rightExprIds.containsAll(slotsExprIds);
+        }
+
+        /**
+         *  consider following cases:
+         *  1# A=1 => not for hash table
+         *  2# t1.a=t2.a + t2.b => hash table
+         *  3# t1.a=t1.a + t2.b => not for hash table
+         *  4# t1.a=t2.a or t1.b=t2.b not for hash table
+         *  5# t1.a > 1 not for hash table
+         * @param equalTo a conjunct in on clause condition
+         * @return true if the equal can be used as hash join condition
+         */
+        boolean isHashJoinCondition(EqualTo equalTo) {
+            List<SlotReference> equalLeft =  
equalTo.left().collect(SlotReference.class::isInstance);
+            if (equalLeft.isEmpty()) {
+                return false;
+            }
+
+            List<SlotReference> equalRight = 
equalTo.right().collect(SlotReference.class::isInstance);
+            if (equalRight.isEmpty()) {
+                return false;
+            }
+
+            List<ExprId> equalLeftExprIds = equalLeft.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());
+
+            List<ExprId> equalRightExprIds = equalRight.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());
+            return leftExprIds.containsAll(equalLeftExprIds) && 
rightExprIds.containsAll(equalRightExprIds)
+                    || left.containsAll(equalLeft) && 
right.containsAll(equalRight)
+                    || leftExprIds.containsAll(equalRightExprIds) && 
rightExprIds.containsAll(equalLeftExprIds)
+                    || right.containsAll(equalLeft) && 
left.containsAll(equalRight);
+        }
+    }
+
+    /**
+     * collect expressions from on clause, which could be used to build hash 
table
+     * @param join join node
+     * @return pair of expressions, for hash table or not.
+     */
+    public static Pair<List<Expression>, List<Expression>> 
extractExpressionForHashTable(LogicalJoin join,
+            Optional<Expression> onConditions) {
+        if (onConditions.isPresent()) {
+            List<Expression> onExprs = 
ExpressionUtils.extractConjunction(onConditions.get());
+            List<SlotReference> leftSlots = 
Utils.getOutputSlotReference((Plan) (join.left()));
+            List<SlotReference> rightSlots = 
Utils.getOutputSlotReference((Plan) (join.right()));

Review Comment:
   Override left and right method in interface BinaryPlan and then u could 
remove type cast here



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java:
##########
@@ -51,11 +54,27 @@ public class MultiJoin extends PlanVisitor<Void, Void> {
      *    A      B
      */
     public final List<Plan> joinInputs = new ArrayList<>();
-    public final List<Expression> conjuncts = new ArrayList<>();
+    public final List<Expression> conjunctsForAllHashJoins = new ArrayList<>();
 
+    private List<Expression> conjunctsKeepInFilter = new ArrayList<>();

Review Comment:
   final



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java:
##########
@@ -98,19 +213,20 @@ public static Pair<List<SlotReference>, 
List<SlotReference>> getOnClauseUsedSlot
 
         List<SlotReference> leftSlots = 
Utils.getOutputSlotReference(join.left());
         List<SlotReference> rightSlots = 
Utils.getOutputSlotReference(join.right());
-        List<EqualTo> equalToList = getEqualTo(join);
-
+        List<EqualTo> equalToList = join.getHashJoinPredicates().stream()
+                .map(e -> (EqualTo) e).collect(Collectors.toList());
+        JoinSlotCoverageChecker checker = new 
JoinSlotCoverageChecker(leftSlots, rightSlots);
         for (EqualTo equalTo : equalToList) {
             List<SlotReference> leftOnSlots = 
equalTo.left().collect(SlotReference.class::isInstance);
             List<SlotReference> rightOnSlots = 
equalTo.right().collect(SlotReference.class::isInstance);
 
-            if (new HashSet<>(leftSlots).containsAll(leftOnSlots)
-                    && new HashSet<>(rightSlots).containsAll(rightOnSlots)) {
+            if (checker.isCoveredByLeftSlots(leftOnSlots)
+                    && checker.isCoveredByRightSlots(rightOnSlots)) {
                 // TODO: need rethink about `.get(0)`
                 childSlots.first.add(leftOnSlots.get(0));

Review Comment:
   we need addAll



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java:
##########
@@ -74,25 +93,30 @@ public Plan reorderJoinsAccordingToConditions() {
      */
     private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, 
List<Expression> conjuncts) {
         if (joinInputs.size() == 2) {
-            Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), 
joinInputs.get(1));
-            Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, 
joinOutput);
-            List<Expression> joinConditions = split.get(true);
-            List<Expression> nonJoinConditions = split.get(false);
-
+            //Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), 
joinInputs.get(1));
+            //Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, 
joinOutput);
+            //List<Expression> joinConditions = split.get(true);
+            //List<Expression> nonJoinConditions = split.get(false);
+            Pair<List<Expression>, List<Expression>> pair = 
JoinUtils.extractExpressionForHashTable(
+                    
joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
+                    
joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
+                    conjuncts);
+            List<Expression> joinConditions = pair.first;
+            conjunctsKeepInFilter = pair.second;

Review Comment:
   ```suggestion
               conjunctsKeepInFilter.addAll(pair.second);
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java:
##########
@@ -74,25 +93,30 @@ public Plan reorderJoinsAccordingToConditions() {
      */
     private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, 
List<Expression> conjuncts) {
         if (joinInputs.size() == 2) {
-            Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), 
joinInputs.get(1));
-            Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, 
joinOutput);
-            List<Expression> joinConditions = split.get(true);
-            List<Expression> nonJoinConditions = split.get(false);
-
+            //Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), 
joinInputs.get(1));
+            //Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, 
joinOutput);
+            //List<Expression> joinConditions = split.get(true);
+            //List<Expression> nonJoinConditions = split.get(false);
+            Pair<List<Expression>, List<Expression>> pair = 
JoinUtils.extractExpressionForHashTable(
+                    
joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
+                    
joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
+                    conjuncts);
+            List<Expression> joinConditions = pair.first;
+            conjunctsKeepInFilter = pair.second;
             LogicalJoin join;
             if (joinConditions.isEmpty()) {
-                join = new LogicalJoin(JoinType.CROSS_JOIN, Optional.empty(), 
joinInputs.get(0), joinInputs.get(1));
+                join = new LogicalJoin(JoinType.CROSS_JOIN,
+                        new ArrayList<>(),
+                        Optional.empty(),
+                        joinInputs.get(0), joinInputs.get(1));
             } else {
                 join = new LogicalJoin(JoinType.INNER_JOIN,
+                        new ArrayList<>(),
                         Optional.of(ExpressionUtils.and(joinConditions)),

Review Comment:
   why add to other conditon?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java:
##########
@@ -165,6 +166,7 @@ private Pair<List<NamedExpression>, List<NamedExpression>> 
getProjectExprs() {
     private LogicalJoin<GroupPlan, GroupPlan> newBottomJoin() {
         return new LogicalJoin(
                 bottomJoin.getJoinType(),
+                bottomJoin.getHashJoinPredicates(),

Review Comment:
   we need according to new join children node to generate appropriate 
`hashJoinPredicates` just like what `newBottomJoinOnCondition` do



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java:
##########
@@ -185,7 +219,11 @@ public Void visitLogicalJoin(LogicalJoin<Plan, Plan> join, 
Void context) {
         join.left().accept(this, context);
         join.right().accept(this, context);
 
-        join.getCondition().ifPresent(cond -> 
conjuncts.addAll(ExpressionUtils.extractConjunction(cond)));
+        conjunctsForAllHashJoins.addAll(join.getHashJoinPredicates());
+        if (join.getOtherJoinCondition().isPresent()) {
+            
conjunctsForAllHashJoins.addAll(ExpressionUtils.extractConjunction(join.getOtherJoinCondition().get()));

Review Comment:
   this name is not very good since it could have other conditions



##########
fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java:
##########
@@ -190,8 +191,8 @@ public void pushDownPredicateIntoScanTest4() {
         Expression whereCondition = ExpressionUtils.and(whereCondition1, 
whereCondition2, whereCondition3,
                 whereCondition4);
 
-        Plan join = new LogicalJoin(JoinType.INNER_JOIN, Optional.empty(), 
rStudent, rScore);
-        Plan join1 = new LogicalJoin(JoinType.INNER_JOIN, Optional.empty(), 
join, rCourse);
+        Plan join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), 
Optional.empty(), rStudent, rScore);
+        Plan join1 = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), 
Optional.empty(), join, rCourse);

Review Comment:
   use immutableList



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java:
##########
@@ -106,15 +107,25 @@ private class JoinExpressionRewrite extends 
OneRewriteRuleFactory {
         @Override
         public Rule build() {
             return logicalJoin().then(join -> {
-                Optional<Expression> condition = join.getCondition();
-                if (!condition.isPresent()) {
+                List<Expression> hashJoinPredicates = 
join.getHashJoinPredicates();
+                Optional<Expression> otherJoinCondition = 
join.getOtherJoinCondition();
+                if (!otherJoinCondition.isPresent() && 
hashJoinPredicates.isEmpty()) {
                     return join;
                 }
-                Expression newCondition = rewriter.rewrite(condition.get());
-                if (newCondition.equals(condition.get())) {
+                List<Expression> rewriteHashJoinPredicates = new ArrayList<>();
+                boolean joinPredicatesChanged = false;
+                for (Expression expr : hashJoinPredicates) {
+                    Expression newExpr = rewriter.rewrite(expr);
+                    joinPredicatesChanged = joinPredicatesChanged || 
newExpr.equals(expr);

Review Comment:
   ```suggestion
                       joinPredicatesChanged = joinPredicatesChanged || 
!newExpr.equals(expr);
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java:
##########
@@ -124,20 +148,27 @@ private Plan reorderJoinsAccordingToConditions(List<Plan> 
joinInputs, List<Expre
         }).findFirst();
 
         Plan right = rightOpt.orElseGet(() -> candidate.get(1));
-        Set<Slot> joinOutput = getJoinOutput(left, right);
-        Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, 
joinOutput);
-        List<Expression> joinConditions = split.get(true);
-        List<Expression> nonJoinConditions = split.get(false);
-
-        Optional<Expression> cond;
+        Pair<List<Expression>, List<Expression>> pair = 
JoinUtils.extractExpressionForHashTable(
+                
joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
+                
joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()),
+                conjuncts);
+        //Set<Slot> joinOutput = getJoinOutput(left, right);
+        //Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, 
joinOutput);
+        //List<Expression> joinConditions = split.get(true);
+        //List<Expression> nonJoinConditions = split.get(false);
+        List<Expression> joinConditions = pair.first;
+        List<Expression> nonJoinConditions = pair.second;

Review Comment:
   we lose nonJoinConditions?



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java:
##########
@@ -106,15 +107,25 @@ private class JoinExpressionRewrite extends 
OneRewriteRuleFactory {
         @Override
         public Rule build() {
             return logicalJoin().then(join -> {
-                Optional<Expression> condition = join.getCondition();
-                if (!condition.isPresent()) {
+                List<Expression> hashJoinPredicates = 
join.getHashJoinPredicates();
+                Optional<Expression> otherJoinCondition = 
join.getOtherJoinCondition();
+                if (!otherJoinCondition.isPresent() && 
hashJoinPredicates.isEmpty()) {

Review Comment:
   i think we can let join Override getExpressions interface and return all 
expression in one method



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java:
##########
@@ -51,11 +54,27 @@ public class MultiJoin extends PlanVisitor<Void, Void> {
      *    A      B
      */
     public final List<Plan> joinInputs = new ArrayList<>();
-    public final List<Expression> conjuncts = new ArrayList<>();
+    public final List<Expression> conjunctsForAllHashJoins = new ArrayList<>();
 
+    private List<Expression> conjunctsKeepInFilter = new ArrayList<>();
+
+    /**
+     * reorderJoinsAccordingToConditions
+     * @return join or filter
+     */
     public Plan reorderJoinsAccordingToConditions() {
         Preconditions.checkArgument(joinInputs.size() >= 2);
-        return reorderJoinsAccordingToConditions(joinInputs, conjuncts);
+        Plan joinRoot = reorderJoinsAccordingToConditions(joinInputs, 
conjunctsForAllHashJoins);
+        if (!conjunctsKeepInFilter.isEmpty()) {
+            LogicalFilter filter = new LogicalFilter(
+                    ExpressionUtils.and(conjunctsKeepInFilter),
+                    joinRoot
+            );
+            return filter;
+        } else {
+            return joinRoot;
+        }
+

Review Comment:
   ```suggestion
           Plan root = reorderJoinsAccordingToConditions(joinInputs, 
conjunctsForAllHashJoins);
           if (!conjunctsKeepInFilter.isEmpty()) {
               root = new 
LogicalFilter(ExpressionUtils.and(conjunctsKeepInFilter), root);
           }
           return root;
   
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java:
##########
@@ -54,17 +58,21 @@
      * @param joinType logical type for join
      */
     public LogicalJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, 
RIGHT_CHILD_TYPE rightChild) {
-        this(joinType, Optional.empty(), Optional.empty(), Optional.empty(), 
leftChild, rightChild);
+        this(joinType, new ArrayList<Expression>(),

Review Comment:
   use ImmutableList instead of ArrayList in all immutable class



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java:
##########
@@ -28,5 +29,9 @@
 public interface Join {
     JoinType getJoinType();
 
-    Optional<Expression> getCondition();
+    List<Expression> getHashJoinPredicates();
+
+    Optional<Expression> getOtherJoinCondition();
+
+    Optional<Expression> getOnClauseCondition();

Review Comment:
   ```suggestion
       List<Expression> getHashConditions();
   
       Optional<Expression> getOtherConditions();
   
       Optional<Expression> getConditions();
   ```



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java:
##########
@@ -165,11 +196,14 @@ public Void visit(Plan plan, Void context) {
         return null;
     }
 
+    //TODO: add a rule to push filter condition down to join if acceptable.
+    // We can not simply add filter predicates into join conditions for 
outer/anti join.
+    // It is better to add another rule to push down acceptable filter 
conditions to join.

Review Comment:
   i think PushPredicateThroughJoin already do that, but we need a rule to push 
down join other condition to its children



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java:
##########
@@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> 
leftSlots, List<SlotReferen
                 || (leftSlotsSet.containsAll(rightUsed) && 
rightSlotsSet.containsAll(leftUsed));
     }
 
+    private static class JoinSlotCoverageChecker {
+        HashSet<SlotReference> left;
+        HashSet<ExprId> leftExprIds;
+        HashSet<SlotReference> right;
+        HashSet<ExprId> rightExprIds;
+
+        JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> 
right) {
+            this.left = new HashSet<>(left);
+            leftExprIds = new 
HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList()));
+            this.right = new HashSet<>(right);
+            rightExprIds = new 
HashSet<>(right.stream().map(SlotReference::getExprId).collect(Collectors.toList()));
+        }
+
+        boolean isCoveredByLeftSlots(List<SlotReference> slots) {
+            boolean covered = left.containsAll(slots);

Review Comment:
   i think just use ExprId is better



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java:
##########
@@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> 
leftSlots, List<SlotReferen
                 || (leftSlotsSet.containsAll(rightUsed) && 
rightSlotsSet.containsAll(leftUsed));
     }
 
+    private static class JoinSlotCoverageChecker {
+        HashSet<SlotReference> left;
+        HashSet<ExprId> leftExprIds;
+        HashSet<SlotReference> right;
+        HashSet<ExprId> rightExprIds;
+
+        JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> 
right) {
+            this.left = new HashSet<>(left);
+            leftExprIds = new 
HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList()));
+            this.right = new HashSet<>(right);
+            rightExprIds = new 
HashSet<>(right.stream().map(SlotReference::getExprId).collect(Collectors.toList()));
+        }
+
+        boolean isCoveredByLeftSlots(List<SlotReference> slots) {
+            boolean covered = left.containsAll(slots);
+            if (covered) {
+                return true;
+            }
+            List<ExprId> slotsExprIds = slots.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());
+            return leftExprIds.containsAll(slotsExprIds);
+        }
+
+        boolean isCoveredByRightSlots(List<SlotReference> slots) {
+            boolean covered = right.containsAll(slots);
+            if (covered) {
+                return true;
+            }
+            List<ExprId> slotsExprIds = slots.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());
+            return rightExprIds.containsAll(slotsExprIds);

Review Comment:
   ditto



##########
fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java:
##########
@@ -98,14 +99,18 @@ public void testLogicalJoin(@Mocked Plan left, @Mocked Plan 
right) {
                 new SlotReference("a", new BigIntType(), true, 
Lists.newArrayList()),
                 new SlotReference("b", new BigIntType(), true, 
Lists.newArrayList())
         );
-        LogicalJoin innerJoin1 = new LogicalJoin(JoinType.INNER_JOIN, 
Optional.of(condition1), left, right);
+        LogicalJoin innerJoin1 = new LogicalJoin(JoinType.INNER_JOIN, 
Lists.newArrayList(condition1),
+                Optional.empty(), left, right);
+        boolean bool = innerJoin1.equals(innerJoin);
+        System.out.println(bool);

Review Comment:
   remove useless print



##########
fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java:
##########
@@ -73,9 +73,11 @@ public Pair<LogicalJoin, LogicalJoin> testJoinLAsscom(
          */
         Assertions.assertEquals(3, scans.size());
         LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new 
LogicalJoin<>(JoinType.INNER_JOIN,
-                Optional.of(bottomJoinOnCondition), scans.get(0), 
scans.get(1));
+                Lists.newArrayList(bottomJoinOnCondition),
+                Optional.empty(), scans.get(0), scans.get(1));
         LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, 
LogicalOlapScan> topJoin = new LogicalJoin<>(
-                JoinType.INNER_JOIN, Optional.of(topJoinOnCondition), 
bottomJoin, scans.get(2));
+                JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition),
+                Optional.empty(), bottomJoin, scans.get(2));

Review Comment:
   @jackwener please add Assert for join equal condition in another PR



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java:
##########
@@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> 
leftSlots, List<SlotReferen
                 || (leftSlotsSet.containsAll(rightUsed) && 
rightSlotsSet.containsAll(leftUsed));
     }
 
+    private static class JoinSlotCoverageChecker {
+        HashSet<SlotReference> left;
+        HashSet<ExprId> leftExprIds;
+        HashSet<SlotReference> right;
+        HashSet<ExprId> rightExprIds;
+
+        JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> 
right) {
+            this.left = new HashSet<>(left);
+            leftExprIds = new 
HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList()));
+            this.right = new HashSet<>(right);
+            rightExprIds = new 
HashSet<>(right.stream().map(SlotReference::getExprId).collect(Collectors.toList()));
+        }
+
+        boolean isCoveredByLeftSlots(List<SlotReference> slots) {
+            boolean covered = left.containsAll(slots);
+            if (covered) {
+                return true;
+            }
+            List<ExprId> slotsExprIds = slots.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());
+            return leftExprIds.containsAll(slotsExprIds);
+        }
+
+        boolean isCoveredByRightSlots(List<SlotReference> slots) {
+            boolean covered = right.containsAll(slots);
+            if (covered) {
+                return true;
+            }
+            List<ExprId> slotsExprIds = slots.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());
+            return rightExprIds.containsAll(slotsExprIds);
+        }
+
+        /**
+         *  consider following cases:
+         *  1# A=1 => not for hash table
+         *  2# t1.a=t2.a + t2.b => hash table
+         *  3# t1.a=t1.a + t2.b => not for hash table
+         *  4# t1.a=t2.a or t1.b=t2.b not for hash table
+         *  5# t1.a > 1 not for hash table
+         * @param equalTo a conjunct in on clause condition
+         * @return true if the equal can be used as hash join condition
+         */
+        boolean isHashJoinCondition(EqualTo equalTo) {
+            List<SlotReference> equalLeft =  
equalTo.left().collect(SlotReference.class::isInstance);
+            if (equalLeft.isEmpty()) {
+                return false;
+            }
+
+            List<SlotReference> equalRight = 
equalTo.right().collect(SlotReference.class::isInstance);
+            if (equalRight.isEmpty()) {
+                return false;
+            }
+
+            List<ExprId> equalLeftExprIds = equalLeft.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());
+
+            List<ExprId> equalRightExprIds = equalRight.stream()
+                    
.map(SlotReference::getExprId).collect(Collectors.toList());
+            return leftExprIds.containsAll(equalLeftExprIds) && 
rightExprIds.containsAll(equalRightExprIds)
+                    || left.containsAll(equalLeft) && 
right.containsAll(equalRight)
+                    || leftExprIds.containsAll(equalRightExprIds) && 
rightExprIds.containsAll(equalLeftExprIds)
+                    || right.containsAll(equalLeft) && 
left.containsAll(equalRight);
+        }
+    }
+
+    /**
+     * collect expressions from on clause, which could be used to build hash 
table
+     * @param join join node
+     * @return pair of expressions, for hash table or not.
+     */
+    public static Pair<List<Expression>, List<Expression>> 
extractExpressionForHashTable(LogicalJoin join,
+            Optional<Expression> onConditions) {
+        if (onConditions.isPresent()) {
+            List<Expression> onExprs = 
ExpressionUtils.extractConjunction(onConditions.get());
+            List<SlotReference> leftSlots = 
Utils.getOutputSlotReference((Plan) (join.left()));
+            List<SlotReference> rightSlots = 
Utils.getOutputSlotReference((Plan) (join.right()));
+            return extractExpressionForHashTable(leftSlots, rightSlots, 
onExprs);
+        }
+        return new Pair<>(Lists.newArrayList(), Lists.newArrayList());
+    }
+
+    /**
+     * extract expression
+     * @param leftSlots left child output slots
+     * @param rightSlots right child output slots
+     * @param onConditions conditions to be split
+     * @return pair of hashCondition and otherCondition
+     */
+    public static Pair<List<Expression>, List<Expression>> 
extractExpressionForHashTable(List<SlotReference> leftSlots,
+            List<SlotReference> rightSlots,
+            List<Expression> onConditions) {
+
+        Pair<List<Expression>, List<Expression>> pair = new 
Pair<>(Lists.newArrayList(), Lists.newArrayList());

Review Comment:
   stream group by maybe is a better choice



##########
fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java:
##########
@@ -113,9 +113,11 @@ public void check(Plan plan) {
 
             // check join conditions
             List<String> actualJoinConditions = joins.stream().map(j -> {
-                Optional<Expression> condition = j.getCondition();
+

Review Comment:
   useless blank line



##########
fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java:
##########
@@ -74,25 +93,30 @@ public Plan reorderJoinsAccordingToConditions() {
      */
     private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, 
List<Expression> conjuncts) {
         if (joinInputs.size() == 2) {
-            Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), 
joinInputs.get(1));
-            Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, 
joinOutput);
-            List<Expression> joinConditions = split.get(true);
-            List<Expression> nonJoinConditions = split.get(false);
-
+            //Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), 
joinInputs.get(1));
+            //Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, 
joinOutput);
+            //List<Expression> joinConditions = split.get(true);
+            //List<Expression> nonJoinConditions = split.get(false);

Review Comment:
   remove it



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to