This is an automated email from the ASF dual-hosted git repository. starocean999 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new ecadd4b392 [feature](Nereids): add OuterJoinAssoc rule (#16676) ecadd4b392 is described below commit ecadd4b39284da222990352330ac3b25037d4dd6 Author: jakevin <jakevin...@gmail.com> AuthorDate: Wed Feb 15 19:19:28 2023 +0800 [feature](Nereids): add OuterJoinAssoc rule (#16676) * move isIntersecting. * [feature](Nereids): add OuterJoinAssoc rule * fix bug * fix --- .../org/apache/doris/nereids/rules/RuleType.java | 2 + .../rules/exploration/join/InnerJoinLAsscom.java | 4 +- .../exploration/join/InnerJoinLAsscomProject.java | 74 ++++---------- .../rules/exploration/join/JoinReorderHelper.java | 99 ++++++++++++++++++ .../rules/exploration/join/JoinReorderUtils.java | 44 +++++++- .../{OuterJoinLAsscom.java => OuterJoinAssoc.java} | 89 ++++++++--------- .../exploration/join/OuterJoinAssocProject.java | 111 +++++++++++++++++++++ .../rules/exploration/join/OuterJoinLAsscom.java | 4 +- .../exploration/join/OuterJoinLAsscomProject.java | 98 +++++------------- .../join/SemiJoinLogicalJoinTranspose.java | 8 +- .../join/SemiJoinLogicalJoinTransposeProject.java | 8 +- .../nereids/rules/rewrite/logical/ReorderJoin.java | 5 +- .../apache/doris/nereids/util/ExpressionUtils.java | 24 ----- .../org/apache/doris/nereids/util/JoinUtils.java | 36 +------ .../java/org/apache/doris/nereids/util/Utils.java | 24 +++++ .../rules/exploration/join/OuterJoinAssocTest.java | 72 +++++++++++++ .../join/OuterJoinLAsscomProjectTest.java | 28 +++--- 17 files changed, 460 insertions(+), 270 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 543291af89..4dfbf928fc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -214,6 +214,8 @@ public enum RuleType { LOGICAL_INNER_JOIN_LASSCOM_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_OUTER_JOIN_LASSCOM(RuleTypeClass.EXPLORATION), LOGICAL_OUTER_JOIN_LASSCOM_PROJECT(RuleTypeClass.EXPLORATION), + LOGICAL_OUTER_JOIN_ASSOC(RuleTypeClass.EXPLORATION), + LOGICAL_OUTER_JOIN_ASSOC_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE(RuleTypeClass.EXPLORATION), LOGICAL_SEMI_JOIN_LOGICAL_JOIN_TRANSPOSE_PROJECT(RuleTypeClass.EXPLORATION), LOGICAL_SEMI_JOIN_SEMI_JOIN_TRANPOSE(RuleTypeClass.EXPLORATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java index 459b5f2bdb..977ef42afa 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscom.java @@ -27,7 +27,7 @@ import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; @@ -111,7 +111,7 @@ public class InnerJoinLAsscom extends OneExplorationRuleFactory { .collect(Collectors.partitioningBy(topHashOn -> { Set<ExprId> usedExprIdSet = topHashOn.getInputSlotExprIds(); Set<ExprId> bOutputExprIdSet = bottomJoin.right().getOutputExprIdSet(); - return ExpressionUtils.isIntersecting(bOutputExprIdSet, usedExprIdSet); + return Utils.isIntersecting(bOutputExprIdSet, usedExprIdSet); })); // * don't include B, just include (A C) // we add it into newBottomJoin HashJoinConjuncts. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java index 8a3fe670b0..aac195cb3a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/InnerJoinLAsscomProject.java @@ -20,28 +20,23 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; -import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.ExpressionUtils; -import org.apache.doris.nereids.util.JoinUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; import java.util.ArrayList; -import java.util.HashMap; import java.util.List; import java.util.Map; import java.util.Set; import java.util.stream.Collectors; -import java.util.stream.Stream; /** * Rule for change inner join LAsscom (associative and commutive). @@ -65,7 +60,6 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) .when(join -> JoinReorderUtils.checkProject(join.left())) .then(topJoin -> { - /* ********** init ********** */ List<NamedExpression> projects = topJoin.left().getProjects(); LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child(); @@ -73,15 +67,17 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { GroupPlan b = bottomJoin.right(); GroupPlan c = topJoin.right(); Set<Slot> cOutputSet = c.getOutputSet(); - Set<ExprId> cOutputExprIdSet = c.getOutputExprIdSet(); /* ********** Split projects ********** */ Map<Boolean, List<NamedExpression>> map = JoinReorderUtils.splitProjection(projects, b); - List<NamedExpression> newLeftProjects = map.get(false); - List<NamedExpression> newRightProjects = map.get(true); - Set<ExprId> bExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(b, newRightProjects); + List<NamedExpression> aProjects = map.get(false); + List<NamedExpression> bProjects = map.get(true); + if (aProjects.isEmpty()) { + return null; + } /* ********** split HashConjuncts ********** */ + Set<ExprId> bExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(b, bProjects); Map<Boolean, List<Expression>> splitHashConjuncts = splitConjunctsWithAlias( topJoin.getHashJoinConjuncts(), bottomJoin.getHashJoinConjuncts(), bExprIdSet); List<Expression> newTopHashConjuncts = splitHashConjuncts.get(true); @@ -98,61 +94,27 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { List<Expression> newTopOtherConjuncts = splitOtherConjuncts.get(true); List<Expression> newBottomOtherConjuncts = splitOtherConjuncts.get(false); - /* ********** replace Conjuncts by projects ********** */ - Map<ExprId, Slot> inputToOutput = new HashMap<>(); - Map<ExprId, Slot> outputToInput = new HashMap<>(); - for (NamedExpression expr : projects) { - if (expr instanceof Alias) { - Alias alias = (Alias) expr; - Slot outputSlot = alias.toSlot(); - Expression child = alias.child(); - Preconditions.checkState(child instanceof Slot); - Slot inputSlot = (Slot) child; - inputToOutput.put(inputSlot.getExprId(), outputSlot); - outputToInput.put(outputSlot.getExprId(), inputSlot); - } - } - // replace hashConjuncts - newBottomHashConjuncts = JoinUtils.replaceJoinConjuncts(newBottomHashConjuncts, outputToInput); - newTopHashConjuncts = JoinUtils.replaceJoinConjuncts(newTopHashConjuncts, inputToOutput); - - // replace otherConjuncts - newBottomOtherConjuncts = JoinUtils.replaceJoinConjuncts(newBottomOtherConjuncts, outputToInput); - newTopOtherConjuncts = JoinUtils.replaceJoinConjuncts(newTopOtherConjuncts, inputToOutput); + JoinReorderHelper helper = new JoinReorderHelper(newTopHashConjuncts, newTopOtherConjuncts, + newBottomHashConjuncts, newBottomOtherConjuncts, projects, aProjects, bProjects); // Add all slots used by OnCondition when projects not empty. - Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat( - newTopHashConjuncts.stream(), - newTopOtherConjuncts.stream()) - .flatMap(onExpr -> { - Set<Slot> usedSlotRefs = onExpr.collect(SlotReference.class::isInstance); - return usedSlotRefs.stream(); - }) - .filter(slot -> !cOutputExprIdSet.contains(slot.getExprId())) - .collect(Collectors.partitioningBy( - slot -> bExprIdSet.contains(slot.getExprId()), Collectors.toSet())); - Set<Slot> aUsedSlots = abOnUsedSlots.get(false); - Set<Slot> bUsedSlots = abOnUsedSlots.get(true); - - JoinUtils.addSlotsUsedByOn(bUsedSlots, newRightProjects); - JoinUtils.addSlotsUsedByOn(aUsedSlots, newLeftProjects); - - if (!newLeftProjects.isEmpty()) { - newLeftProjects.addAll(cOutputSet); - } + helper.addSlotsUsedByOn(JoinReorderUtils.combineProjectAndChildExprId(a, helper.newLeftProjects), + c.getOutputExprIdSet()); + + aProjects.addAll(cOutputSet); /* ********** new Plan ********** */ LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), - newBottomHashConjuncts, newBottomOtherConjuncts, JoinHint.NONE, + helper.newBottomHashConjuncts, helper.newBottomOtherConjuncts, JoinHint.NONE, a, c, bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); - Plan left = JoinReorderUtils.projectOrSelf(newLeftProjects, newBottomJoin); - Plan right = JoinReorderUtils.projectOrSelf(newRightProjects, b); + Plan left = JoinReorderUtils.projectOrSelf(aProjects, newBottomJoin); + Plan right = JoinReorderUtils.projectOrSelf(bProjects, b); LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), - newTopHashConjuncts, newTopOtherConjuncts, JoinHint.NONE, + helper.newTopHashConjuncts, helper.newTopOtherConjuncts, JoinHint.NONE, left, right, topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); @@ -172,7 +134,7 @@ public class InnerJoinLAsscomProject extends OneExplorationRuleFactory { Map<Boolean, List<Expression>> splitOn = topConjuncts.stream() .collect(Collectors.partitioningBy(topHashOn -> { Set<ExprId> usedExprIds = topHashOn.getInputSlotExprIds(); - return ExpressionUtils.isIntersecting(bExprIdSet, usedExprIds); + return Utils.isIntersecting(bExprIdSet, usedExprIds); })); // * don't include B, just include (A C) // we add it into newBottomJoin HashConjuncts. diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderHelper.java new file mode 100644 index 0000000000..600c88b388 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderHelper.java @@ -0,0 +1,99 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; + +import com.google.common.base.Preconditions; + +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * Helper class for three left deep tree ( original plan tree is (a b) c )create new join. + */ +public class JoinReorderHelper { + public List<Expression> newTopHashConjuncts; + public List<Expression> newTopOtherConjuncts; + public List<Expression> newBottomHashConjuncts; + public List<Expression> newBottomOtherConjuncts; + + public List<NamedExpression> oldProjects; + public List<NamedExpression> newLeftProjects; + public List<NamedExpression> newRightProjects; + + /** + * Constructor. + */ + public JoinReorderHelper(List<Expression> newTopHashConjuncts, List<Expression> newTopOtherConjuncts, + List<Expression> newBottomHashConjuncts, List<Expression> newBottomOtherConjuncts, + List<NamedExpression> oldProjects, List<NamedExpression> newLeftProjects, + List<NamedExpression> newRightProjects) { + this.newTopHashConjuncts = newTopHashConjuncts; + this.newTopOtherConjuncts = newTopOtherConjuncts; + this.newBottomHashConjuncts = newBottomHashConjuncts; + this.newBottomOtherConjuncts = newBottomOtherConjuncts; + this.oldProjects = oldProjects; + this.newLeftProjects = newLeftProjects; + this.newRightProjects = newRightProjects; + replaceConjuncts(oldProjects); + } + + private void replaceConjuncts(List<NamedExpression> projects) { + Map<ExprId, Slot> inputToOutput = new HashMap<>(); + Map<ExprId, Slot> outputToInput = new HashMap<>(); + for (NamedExpression expr : projects) { + Slot outputSlot = expr.toSlot(); + Set<Slot> usedSlots = expr.getInputSlots(); + Preconditions.checkState(usedSlots.size() == 1); + Slot inputSlot = (Slot) usedSlots.toArray()[0]; + inputToOutput.put(inputSlot.getExprId(), outputSlot); + outputToInput.put(outputSlot.getExprId(), inputSlot); + } + + newBottomHashConjuncts = JoinReorderUtils.replaceJoinConjuncts(newBottomHashConjuncts, outputToInput); + newTopHashConjuncts = JoinReorderUtils.replaceJoinConjuncts(newTopHashConjuncts, inputToOutput); + newBottomOtherConjuncts = JoinReorderUtils.replaceJoinConjuncts(newBottomOtherConjuncts, outputToInput); + newTopOtherConjuncts = JoinReorderUtils.replaceJoinConjuncts(newTopOtherConjuncts, inputToOutput); + } + + /** + * Add all slots used by OnCondition when projects not empty. + * @param cOutputExprIdSet we want to get abOnUsedSlots, we need filter cOutputExprIdSet. + */ + public void addSlotsUsedByOn(Set<ExprId> splitIds, Set<ExprId> cOutputExprIdSet) { + Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat( + newTopHashConjuncts.stream(), + newTopOtherConjuncts.stream()) + .flatMap(onExpr -> onExpr.getInputSlots().stream()) + .filter(slot -> !cOutputExprIdSet.contains(slot.getExprId())) + .collect(Collectors.partitioningBy(slot -> splitIds.contains(slot.getExprId()), Collectors.toSet())); + Set<Slot> aUsedSlots = abOnUsedSlots.get(true); + Set<Slot> bUsedSlots = abOnUsedSlots.get(false); + + JoinReorderUtils.addSlotsUsedByOn(aUsedSlots, newLeftProjects); + JoinReorderUtils.addSlotsUsedByOn(bUsedSlots, newRightProjects); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java index b115d03d5a..645e74c7f3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderUtils.java @@ -19,6 +19,7 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.GroupPlan; @@ -26,6 +27,8 @@ import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import com.google.common.collect.ImmutableList; + import java.util.List; import java.util.Map; import java.util.Set; @@ -56,17 +59,14 @@ class JoinReorderUtils { }); } - static Map<Boolean, List<NamedExpression>> splitProjection( - List<NamedExpression> projects, Plan splitChild) { + static Map<Boolean, List<NamedExpression>> splitProjection(List<NamedExpression> projects, Plan splitChild) { Set<ExprId> splitExprIds = splitChild.getOutputExprIdSet(); - Map<Boolean, List<NamedExpression>> projectExprsMap = projects.stream() + return projects.stream() .collect(Collectors.partitioningBy(projectExpr -> { Set<ExprId> usedExprIds = projectExpr.getInputSlotExprIds(); return splitExprIds.containsAll(usedExprIds); })); - - return projectExprsMap; } public static Set<ExprId> combineProjectAndChildExprId(Plan b, List<NamedExpression> bProject) { @@ -85,4 +85,38 @@ class JoinReorderUtils { } return new LogicalProject<>(projectExprs, plan); } + + /** + * replace JoinConjuncts by using slots map. + */ + public static List<Expression> replaceJoinConjuncts(List<Expression> joinConjuncts, + Map<ExprId, Slot> replaceMaps) { + return joinConjuncts.stream() + .map(expr -> + expr.rewriteUp(e -> { + if (e instanceof Slot && replaceMaps.containsKey(((Slot) e).getExprId())) { + return replaceMaps.get(((Slot) e).getExprId()); + } else { + return e; + } + }) + ).collect(ImmutableList.toImmutableList()); + } + + /** + * When project not empty, we add all slots used by hashOnCondition into projects. + */ + public static void addSlotsUsedByOn(Set<Slot> usedSlots, List<NamedExpression> projects) { + if (projects.isEmpty()) { + return; + } + Set<ExprId> projectExprIdSet = projects.stream() + .map(NamedExpression::getExprId) + .collect(Collectors.toSet()); + usedSlots.forEach(slot -> { + if (!projectExprIdSet.contains(slot.getExprId())) { + projects.add(slot); + } + }); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java similarity index 50% copy from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java copy to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java index b1124965e8..05229d2bcb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssoc.java @@ -21,103 +21,92 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; -import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableSet; import java.util.Set; -import java.util.stream.Collectors; import java.util.stream.Stream; /** - * Rule for change inner join LAsscom (associative and commutive). - * TODO Future: - * LeftOuter-LeftOuter can allow topHashConjunct (A B) and (AC) + * OuterJoinAssoc. */ -public class OuterJoinLAsscom extends OneExplorationRuleFactory { - public static final OuterJoinLAsscom INSTANCE = new OuterJoinLAsscom(); +public class OuterJoinAssoc extends OneExplorationRuleFactory { + /* + * topJoin newTopJoin + * / \ / \ + * bottomJoin C -> A newBottomJoin + * / \ / \ + * A B B C + */ + public static final OuterJoinAssoc INSTANCE = new OuterJoinAssoc(); - // Pair<bottomJoin, topJoin> - // newBottomJoin Type = topJoin Type, newTopJoin Type = bottomJoin Type public static Set<Pair<JoinType, JoinType>> VALID_TYPE_PAIR_SET = ImmutableSet.of( - Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.INNER_JOIN), Pair.of(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN), Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.LEFT_OUTER_JOIN)); - /* - * topJoin newTopJoin - * / \ / \ - * bottomJoin C --> newBottomJoin B - * / \ / \ - * A B A C - */ @Override public Rule build() { return logicalJoin(logicalJoin(), group()) .when(join -> VALID_TYPE_PAIR_SET.contains(Pair.of(join.left().getJoinType(), join.getJoinType()))) - .when(topJoin -> checkReorder(topJoin, topJoin.left())) - .when(topJoin -> checkCondition(topJoin, topJoin.left().right().getOutputExprIdSet())) - .whenNot(join -> join.hasJoinHint() || join.left().hasJoinHint()) + .when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left())) + .when(topJoin -> checkCondition(topJoin, topJoin.left().left().getOutputSet())) .then(topJoin -> { LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left(); GroupPlan a = bottomJoin.left(); GroupPlan b = bottomJoin.right(); GroupPlan c = topJoin.right(); + /* TODO: + * p23 need to reject nulls on A(e2) (Eqv. 1) + * see paper `On the Correct and Complete Enumeration of the Core Search Space`. + * But because we have added eliminate_outer_rule, we don't need to consider this. + */ + LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), topJoin.getHashJoinConjuncts(), topJoin.getOtherJoinConjuncts(), JoinHint.NONE, - a, c, bottomJoin.getJoinReorderContext()); - newBottomJoin.getJoinReorderContext().setHasLAsscom(false); - newBottomJoin.getJoinReorderContext().setHasCommute(false); - - LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan> newTopJoin = new LogicalJoin<>( - bottomJoin.getJoinType(), + b, c); + LogicalJoin<GroupPlan, LogicalJoin<GroupPlan, GroupPlan>> newTopJoin + = new LogicalJoin<>(bottomJoin.getJoinType(), bottomJoin.getHashJoinConjuncts(), bottomJoin.getOtherJoinConjuncts(), JoinHint.NONE, - newBottomJoin, b, topJoin.getJoinReorderContext()); - newTopJoin.getJoinReorderContext().setHasLAsscom(true); - + a, newBottomJoin, bottomJoin.getJoinReorderContext()); + setReorderContext(newTopJoin, newBottomJoin); return newTopJoin; - }).toRule(RuleType.LOGICAL_OUTER_JOIN_LASSCOM); + }).toRule(RuleType.LOGICAL_OUTER_JOIN_ASSOC); } /** - * topHashConjunct possibility: (A B) (A C) (B C) (A B C). - * (A B) is forbidden, because it should be in bottom join. - * (B C) (A B C) check failed, because it contains B. - * So, just allow: top (A C), bottom (A B), we can exchange HashConjunct directly. + * just allow: top (B C), bottom (A B), we can exchange HashConjunct directly. * <p> * Same with OtherJoinConjunct. */ - private boolean checkCondition(LogicalJoin<? extends Plan, GroupPlan> topJoin, Set<ExprId> bOutputExprIdSet) { + public static boolean checkCondition(LogicalJoin<? extends Plan, GroupPlan> topJoin, Set<Slot> aOutputSet) { return Stream.concat( topJoin.getHashJoinConjuncts().stream(), topJoin.getOtherJoinConjuncts().stream()) .allMatch(expr -> { - Set<ExprId> usedExprIdSet = expr.<Set<SlotReference>>collect(SlotReference.class::isInstance) - .stream() - .map(SlotReference::getExprId) - .collect(Collectors.toSet()); - return !ExpressionUtils.isIntersecting(usedExprIdSet, bOutputExprIdSet); + Set<Slot> usedSlot = expr.collect(SlotReference.class::isInstance); + return !Utils.isIntersecting(usedSlot, aOutputSet); }); } /** - * check join reorder masks. + * Set the reorder context for the new join. */ - public static boolean checkReorder(LogicalJoin<? extends Plan, GroupPlan> topJoin, - LogicalJoin<GroupPlan, GroupPlan> bottomJoin) { - // hasCommute will cause to lack of OuterJoinAssocRule:Left - return !topJoin.getJoinReorderContext().hasLAsscom() - && !topJoin.getJoinReorderContext().hasLeftAssociate() - && !topJoin.getJoinReorderContext().hasRightAssociate() - && !topJoin.getJoinReorderContext().hasExchange() - && !bottomJoin.getJoinReorderContext().hasCommute(); + public static void setReorderContext(LogicalJoin topJoin, LogicalJoin bottomJoin) { + bottomJoin.getJoinReorderContext().setHasCommute(false); + bottomJoin.getJoinReorderContext().setHasRightAssociate(false); + bottomJoin.getJoinReorderContext().setHasLeftAssociate(false); + bottomJoin.getJoinReorderContext().setHasExchange(false); + + topJoin.getJoinReorderContext().setHasRightAssociate(true); + topJoin.getJoinReorderContext().setHasCommute(false); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java new file mode 100644 index 0000000000..7faafbbac5 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocProject.java @@ -0,0 +1,111 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.expressions.ExprId; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.JoinHint; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.Utils; + +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +/** + * OuterJoinAssocProject. + */ +public class OuterJoinAssocProject extends OneExplorationRuleFactory { + /* + * topJoin newTopJoin + * / \ / \ + * bottomJoin C -> A newBottomJoin + * / \ / \ + * A B B C + */ + public static final OuterJoinAssocProject INSTANCE = new OuterJoinAssocProject(); + + @Override + public Rule build() { + return logicalJoin(logicalProject(logicalJoin()), group()) + .when(join -> OuterJoinAssoc.VALID_TYPE_PAIR_SET.contains( + Pair.of(join.left().child().getJoinType(), join.getJoinType()))) + .when(topJoin -> OuterJoinLAsscom.checkReorder(topJoin, topJoin.left().child())) + .whenNot(join -> join.hasJoinHint() || join.left().child().hasJoinHint()) + .when(join -> OuterJoinAssoc.checkCondition(join, join.left().child().left().getOutputSet())) + .then(topJoin -> { + /* ********** init ********** */ + List<NamedExpression> projects = topJoin.left().getProjects(); + LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child(); + GroupPlan a = bottomJoin.left(); + GroupPlan b = bottomJoin.right(); + GroupPlan c = topJoin.right(); + + /* ********** Split projects ********** */ + Map<Boolean, List<NamedExpression>> map = JoinReorderUtils.splitProjection(projects, a); + List<NamedExpression> aProjects = map.get(true); + List<NamedExpression> bProjects = map.get(false); + if (bProjects.isEmpty()) { + return null; + } + Set<ExprId> aProjectsExprIds = aProjects.stream().map(NamedExpression::getExprId) + .collect(Collectors.toSet()); + + // topJoin condition can't contain aProject. just can (B C) + if (Stream.concat(topJoin.getHashJoinConjuncts().stream(), topJoin.getOtherJoinConjuncts().stream()) + .anyMatch(expr -> Utils.isIntersecting(expr.getInputSlotExprIds(), aProjectsExprIds))) { + return null; + } + + // topJoin condition -> newBottomJoin condition, bottomJoin condition -> newTopJoin condition + JoinReorderHelper helper = new JoinReorderHelper(bottomJoin.getHashJoinConjuncts(), + bottomJoin.getOtherJoinConjuncts(), topJoin.getHashJoinConjuncts(), + topJoin.getOtherJoinConjuncts(), projects, aProjects, bProjects); + + // Add all slots used by OnCondition when projects not empty. + helper.addSlotsUsedByOn(JoinReorderUtils.combineProjectAndChildExprId(a, helper.newLeftProjects), + Collections.EMPTY_SET); + + bProjects.addAll(OuterJoinLAsscomProject.forceToNullable(c.getOutputSet())); + /* ********** new Plan ********** */ + LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), + helper.newBottomHashConjuncts, helper.newBottomOtherConjuncts, JoinHint.NONE, + b, c, bottomJoin.getJoinReorderContext()); + + Plan left = JoinReorderUtils.projectOrSelf(aProjects, a); + Plan right = JoinReorderUtils.projectOrSelf(bProjects, newBottomJoin); + + LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), + helper.newTopHashConjuncts, helper.newTopOtherConjuncts, JoinHint.NONE, + left, right, topJoin.getJoinReorderContext()); + OuterJoinAssoc.setReorderContext(newTopJoin, newBottomJoin); + + return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); + }).toRule(RuleType.LOGICAL_OUTER_JOIN_ASSOC_PROJECT); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java index b1124965e8..dda781f327 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscom.java @@ -28,7 +28,7 @@ import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.collect.ImmutableSet; @@ -104,7 +104,7 @@ public class OuterJoinLAsscom extends OneExplorationRuleFactory { .stream() .map(SlotReference::getExprId) .collect(Collectors.toSet()); - return !ExpressionUtils.isIntersecting(usedExprIdSet, bOutputExprIdSet); + return !Utils.isIntersecting(usedExprIdSet, bOutputExprIdSet); }); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java index cfb94e9b46..6bfa65b353 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProject.java @@ -21,9 +21,7 @@ import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; -import org.apache.doris.nereids.trees.expressions.Alias; import org.apache.doris.nereids.trees.expressions.ExprId; -import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; @@ -31,13 +29,10 @@ import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.JoinUtils; - -import com.google.common.base.Preconditions; -import com.google.common.collect.ImmutableSet; +import org.apache.doris.nereids.util.Utils; import java.util.ArrayList; -import java.util.HashMap; +import java.util.Collections; import java.util.List; import java.util.Map; import java.util.Set; @@ -74,95 +69,56 @@ public class OuterJoinLAsscomProject extends OneExplorationRuleFactory { GroupPlan a = bottomJoin.left(); GroupPlan b = bottomJoin.right(); GroupPlan c = topJoin.right(); - Set<Slot> cOutputSet = c.getOutputSet(); - Set<ExprId> cOutputExprIdSet = c.getOutputExprIdSet(); /* ********** Split projects ********** */ Map<Boolean, List<NamedExpression>> map = JoinReorderUtils.splitProjection(projects, a); - List<NamedExpression> newLeftProjects = map.get(true); - List<NamedExpression> newRightProjects = map.get(false); - Set<ExprId> aExprIdSet = JoinReorderUtils.combineProjectAndChildExprId(a, newLeftProjects); - - /* ********** Conjuncts ********** */ - List<Expression> newTopHashConjuncts = bottomJoin.getHashJoinConjuncts(); - List<Expression> newBottomHashConjuncts = topJoin.getHashJoinConjuncts(); - List<Expression> newTopOtherConjuncts = bottomJoin.getOtherJoinConjuncts(); - List<Expression> newBottomOtherConjuncts = topJoin.getOtherJoinConjuncts(); - - /* ********** replace Conjuncts by projects ********** */ - Map<ExprId, Slot> inputToOutput = new HashMap<>(); - Map<ExprId, Slot> outputToInput = new HashMap<>(); - for (NamedExpression expr : projects) { - if (expr instanceof Alias) { - Alias alias = (Alias) expr; - Slot outputSlot = alias.toSlot(); - Expression child = alias.child(); - // checkProject already confirmed. - Preconditions.checkState(child instanceof Slot); - Slot inputSlot = (Slot) child; - inputToOutput.put(inputSlot.getExprId(), outputSlot); - outputToInput.put(outputSlot.getExprId(), inputSlot); - } + List<NamedExpression> aProjects = map.get(true); + if (aProjects.isEmpty()) { + return null; } - // replace hashConjuncts - newBottomHashConjuncts = JoinUtils.replaceJoinConjuncts(newBottomHashConjuncts, outputToInput); - newTopHashConjuncts = JoinUtils.replaceJoinConjuncts(newTopHashConjuncts, inputToOutput); - // replace otherConjuncts - newBottomOtherConjuncts = JoinUtils.replaceJoinConjuncts(newBottomOtherConjuncts, outputToInput); - newTopOtherConjuncts = JoinUtils.replaceJoinConjuncts(newTopOtherConjuncts, inputToOutput); + List<NamedExpression> bProjects = map.get(false); + Set<ExprId> bProjectsExprIds = bProjects.stream().map(NamedExpression::getExprId) + .collect(Collectors.toSet()); - /* ********** check ********** */ - Set<Slot> acOutputSet = ImmutableSet.<Slot>builder().addAll(a.getOutputSet()) - .addAll(c.getOutputSet()).build(); - if (!Stream.concat(newBottomHashConjuncts.stream(), newBottomOtherConjuncts.stream()) - .allMatch(expr -> { - Set<Slot> inputSlots = expr.getInputSlots(); - return acOutputSet.containsAll(inputSlots); - })) { + // topJoin condition can't contain bProject output. just can (A C) + if (Stream.concat(topJoin.getHashJoinConjuncts().stream(), topJoin.getOtherJoinConjuncts().stream()) + .anyMatch(expr -> Utils.isIntersecting(expr.getInputSlotExprIds(), bProjectsExprIds))) { return null; } - // Add all slots used by OnCondition when projects not empty. - Map<Boolean, Set<Slot>> abOnUsedSlots = Stream.concat( - newTopHashConjuncts.stream(), - newTopOtherConjuncts.stream()) - .flatMap(onExpr -> { - Set<Slot> usedSlotRefs = onExpr.collect(SlotReference.class::isInstance); - return usedSlotRefs.stream(); - }) - .filter(slot -> !cOutputExprIdSet.contains(slot.getExprId())) - .collect(Collectors.partitioningBy( - slot -> aExprIdSet.contains(slot.getExprId()), Collectors.toSet())); - Set<Slot> aUsedSlots = abOnUsedSlots.get(true); - Set<Slot> bUsedSlots = abOnUsedSlots.get(false); + // topJoin condition -> newBottomJoin condition, bottomJoin condition -> newTopJoin condition + JoinReorderHelper helper = new JoinReorderHelper(bottomJoin.getHashJoinConjuncts(), + bottomJoin.getOtherJoinConjuncts(), topJoin.getHashJoinConjuncts(), + topJoin.getOtherJoinConjuncts(), projects, aProjects, bProjects); - JoinUtils.addSlotsUsedByOn(bUsedSlots, newRightProjects); - JoinUtils.addSlotsUsedByOn(aUsedSlots, newLeftProjects); + // Add all slots used by OnCondition when projects not empty. + helper.addSlotsUsedByOn(JoinReorderUtils.combineProjectAndChildExprId(a, helper.newLeftProjects), + Collections.EMPTY_SET); - if (!newLeftProjects.isEmpty()) { - Set<Slot> nullableCOutputSet = forceToNullable(cOutputSet); - newLeftProjects.addAll(nullableCOutputSet); - } + aProjects.addAll(forceToNullable(c.getOutputSet())); /* ********** new Plan ********** */ LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), - newBottomHashConjuncts, newBottomOtherConjuncts, JoinHint.NONE, + helper.newBottomHashConjuncts, helper.newBottomOtherConjuncts, JoinHint.NONE, a, c, bottomJoin.getJoinReorderContext()); newBottomJoin.getJoinReorderContext().setHasLAsscom(false); newBottomJoin.getJoinReorderContext().setHasCommute(false); - Plan left = JoinReorderUtils.projectOrSelf(newLeftProjects, newBottomJoin); - Plan right = JoinReorderUtils.projectOrSelf(newRightProjects, b); + Plan left = JoinReorderUtils.projectOrSelf(aProjects, newBottomJoin); + Plan right = JoinReorderUtils.projectOrSelf(bProjects, b); LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), - newTopHashConjuncts, newTopOtherConjuncts, JoinHint.NONE, + helper.newTopHashConjuncts, helper.newTopOtherConjuncts, JoinHint.NONE, left, right, topJoin.getJoinReorderContext()); newTopJoin.getJoinReorderContext().setHasLAsscom(true); return JoinReorderUtils.projectOrSelf(new ArrayList<>(topJoin.getOutput()), newTopJoin); }).toRule(RuleType.LOGICAL_OUTER_JOIN_LASSCOM_PROJECT); } - private Set<Slot> forceToNullable(Set<Slot> slotSet) { + /** + * Force all slots in set to nullable. + */ + public static Set<Slot> forceToNullable(Set<Slot> slotSet) { return slotSet.stream().map(s -> (Slot) s.rewriteUp(e -> { if (e instanceof SlotReference) { return ((SlotReference) e).withNullable(true); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java index 78a354cb36..79b79510b3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTranspose.java @@ -26,7 +26,7 @@ import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.JoinHint; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; @@ -73,7 +73,7 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory { boolean lasscom = false; for (Expression hashJoinConjunct : hashJoinConjuncts) { Set<ExprId> usedSlotExprIds = hashJoinConjunct.getInputSlotExprIds(); - lasscom = ExpressionUtils.isIntersecting(usedSlotExprIds, aOutputExprIdSet) || lasscom; + lasscom = Utils.isIntersecting(usedSlotExprIds, aOutputExprIdSet) || lasscom; } if (lasscom) { @@ -133,8 +133,8 @@ public class SemiJoinLogicalJoinTranspose extends OneExplorationRuleFactory { boolean hashContainsB = false; for (Expression hashJoinConjunct : hashJoinConjuncts) { Set<ExprId> usedSlotExprIds = hashJoinConjunct.getInputSlotExprIds(); - hashContainsA = ExpressionUtils.isIntersecting(usedSlotExprIds, aOutputExprIdSet) || hashContainsA; - hashContainsB = ExpressionUtils.isIntersecting(usedSlotExprIds, bOutputExprIdSet) || hashContainsB; + hashContainsA = Utils.isIntersecting(usedSlotExprIds, aOutputExprIdSet) || hashContainsA; + hashContainsB = Utils.isIntersecting(usedSlotExprIds, bOutputExprIdSet) || hashContainsB; } if (leftDeep && hashContainsB) { return false; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java index 6d2f705ec7..76fda42fe5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/SemiJoinLogicalJoinTransposeProject.java @@ -29,7 +29,7 @@ import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; @@ -79,7 +79,7 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto boolean lasscom = false; for (Expression hashJoinConjunct : hashJoinConjuncts) { Set<ExprId> usedSlotExprIdSet = hashJoinConjunct.getInputSlotExprIds(); - lasscom = ExpressionUtils.isIntersecting(usedSlotExprIdSet, aOutputExprIdSet) || lasscom; + lasscom = Utils.isIntersecting(usedSlotExprIdSet, aOutputExprIdSet) || lasscom; } if (lasscom) { @@ -148,8 +148,8 @@ public class SemiJoinLogicalJoinTransposeProject extends OneExplorationRuleFacto boolean hashContainsB = false; for (Expression hashJoinConjunct : hashJoinConjuncts) { Set<Slot> usedSlot = hashJoinConjunct.collect(Slot.class::isInstance); - hashContainsA = ExpressionUtils.isIntersecting(usedSlot, aOutput) || hashContainsA; - hashContainsB = ExpressionUtils.isIntersecting(usedSlot, bOutput) || hashContainsB; + hashContainsA = Utils.isIntersecting(usedSlot, aOutput) || hashContainsA; + hashContainsB = Utils.isIntersecting(usedSlot, bOutput) || hashContainsB; } if (leftDeep && hashContainsB) { return false; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java index 610acb6835..2abd3876bf 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java @@ -33,6 +33,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.JoinUtils; import org.apache.doris.nereids.util.PlanUtils; +import org.apache.doris.nereids.util.Utils; import com.google.common.base.Preconditions; import com.google.common.collect.ImmutableList; @@ -230,7 +231,7 @@ public class ReorderJoin extends OneRewriteRuleFactory { Set<ExprId> rightOutputExprIdSet = right.getOutputExprIdSet(); Map<Boolean, List<Expression>> split = multiJoin.getJoinFilter().stream() .collect(Collectors.partitioningBy(expr -> - ExpressionUtils.isIntersecting(rightOutputExprIdSet, expr.getInputSlotExprIds()) + Utils.isIntersecting(rightOutputExprIdSet, expr.getInputSlotExprIds()) )); remainingFilter = split.get(true); List<Expression> pushedFilter = split.get(false); @@ -244,7 +245,7 @@ public class ReorderJoin extends OneRewriteRuleFactory { Set<ExprId> leftOutputExprIdSet = left.getOutputExprIdSet(); Map<Boolean, List<Expression>> split = multiJoin.getJoinFilter().stream() .collect(Collectors.partitioningBy(expr -> - ExpressionUtils.isIntersecting(leftOutputExprIdSet, expr.getInputSlotExprIds()) + Utils.isIntersecting(leftOutputExprIdSet, expr.getInputSlotExprIds()) )); remainingFilter = split.get(true); List<Expression> pushedFilter = split.get(false); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index f067c9dba8..6128aa39dc 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -199,30 +199,6 @@ public class ExpressionUtils { .orElse(BooleanLiteral.of(type == And.class)); } - /** - * Check whether lhs and rhs are intersecting. - */ - public static <T> boolean isIntersecting(Set<T> lhs, List<T> rhs) { - for (T rh : rhs) { - if (lhs.contains(rh)) { - return true; - } - } - return false; - } - - /** - * Check whether lhs and rhs are intersecting. - */ - public static <T> boolean isIntersecting(Set<T> lhs, Set<T> rhs) { - for (T rh : rhs) { - if (lhs.contains(rh)) { - return true; - } - } - return false; - } - /** * Choose the minimum slot from input parameter. */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java index 0532f0b26a..7600656c64 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/JoinUtils.java @@ -27,7 +27,6 @@ import org.apache.doris.nereids.properties.DistributionSpecReplicated; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.ExprId; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; @@ -294,40 +293,6 @@ public class JoinUtils { return false; } - /** - * replace JoinConjuncts by using slots map. - */ - public static List<Expression> replaceJoinConjuncts(List<Expression> joinConjuncts, - Map<ExprId, Slot> replaceMaps) { - return joinConjuncts.stream() - .map(expr -> - expr.rewriteUp(e -> { - if (e instanceof Slot && replaceMaps.containsKey(((Slot) e).getExprId())) { - return replaceMaps.get(((Slot) e).getExprId()); - } else { - return e; - } - }) - ).collect(ImmutableList.toImmutableList()); - } - - /** - * When project not empty, we add all slots used by hashOnCondition into projects. - */ - public static void addSlotsUsedByOn(Set<Slot> usedSlots, List<NamedExpression> projects) { - if (projects.isEmpty()) { - return; - } - Set<ExprId> projectExprIdSet = projects.stream() - .map(NamedExpression::getExprId) - .collect(Collectors.toSet()); - usedSlots.forEach(slot -> { - if (!projectExprIdSet.contains(slot.getExprId())) { - projects.add(slot); - } - }); - } - public static Set<ExprId> getJoinOutputExprIdSet(Plan left, Plan right) { Set<ExprId> joinOutputExprIdSet = new HashSet<>(); joinOutputExprIdSet.addAll(left.getOutputExprIdSet()); @@ -342,6 +307,7 @@ public class JoinUtils { /** * calculate the output slot of a join operator according join type and its children + * * @param joinType the type of join operator * @param left left child * @param right right child diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java index 6609a2938a..512a2147e2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/Utils.java @@ -79,6 +79,30 @@ public class Utils { return (R) ans[0]; } + /** + * Check whether lhs and rhs are intersecting. + */ + public static <T> boolean isIntersecting(Set<T> lhs, List<T> rhs) { + for (T rh : rhs) { + if (lhs.contains(rh)) { + return true; + } + } + return false; + } + + /** + * Check whether lhs and rhs are intersecting. + */ + public static <T> boolean isIntersecting(Set<T> lhs, Set<T> rhs) { + for (T rh : rhs) { + if (lhs.contains(rh)) { + return true; + } + } + return false; + } + /** * Wrapper to a function without return value. */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java new file mode 100644 index 0000000000..49e7a13795 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinAssocTest.java @@ -0,0 +1,72 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalPlan; +import org.apache.doris.nereids.util.LogicalPlanBuilder; +import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PatternMatchSupported; +import org.apache.doris.nereids.util.PlanChecker; +import org.apache.doris.nereids.util.PlanConstructor; + +import org.junit.jupiter.api.Test; + +import java.util.Objects; + +class OuterJoinAssocTest implements PatternMatchSupported { + LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); + + @Test + public void testInnerLeft() { + LogicalPlan join = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.INNER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .join(scan3, JoinType.LEFT_OUTER_JOIN, Pair.of(2, 0)) // t2.id = t3.id + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), join) + .applyExploration(OuterJoinAssoc.INSTANCE.build()) + .matchesExploration( + logicalJoin( + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), + logicalJoin() + ).when(top -> Objects.equals(top.getHashJoinConjuncts().toString(), "[(id#0 = id#2)]")) + ); + } + + @Test + public void testLeftLeft() { + LogicalPlan join = new LogicalPlanBuilder(scan1) + .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id = t2.id + .join(scan3, JoinType.LEFT_OUTER_JOIN, Pair.of(2, 0)) // t2.id = t3.id + .build(); + + PlanChecker.from(MemoTestUtils.createConnectContext(), join) + .applyExploration(OuterJoinAssoc.INSTANCE.build()) + .matchesExploration( + logicalJoin( + logicalOlapScan().when(scan -> scan.getTable().getName().equals("t1")), + logicalJoin() + ).when(top -> Objects.equals(top.getHashJoinConjuncts().toString(), "[(id#0 = id#2)]")) + ); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java index b2492e86e9..79d705f231 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/OuterJoinLAsscomProjectTest.java @@ -106,9 +106,7 @@ class OuterJoinLAsscomProjectTest implements PatternMatchSupported { PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyExploration(OuterJoinLAsscomProject.INSTANCE.build()) .printlnOrigin() - .checkMemo(memo -> { - Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size()); - }); + .checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size())); } @Test @@ -116,16 +114,14 @@ class OuterJoinLAsscomProjectTest implements PatternMatchSupported { LogicalPlan plan = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.LEFT_OUTER_JOIN, Pair.of(0, 0)) // t1.id=t2.id .alias(ImmutableList.of(0, 2), ImmutableList.of("t1.id", "t2.id")) - // t1.id=t3.id t2.id = t3.id + // t1.id=t3.id t2.id=t3.id .join(scan3, JoinType.INNER_JOIN, ImmutableList.of(Pair.of(0, 0), Pair.of(1, 0))) .build(); // transform failed. PlanChecker.from(MemoTestUtils.createConnectContext(), plan) .applyExploration(OuterJoinLAsscomProject.INSTANCE.build()) - .checkMemo(memo -> { - Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size()); - }); + .checkMemo(memo -> Assertions.assertEquals(1, memo.getRoot().getLogicalExpressions().size())); } @Test @@ -134,16 +130,18 @@ class OuterJoinLAsscomProjectTest implements PatternMatchSupported { new EqualTo(scan1.getOutput().get(0), scan2.getOutput().get(0))); List<Expression> bottomOtherJoinConjunct = ImmutableList.of( new GreaterThan(scan1.getOutput().get(1), scan2.getOutput().get(1))); - List<Expression> topHashJoinConjunct = ImmutableList.of( - new EqualTo(scan1.getOutput().get(0), scan3.getOutput().get(0)), - new EqualTo(scan2.getOutput().get(0), scan3.getOutput().get(0))); - List<Expression> topOtherJoinConjunct = ImmutableList.of( - new GreaterThan(scan1.getOutput().get(1), scan3.getOutput().get(1)), - new GreaterThan(scan2.getOutput().get(1), scan3.getOutput().get(1))); - - LogicalPlan plan = new LogicalPlanBuilder(scan1) + LogicalPlan project = new LogicalPlanBuilder(scan1) .join(scan2, JoinType.LEFT_OUTER_JOIN, bottomHashJoinConjunct, bottomOtherJoinConjunct) .alias(ImmutableList.of(0, 1, 2, 3), ImmutableList.of("t1.id", "t1.name", "t2.id", "t2.name")) + .build(); + + List<Expression> topHashJoinConjunct = ImmutableList.of( + new EqualTo(project.getOutput().get(0), scan3.getOutput().get(0)), + new EqualTo(project.getOutput().get(2), scan3.getOutput().get(0))); + List<Expression> topOtherJoinConjunct = ImmutableList.of( + new GreaterThan(project.getOutput().get(1), scan3.getOutput().get(1)), + new GreaterThan(project.getOutput().get(3), scan3.getOutput().get(1))); + LogicalPlan plan = new LogicalPlanBuilder(project) .join(scan3, JoinType.LEFT_OUTER_JOIN, topHashJoinConjunct, topOtherJoinConjunct) .build(); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org