morrySnow commented on code in PR #11812: URL: https://github.com/apache/doris/pull/11812#discussion_r949340541
########## fe/fe-core/src/main/java/org/apache/doris/nereids/NereidsPlanner.java: ########## @@ -147,6 +148,7 @@ private void analyze() { private void rewrite() { new NormalizeExpressionRulesJob(cascadesContext).execute(); new JoinReorderRulesJob(cascadesContext).execute(); + new FindHashConditionForJoinJob(cascadesContext).execute(); Review Comment: i think we should execute this rule with predicate push down rule in same time, so we should add this rule to PredicatePushDownRulesJob ########## fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java: ########## @@ -369,13 +368,16 @@ public PlanFragment visitPhysicalHashJoin(PhysicalHashJoin<Plan, Plan> hashJoin, if (JoinUtils.shouldNestedLoopJoin(hashJoin)) { throw new RuntimeException("Physical hash join could not execute without equal join condition."); } else { - Expression eqJoinExpression = hashJoin.getCondition().get(); - List<Expr> execEqConjunctList = ExpressionUtils.extractConjunction(eqJoinExpression).stream() - .map(EqualTo.class::cast) - .map(e -> swapEqualToForChildrenOrder(e, hashJoin.left().getOutput())) - .map(e -> ExpressionTranslator.translate(e, context)) - .collect(Collectors.toList()); - + //TODO: after we apply rule FindHashConditionForJoin, + // we could get execEqConjunctList by hashjoin.getJoinPredicates() directly Review Comment: we need to fix this TODO in this PR, and add other conjuncts into join node too ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/analysis/BindSlotReference.java: ########## @@ -92,9 +93,10 @@ public List<Rule> buildRules() { RuleType.BINDING_JOIN_SLOT.build( logicalJoin().thenApply(ctx -> { LogicalJoin<GroupPlan, GroupPlan> join = ctx.root; - Optional<Expression> cond = join.getCondition() + Optional<Expression> cond = join.getOtherJoinCondition() .map(expr -> bind(expr, join.children(), join, ctx.cascadesContext)); - return new LogicalJoin<>(join.getJoinType(), cond, join.left(), join.right()); + return new LogicalJoin<>(join.getJoinType(), + new ArrayList<Expression>(), cond, join.left(), join.right()); Review Comment: ditto ########## fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java: ########## @@ -630,6 +630,7 @@ public LogicalPlan visitFromClause(FromClauseContext ctx) { left = (left == null) ? right : new LogicalJoin<>( JoinType.CROSS_JOIN, + new ArrayList<>(), Review Comment: use ImmutableLists.of() or Collections.emptyList() instead ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java: ########## @@ -206,6 +209,7 @@ public LogicalJoin newTopJoin() { return new LogicalJoin( topJoin.getJoinType(), + topJoin.getHashJoinPredicates(), Review Comment: ditto ########## fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java: ########## @@ -814,7 +815,8 @@ private LogicalPlan withJoinRelations(LogicalPlan input, RelationContext ctx) { condition = getExpression(joinCriteria.booleanExpression()); } - last = new LogicalJoin<>(joinType, Optional.ofNullable(condition), last, plan(join.relationPrimary())); + last = new LogicalJoin<>(joinType, new ArrayList<Expression>(), Review Comment: ditto ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java: ########## @@ -191,6 +193,7 @@ public LogicalJoin newProjectTopJoin() { return new LogicalJoin<>( topJoin.getJoinType(), + topJoin.getHashJoinPredicates(), Review Comment: ditto ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java: ########## @@ -106,15 +107,25 @@ private class JoinExpressionRewrite extends OneRewriteRuleFactory { @Override public Rule build() { return logicalJoin().then(join -> { - Optional<Expression> condition = join.getCondition(); - if (!condition.isPresent()) { + List<Expression> hashJoinPredicates = join.getHashJoinPredicates(); + Optional<Expression> otherJoinCondition = join.getOtherJoinCondition(); + if (!otherJoinCondition.isPresent() && hashJoinPredicates.isEmpty()) { return join; } - Expression newCondition = rewriter.rewrite(condition.get()); - if (newCondition.equals(condition.get())) { + List<Expression> rewriteHashJoinPredicates = new ArrayList<>(); + boolean joinPredicatesChanged = false; + for (Expression expr : hashJoinPredicates) { + Expression newExpr = rewriter.rewrite(expr); + joinPredicatesChanged = joinPredicatesChanged || newExpr.equals(expr); + rewriteHashJoinPredicates.add(newExpr); + } Review Comment: use stream api could brief this code block ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java: ########## @@ -106,15 +107,25 @@ private class JoinExpressionRewrite extends OneRewriteRuleFactory { @Override public Rule build() { return logicalJoin().then(join -> { - Optional<Expression> condition = join.getCondition(); - if (!condition.isPresent()) { + List<Expression> hashJoinPredicates = join.getHashJoinPredicates(); + Optional<Expression> otherJoinCondition = join.getOtherJoinCondition(); + if (!otherJoinCondition.isPresent() && hashJoinPredicates.isEmpty()) { return join; } - Expression newCondition = rewriter.rewrite(condition.get()); - if (newCondition.equals(condition.get())) { + List<Expression> rewriteHashJoinPredicates = new ArrayList<>(); Review Comment: ```suggestion List<Expression> rewriteHashJoinPredicates = Lists.newArrayLists(); ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java: ########## @@ -43,7 +45,9 @@ extends LogicalBinary<LEFT_CHILD_TYPE, RIGHT_CHILD_TYPE> implements Join { private final JoinType joinType; - private final Optional<Expression> condition; + private final Optional<Expression> otherJoinCondition; + + private final List<Expression> hashJoinPredicates; Review Comment: we need has uniform name, so use condition or predicates uniformly. and BTW, remove blank line L49 or add a new blank line between L47 and L48 ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Expression.java: ########## @@ -111,4 +112,26 @@ public boolean equals(Object o) { public int hashCode() { return 0; } + + /** + * get the conjuct list from expr + * for example: + * a=1 and f(b)=f(c) and (d=e or x=y) => {a=1, f(b)=f(c), (d=e or x=y)}, list size = 3 + * a=1 or b=1 => {a=1 or b=1}, list size = 1. + * @return conjuct list + */ + public List<Expression> getConjucts() { Review Comment: maybe u could use `ExpressionUtils.extractConjunction` directly ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java: ########## @@ -132,7 +167,10 @@ public List<Slot> computeOutput(Plan leftInput, Plan rightInput) { @Override public String toString() { StringBuilder sb = new StringBuilder("LogicalJoin (").append(joinType); - condition.ifPresent(expression -> sb.append(", ").append(expression)); + sb.append(" ["); + hashJoinPredicates.stream().map(expr -> sb.append(" ").append(expr)).collect(Collectors.toList()); + sb.append(" ]"); + otherJoinCondition.ifPresent(expression -> sb.append(", ").append(expression)); Review Comment: we could print them separately for easy reading ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/physical/AbstractPhysicalJoin.java: ########## @@ -78,11 +90,27 @@ public boolean equals(Object o) { return false; } AbstractPhysicalJoin that = (AbstractPhysicalJoin) o; - return joinType == that.joinType && Objects.equals(condition, that.condition); + return joinType == that.joinType && Objects.equals(otherJoinCondition, that.otherJoinCondition); Review Comment: why not compare equal join condition? in Nereids NestedLoopJoin could do equal join, currently we don't do that just because be cannot support it ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java: ########## @@ -145,12 +183,16 @@ public boolean equals(Object o) { return false; } LogicalJoin that = (LogicalJoin) o; - return joinType == that.joinType && Objects.equals(condition, that.condition); + + return joinType == that.joinType + && that.getHashJoinPredicates().containsAll(hashJoinPredicates) + && hashJoinPredicates.containsAll(that.getHashJoinPredicates()) + && Objects.equals(otherJoinCondition, that.otherJoinCondition); Review Comment: add a todo to use semantic rather than equals directly ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java: ########## @@ -51,14 +55,14 @@ public static boolean onlyShuffle(AbstractPhysicalJoin join) { */ public static List<EqualTo> getEqualTo(AbstractPhysicalJoin<Plan, Plan> join) { List<EqualTo> eqConjuncts = Lists.newArrayList(); - if (!join.getCondition().isPresent()) { + if (!join.getOtherJoinCondition().isPresent()) { return eqConjuncts; } List<SlotReference> leftSlots = Utils.getOutputSlotReference(join.left()); List<SlotReference> rightSlots = Utils.getOutputSlotReference(join.right()); - Expression onCondition = join.getCondition().get(); + Expression onCondition = join.getOtherJoinCondition().get(); Review Comment: ditto ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java: ########## @@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> leftSlots, List<SlotReferen || (leftSlotsSet.containsAll(rightUsed) && rightSlotsSet.containsAll(leftUsed)); } + private static class JoinSlotCoverageChecker { + HashSet<SlotReference> left; + HashSet<ExprId> leftExprIds; + HashSet<SlotReference> right; + HashSet<ExprId> rightExprIds; + + JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> right) { + this.left = new HashSet<>(left); + leftExprIds = new HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList())); Review Comment: u could collect to set directly ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java: ########## @@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> leftSlots, List<SlotReferen || (leftSlotsSet.containsAll(rightUsed) && rightSlotsSet.containsAll(leftUsed)); } + private static class JoinSlotCoverageChecker { + HashSet<SlotReference> left; + HashSet<ExprId> leftExprIds; + HashSet<SlotReference> right; + HashSet<ExprId> rightExprIds; + + JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> right) { + this.left = new HashSet<>(left); + leftExprIds = new HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList())); + this.right = new HashSet<>(right); + rightExprIds = new HashSet<>(right.stream().map(SlotReference::getExprId).collect(Collectors.toList())); + } + + boolean isCoveredByLeftSlots(List<SlotReference> slots) { + boolean covered = left.containsAll(slots); + if (covered) { + return true; + } + List<ExprId> slotsExprIds = slots.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); Review Comment: don't need to collect to list, use `allMatch(leftExprIds::contains)` instead ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java: ########## @@ -51,14 +55,14 @@ public static boolean onlyShuffle(AbstractPhysicalJoin join) { */ public static List<EqualTo> getEqualTo(AbstractPhysicalJoin<Plan, Plan> join) { List<EqualTo> eqConjuncts = Lists.newArrayList(); - if (!join.getCondition().isPresent()) { + if (!join.getOtherJoinCondition().isPresent()) { Review Comment: we need to return all equal conditions here, so we should to check getHashPredicates here ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java: ########## @@ -124,6 +240,7 @@ public static Pair<List<SlotReference>, List<SlotReference>> getOnClauseUsedSlot public static boolean shouldNestedLoopJoin(Join join) { JoinType joinType = join.getJoinType(); - return (joinType.isInnerJoin() && !join.getCondition().isPresent()) || joinType.isCrossJoin(); + //return (joinType.isInnerJoin() && !join.getOnClauseCondition().isPresent()) || joinType.isCrossJoin(); Review Comment: please remove useless comment ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java: ########## @@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> leftSlots, List<SlotReferen || (leftSlotsSet.containsAll(rightUsed) && rightSlotsSet.containsAll(leftUsed)); } + private static class JoinSlotCoverageChecker { + HashSet<SlotReference> left; + HashSet<ExprId> leftExprIds; + HashSet<SlotReference> right; + HashSet<ExprId> rightExprIds; + + JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> right) { + this.left = new HashSet<>(left); + leftExprIds = new HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList())); + this.right = new HashSet<>(right); + rightExprIds = new HashSet<>(right.stream().map(SlotReference::getExprId).collect(Collectors.toList())); + } + + boolean isCoveredByLeftSlots(List<SlotReference> slots) { + boolean covered = left.containsAll(slots); + if (covered) { + return true; + } + List<ExprId> slotsExprIds = slots.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); + return leftExprIds.containsAll(slotsExprIds); + } + + boolean isCoveredByRightSlots(List<SlotReference> slots) { + boolean covered = right.containsAll(slots); + if (covered) { + return true; + } + List<ExprId> slotsExprIds = slots.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); + return rightExprIds.containsAll(slotsExprIds); + } + + /** + * consider following cases: + * 1# A=1 => not for hash table + * 2# t1.a=t2.a + t2.b => hash table + * 3# t1.a=t1.a + t2.b => not for hash table + * 4# t1.a=t2.a or t1.b=t2.b not for hash table + * 5# t1.a > 1 not for hash table + * @param equalTo a conjunct in on clause condition + * @return true if the equal can be used as hash join condition + */ + boolean isHashJoinCondition(EqualTo equalTo) { + List<SlotReference> equalLeft = equalTo.left().collect(SlotReference.class::isInstance); + if (equalLeft.isEmpty()) { + return false; + } + + List<SlotReference> equalRight = equalTo.right().collect(SlotReference.class::isInstance); + if (equalRight.isEmpty()) { + return false; + } + + List<ExprId> equalLeftExprIds = equalLeft.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); + + List<ExprId> equalRightExprIds = equalRight.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); + return leftExprIds.containsAll(equalLeftExprIds) && rightExprIds.containsAll(equalRightExprIds) + || left.containsAll(equalLeft) && right.containsAll(equalRight) + || leftExprIds.containsAll(equalRightExprIds) && rightExprIds.containsAll(equalLeftExprIds) + || right.containsAll(equalLeft) && left.containsAll(equalRight); + } + } + + /** + * collect expressions from on clause, which could be used to build hash table + * @param join join node + * @return pair of expressions, for hash table or not. + */ + public static Pair<List<Expression>, List<Expression>> extractExpressionForHashTable(LogicalJoin join, + Optional<Expression> onConditions) { + if (onConditions.isPresent()) { + List<Expression> onExprs = ExpressionUtils.extractConjunction(onConditions.get()); + List<SlotReference> leftSlots = Utils.getOutputSlotReference((Plan) (join.left())); + List<SlotReference> rightSlots = Utils.getOutputSlotReference((Plan) (join.right())); Review Comment: Override left and right method in interface BinaryPlan and then u could remove type cast here ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java: ########## @@ -51,11 +54,27 @@ public class MultiJoin extends PlanVisitor<Void, Void> { * A B */ public final List<Plan> joinInputs = new ArrayList<>(); - public final List<Expression> conjuncts = new ArrayList<>(); + public final List<Expression> conjunctsForAllHashJoins = new ArrayList<>(); + private List<Expression> conjunctsKeepInFilter = new ArrayList<>(); Review Comment: final ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java: ########## @@ -98,19 +213,20 @@ public static Pair<List<SlotReference>, List<SlotReference>> getOnClauseUsedSlot List<SlotReference> leftSlots = Utils.getOutputSlotReference(join.left()); List<SlotReference> rightSlots = Utils.getOutputSlotReference(join.right()); - List<EqualTo> equalToList = getEqualTo(join); - + List<EqualTo> equalToList = join.getHashJoinPredicates().stream() + .map(e -> (EqualTo) e).collect(Collectors.toList()); + JoinSlotCoverageChecker checker = new JoinSlotCoverageChecker(leftSlots, rightSlots); for (EqualTo equalTo : equalToList) { List<SlotReference> leftOnSlots = equalTo.left().collect(SlotReference.class::isInstance); List<SlotReference> rightOnSlots = equalTo.right().collect(SlotReference.class::isInstance); - if (new HashSet<>(leftSlots).containsAll(leftOnSlots) - && new HashSet<>(rightSlots).containsAll(rightOnSlots)) { + if (checker.isCoveredByLeftSlots(leftOnSlots) + && checker.isCoveredByRightSlots(rightOnSlots)) { // TODO: need rethink about `.get(0)` childSlots.first.add(leftOnSlots.get(0)); Review Comment: we need addAll ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java: ########## @@ -74,25 +93,30 @@ public Plan reorderJoinsAccordingToConditions() { */ private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expression> conjuncts) { if (joinInputs.size() == 2) { - Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), joinInputs.get(1)); - Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput); - List<Expression> joinConditions = split.get(true); - List<Expression> nonJoinConditions = split.get(false); - + //Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), joinInputs.get(1)); + //Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput); + //List<Expression> joinConditions = split.get(true); + //List<Expression> nonJoinConditions = split.get(false); + Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable( + joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()), + joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()), + conjuncts); + List<Expression> joinConditions = pair.first; + conjunctsKeepInFilter = pair.second; Review Comment: ```suggestion conjunctsKeepInFilter.addAll(pair.second); ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java: ########## @@ -74,25 +93,30 @@ public Plan reorderJoinsAccordingToConditions() { */ private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expression> conjuncts) { if (joinInputs.size() == 2) { - Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), joinInputs.get(1)); - Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput); - List<Expression> joinConditions = split.get(true); - List<Expression> nonJoinConditions = split.get(false); - + //Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), joinInputs.get(1)); + //Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput); + //List<Expression> joinConditions = split.get(true); + //List<Expression> nonJoinConditions = split.get(false); + Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable( + joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()), + joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()), + conjuncts); + List<Expression> joinConditions = pair.first; + conjunctsKeepInFilter = pair.second; LogicalJoin join; if (joinConditions.isEmpty()) { - join = new LogicalJoin(JoinType.CROSS_JOIN, Optional.empty(), joinInputs.get(0), joinInputs.get(1)); + join = new LogicalJoin(JoinType.CROSS_JOIN, + new ArrayList<>(), + Optional.empty(), + joinInputs.get(0), joinInputs.get(1)); } else { join = new LogicalJoin(JoinType.INNER_JOIN, + new ArrayList<>(), Optional.of(ExpressionUtils.and(joinConditions)), Review Comment: why add to other conditon? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java: ########## @@ -165,6 +166,7 @@ private Pair<List<NamedExpression>, List<NamedExpression>> getProjectExprs() { private LogicalJoin<GroupPlan, GroupPlan> newBottomJoin() { return new LogicalJoin( bottomJoin.getJoinType(), + bottomJoin.getHashJoinPredicates(), Review Comment: we need according to new join children node to generate appropriate `hashJoinPredicates` just like what `newBottomJoinOnCondition` do ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java: ########## @@ -185,7 +219,11 @@ public Void visitLogicalJoin(LogicalJoin<Plan, Plan> join, Void context) { join.left().accept(this, context); join.right().accept(this, context); - join.getCondition().ifPresent(cond -> conjuncts.addAll(ExpressionUtils.extractConjunction(cond))); + conjunctsForAllHashJoins.addAll(join.getHashJoinPredicates()); + if (join.getOtherJoinCondition().isPresent()) { + conjunctsForAllHashJoins.addAll(ExpressionUtils.extractConjunction(join.getOtherJoinCondition().get())); Review Comment: this name is not very good since it could have other conditions ########## fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java: ########## @@ -190,8 +191,8 @@ public void pushDownPredicateIntoScanTest4() { Expression whereCondition = ExpressionUtils.and(whereCondition1, whereCondition2, whereCondition3, whereCondition4); - Plan join = new LogicalJoin(JoinType.INNER_JOIN, Optional.empty(), rStudent, rScore); - Plan join1 = new LogicalJoin(JoinType.INNER_JOIN, Optional.empty(), join, rCourse); + Plan join = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), Optional.empty(), rStudent, rScore); + Plan join1 = new LogicalJoin(JoinType.INNER_JOIN, new ArrayList<>(), Optional.empty(), join, rCourse); Review Comment: use immutableList ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java: ########## @@ -106,15 +107,25 @@ private class JoinExpressionRewrite extends OneRewriteRuleFactory { @Override public Rule build() { return logicalJoin().then(join -> { - Optional<Expression> condition = join.getCondition(); - if (!condition.isPresent()) { + List<Expression> hashJoinPredicates = join.getHashJoinPredicates(); + Optional<Expression> otherJoinCondition = join.getOtherJoinCondition(); + if (!otherJoinCondition.isPresent() && hashJoinPredicates.isEmpty()) { return join; } - Expression newCondition = rewriter.rewrite(condition.get()); - if (newCondition.equals(condition.get())) { + List<Expression> rewriteHashJoinPredicates = new ArrayList<>(); + boolean joinPredicatesChanged = false; + for (Expression expr : hashJoinPredicates) { + Expression newExpr = rewriter.rewrite(expr); + joinPredicatesChanged = joinPredicatesChanged || newExpr.equals(expr); Review Comment: ```suggestion joinPredicatesChanged = joinPredicatesChanged || !newExpr.equals(expr); ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java: ########## @@ -124,20 +148,27 @@ private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expre }).findFirst(); Plan right = rightOpt.orElseGet(() -> candidate.get(1)); - Set<Slot> joinOutput = getJoinOutput(left, right); - Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput); - List<Expression> joinConditions = split.get(true); - List<Expression> nonJoinConditions = split.get(false); - - Optional<Expression> cond; + Pair<List<Expression>, List<Expression>> pair = JoinUtils.extractExpressionForHashTable( + joinInputs.get(0).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()), + joinInputs.get(1).getOutput().stream().map(SlotReference.class::cast).collect(Collectors.toList()), + conjuncts); + //Set<Slot> joinOutput = getJoinOutput(left, right); + //Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput); + //List<Expression> joinConditions = split.get(true); + //List<Expression> nonJoinConditions = split.get(false); + List<Expression> joinConditions = pair.first; + List<Expression> nonJoinConditions = pair.second; Review Comment: we lose nonJoinConditions? ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRewrite.java: ########## @@ -106,15 +107,25 @@ private class JoinExpressionRewrite extends OneRewriteRuleFactory { @Override public Rule build() { return logicalJoin().then(join -> { - Optional<Expression> condition = join.getCondition(); - if (!condition.isPresent()) { + List<Expression> hashJoinPredicates = join.getHashJoinPredicates(); + Optional<Expression> otherJoinCondition = join.getOtherJoinCondition(); + if (!otherJoinCondition.isPresent() && hashJoinPredicates.isEmpty()) { Review Comment: i think we can let join Override getExpressions interface and return all expression in one method ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java: ########## @@ -51,11 +54,27 @@ public class MultiJoin extends PlanVisitor<Void, Void> { * A B */ public final List<Plan> joinInputs = new ArrayList<>(); - public final List<Expression> conjuncts = new ArrayList<>(); + public final List<Expression> conjunctsForAllHashJoins = new ArrayList<>(); + private List<Expression> conjunctsKeepInFilter = new ArrayList<>(); + + /** + * reorderJoinsAccordingToConditions + * @return join or filter + */ public Plan reorderJoinsAccordingToConditions() { Preconditions.checkArgument(joinInputs.size() >= 2); - return reorderJoinsAccordingToConditions(joinInputs, conjuncts); + Plan joinRoot = reorderJoinsAccordingToConditions(joinInputs, conjunctsForAllHashJoins); + if (!conjunctsKeepInFilter.isEmpty()) { + LogicalFilter filter = new LogicalFilter( + ExpressionUtils.and(conjunctsKeepInFilter), + joinRoot + ); + return filter; + } else { + return joinRoot; + } + Review Comment: ```suggestion Plan root = reorderJoinsAccordingToConditions(joinInputs, conjunctsForAllHashJoins); if (!conjunctsKeepInFilter.isEmpty()) { root = new LogicalFilter(ExpressionUtils.and(conjunctsKeepInFilter), root); } return root; ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java: ########## @@ -54,17 +58,21 @@ * @param joinType logical type for join */ public LogicalJoin(JoinType joinType, LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild) { - this(joinType, Optional.empty(), Optional.empty(), Optional.empty(), leftChild, rightChild); + this(joinType, new ArrayList<Expression>(), Review Comment: use ImmutableList instead of ArrayList in all immutable class ########## fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/algebra/Join.java: ########## @@ -28,5 +29,9 @@ public interface Join { JoinType getJoinType(); - Optional<Expression> getCondition(); + List<Expression> getHashJoinPredicates(); + + Optional<Expression> getOtherJoinCondition(); + + Optional<Expression> getOnClauseCondition(); Review Comment: ```suggestion List<Expression> getHashConditions(); Optional<Expression> getOtherConditions(); Optional<Expression> getConditions(); ``` ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java: ########## @@ -165,11 +196,14 @@ public Void visit(Plan plan, Void context) { return null; } + //TODO: add a rule to push filter condition down to join if acceptable. + // We can not simply add filter predicates into join conditions for outer/anti join. + // It is better to add another rule to push down acceptable filter conditions to join. Review Comment: i think PushPredicateThroughJoin already do that, but we need a rule to push down join other condition to its children ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java: ########## @@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> leftSlots, List<SlotReferen || (leftSlotsSet.containsAll(rightUsed) && rightSlotsSet.containsAll(leftUsed)); } + private static class JoinSlotCoverageChecker { + HashSet<SlotReference> left; + HashSet<ExprId> leftExprIds; + HashSet<SlotReference> right; + HashSet<ExprId> rightExprIds; + + JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> right) { + this.left = new HashSet<>(left); + leftExprIds = new HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList())); + this.right = new HashSet<>(right); + rightExprIds = new HashSet<>(right.stream().map(SlotReference::getExprId).collect(Collectors.toList())); + } + + boolean isCoveredByLeftSlots(List<SlotReference> slots) { + boolean covered = left.containsAll(slots); Review Comment: i think just use ExprId is better ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java: ########## @@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> leftSlots, List<SlotReferen || (leftSlotsSet.containsAll(rightUsed) && rightSlotsSet.containsAll(leftUsed)); } + private static class JoinSlotCoverageChecker { + HashSet<SlotReference> left; + HashSet<ExprId> leftExprIds; + HashSet<SlotReference> right; + HashSet<ExprId> rightExprIds; + + JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> right) { + this.left = new HashSet<>(left); + leftExprIds = new HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList())); + this.right = new HashSet<>(right); + rightExprIds = new HashSet<>(right.stream().map(SlotReference::getExprId).collect(Collectors.toList())); + } + + boolean isCoveredByLeftSlots(List<SlotReference> slots) { + boolean covered = left.containsAll(slots); + if (covered) { + return true; + } + List<ExprId> slotsExprIds = slots.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); + return leftExprIds.containsAll(slotsExprIds); + } + + boolean isCoveredByRightSlots(List<SlotReference> slots) { + boolean covered = right.containsAll(slots); + if (covered) { + return true; + } + List<ExprId> slotsExprIds = slots.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); + return rightExprIds.containsAll(slotsExprIds); Review Comment: ditto ########## fe/fe-core/src/test/java/org/apache/doris/nereids/trees/plans/PlanEqualsTest.java: ########## @@ -98,14 +99,18 @@ public void testLogicalJoin(@Mocked Plan left, @Mocked Plan right) { new SlotReference("a", new BigIntType(), true, Lists.newArrayList()), new SlotReference("b", new BigIntType(), true, Lists.newArrayList()) ); - LogicalJoin innerJoin1 = new LogicalJoin(JoinType.INNER_JOIN, Optional.of(condition1), left, right); + LogicalJoin innerJoin1 = new LogicalJoin(JoinType.INNER_JOIN, Lists.newArrayList(condition1), + Optional.empty(), left, right); + boolean bool = innerJoin1.equals(innerJoin); + System.out.println(bool); Review Comment: remove useless print ########## fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java: ########## @@ -73,9 +73,11 @@ public Pair<LogicalJoin, LogicalJoin> testJoinLAsscom( */ Assertions.assertEquals(3, scans.size()); LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN, - Optional.of(bottomJoinOnCondition), scans.get(0), scans.get(1)); + Lists.newArrayList(bottomJoinOnCondition), + Optional.empty(), scans.get(0), scans.get(1)); LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>( - JoinType.INNER_JOIN, Optional.of(topJoinOnCondition), bottomJoin, scans.get(2)); + JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition), + Optional.empty(), bottomJoin, scans.get(2)); Review Comment: @jackwener please add Assert for join equal condition in another PR ########## fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java: ########## @@ -87,6 +91,117 @@ private static boolean isEqualTo(List<SlotReference> leftSlots, List<SlotReferen || (leftSlotsSet.containsAll(rightUsed) && rightSlotsSet.containsAll(leftUsed)); } + private static class JoinSlotCoverageChecker { + HashSet<SlotReference> left; + HashSet<ExprId> leftExprIds; + HashSet<SlotReference> right; + HashSet<ExprId> rightExprIds; + + JoinSlotCoverageChecker(List<SlotReference> left, List<SlotReference> right) { + this.left = new HashSet<>(left); + leftExprIds = new HashSet<>(left.stream().map(SlotReference::getExprId).collect(Collectors.toList())); + this.right = new HashSet<>(right); + rightExprIds = new HashSet<>(right.stream().map(SlotReference::getExprId).collect(Collectors.toList())); + } + + boolean isCoveredByLeftSlots(List<SlotReference> slots) { + boolean covered = left.containsAll(slots); + if (covered) { + return true; + } + List<ExprId> slotsExprIds = slots.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); + return leftExprIds.containsAll(slotsExprIds); + } + + boolean isCoveredByRightSlots(List<SlotReference> slots) { + boolean covered = right.containsAll(slots); + if (covered) { + return true; + } + List<ExprId> slotsExprIds = slots.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); + return rightExprIds.containsAll(slotsExprIds); + } + + /** + * consider following cases: + * 1# A=1 => not for hash table + * 2# t1.a=t2.a + t2.b => hash table + * 3# t1.a=t1.a + t2.b => not for hash table + * 4# t1.a=t2.a or t1.b=t2.b not for hash table + * 5# t1.a > 1 not for hash table + * @param equalTo a conjunct in on clause condition + * @return true if the equal can be used as hash join condition + */ + boolean isHashJoinCondition(EqualTo equalTo) { + List<SlotReference> equalLeft = equalTo.left().collect(SlotReference.class::isInstance); + if (equalLeft.isEmpty()) { + return false; + } + + List<SlotReference> equalRight = equalTo.right().collect(SlotReference.class::isInstance); + if (equalRight.isEmpty()) { + return false; + } + + List<ExprId> equalLeftExprIds = equalLeft.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); + + List<ExprId> equalRightExprIds = equalRight.stream() + .map(SlotReference::getExprId).collect(Collectors.toList()); + return leftExprIds.containsAll(equalLeftExprIds) && rightExprIds.containsAll(equalRightExprIds) + || left.containsAll(equalLeft) && right.containsAll(equalRight) + || leftExprIds.containsAll(equalRightExprIds) && rightExprIds.containsAll(equalLeftExprIds) + || right.containsAll(equalLeft) && left.containsAll(equalRight); + } + } + + /** + * collect expressions from on clause, which could be used to build hash table + * @param join join node + * @return pair of expressions, for hash table or not. + */ + public static Pair<List<Expression>, List<Expression>> extractExpressionForHashTable(LogicalJoin join, + Optional<Expression> onConditions) { + if (onConditions.isPresent()) { + List<Expression> onExprs = ExpressionUtils.extractConjunction(onConditions.get()); + List<SlotReference> leftSlots = Utils.getOutputSlotReference((Plan) (join.left())); + List<SlotReference> rightSlots = Utils.getOutputSlotReference((Plan) (join.right())); + return extractExpressionForHashTable(leftSlots, rightSlots, onExprs); + } + return new Pair<>(Lists.newArrayList(), Lists.newArrayList()); + } + + /** + * extract expression + * @param leftSlots left child output slots + * @param rightSlots right child output slots + * @param onConditions conditions to be split + * @return pair of hashCondition and otherCondition + */ + public static Pair<List<Expression>, List<Expression>> extractExpressionForHashTable(List<SlotReference> leftSlots, + List<SlotReference> rightSlots, + List<Expression> onConditions) { + + Pair<List<Expression>, List<Expression>> pair = new Pair<>(Lists.newArrayList(), Lists.newArrayList()); Review Comment: stream group by maybe is a better choice ########## fe/fe-core/src/test/java/org/apache/doris/nereids/datasets/ssb/SSBJoinReorderTest.java: ########## @@ -113,9 +113,11 @@ public void check(Plan plan) { // check join conditions List<String> actualJoinConditions = joins.stream().map(j -> { - Optional<Expression> condition = j.getCondition(); + Review Comment: useless blank line ########## fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/MultiJoin.java: ########## @@ -74,25 +93,30 @@ public Plan reorderJoinsAccordingToConditions() { */ private Plan reorderJoinsAccordingToConditions(List<Plan> joinInputs, List<Expression> conjuncts) { if (joinInputs.size() == 2) { - Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), joinInputs.get(1)); - Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput); - List<Expression> joinConditions = split.get(true); - List<Expression> nonJoinConditions = split.get(false); - + //Set<Slot> joinOutput = getJoinOutput(joinInputs.get(0), joinInputs.get(1)); + //Map<Boolean, List<Expression>> split = splitConjuncts(conjuncts, joinOutput); + //List<Expression> joinConditions = split.get(true); + //List<Expression> nonJoinConditions = split.get(false); Review Comment: remove it -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org