This is an automated email from the ASF dual-hosted git repository. jakevin pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 7c7ac86fe8 [feature](Nereids): Left deep tree join order. (#12439) 7c7ac86fe8 is described below commit 7c7ac86fe8e3360e39b96d2c517aa25b61b18426 Author: jakevin <jakevin...@gmail.com> AuthorDate: Thu Sep 8 15:09:22 2022 +0800 [feature](Nereids): Left deep tree join order. (#12439) * [feature](Nereids): Left deep tree join order. --- .../org/apache/doris/nereids/rules/RuleSet.java | 44 +++- .../rules/exploration/join/JoinCommute.java | 74 +++--- .../rules/exploration/join/JoinCommuteProject.java | 66 ------ .../rules/exploration/join/JoinLAsscom.java | 50 +++-- .../rules/exploration/join/JoinLAsscomHelper.java | 250 +++++---------------- .../rules/exploration/join/JoinLAsscomProject.java | 55 +++-- ...inCommuteHelper.java => JoinReorderCommon.java} | 36 ++- .../exploration/{ => join}/JoinReorderContext.java | 19 +- .../rules/exploration/join/ThreeJoinHelper.java | 165 ++++++++++++++ .../nereids/trees/plans/logical/LogicalJoin.java | 2 +- .../apache/doris/nereids/util/ExpressionUtils.java | 22 ++ .../rules/exploration/join/JoinCommuteTest.java | 30 ++- .../exploration/join/JoinLAsscomProjectTest.java | 134 ----------- .../rules/exploration/join/JoinLAsscomTest.java | 156 +++++-------- .../org/apache/doris/nereids/util/PlanChecker.java | 37 ++- 15 files changed, 548 insertions(+), 592 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index fb2696788c..936834605b 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -18,7 +18,8 @@ package org.apache.doris.nereids.rules; import org.apache.doris.nereids.rules.exploration.join.JoinCommute; -import org.apache.doris.nereids.rules.exploration.join.JoinCommuteProject; +import org.apache.doris.nereids.rules.exploration.join.JoinLAsscom; +import org.apache.doris.nereids.rules.exploration.join.JoinLAsscomProject; import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg; import org.apache.doris.nereids.rules.implementation.LogicalAssertNumRowsToPhysicalAssertNumRows; import org.apache.doris.nereids.rules.implementation.LogicalEmptyRelationToPhysicalEmptyRelation; @@ -32,6 +33,7 @@ import org.apache.doris.nereids.rules.implementation.LogicalProjectToPhysicalPro import org.apache.doris.nereids.rules.implementation.LogicalSortToPhysicalQuickSort; import org.apache.doris.nereids.rules.implementation.LogicalTopNToPhysicalTopN; import org.apache.doris.nereids.rules.rewrite.AggregateDisassemble; +import org.apache.doris.nereids.rules.rewrite.logical.MergeConsecutiveProjects; import com.google.common.collect.ImmutableList; import com.google.common.collect.ImmutableList.Builder; @@ -43,8 +45,10 @@ import java.util.List; */ public class RuleSet { public static final List<Rule> EXPLORATION_RULES = planRuleFactories() - .add(JoinCommute.SWAP_OUTER_SWAP_ZIG_ZAG) - .add(JoinCommuteProject.SWAP_OUTER_SWAP_ZIG_ZAG) + .add(JoinCommute.OUTER_LEFT_DEEP) + .add(JoinLAsscom.INNER) + .add(JoinLAsscomProject.INNER) + .add(new MergeConsecutiveProjects()) .build(); public static final List<Rule> REWRITE_RULES = planRuleFactories() @@ -66,6 +70,40 @@ public class RuleSet { .add(new LogicalEmptyRelationToPhysicalEmptyRelation()) .build(); + public static final List<Rule> LEFT_DEEP_TREE_JOIN_REORDER = planRuleFactories() + .add(JoinCommute.OUTER_LEFT_DEEP) + .add(JoinLAsscom.INNER) + .add(JoinLAsscomProject.INNER) + .add(JoinLAsscom.OUTER) + .add(JoinLAsscomProject.OUTER) + // semi join Transpose .... + .build(); + + public static final List<Rule> ZIG_ZAG_TREE_JOIN_REORDER = planRuleFactories() + .add(JoinCommute.OUTER_ZIG_ZAG) + .add(JoinLAsscom.INNER) + .add(JoinLAsscomProject.INNER) + .add(JoinLAsscom.OUTER) + .add(JoinLAsscomProject.OUTER) + // semi join Transpose .... + .build(); + + public static final List<Rule> BUSHY_TREE_JOIN_REORDER = planRuleFactories() + .add(JoinCommute.OUTER_BUSHY) + // TODO: add more rule + // .add(JoinLeftAssociate.INNER) + // .add(JoinLeftAssociateProject.INNER) + // .add(JoinRightAssociate.INNER) + // .add(JoinRightAssociateProject.INNER) + // .add(JoinExchange.INNER) + // .add(JoinExchangeBothProject.INNER) + // .add(JoinExchangeLeftProject.INNER) + // .add(JoinExchangeRightProject.INNER) + // .add(JoinRightAssociate.OUTER) + .add(JoinLAsscom.OUTER) + // semi join Transpose .... + .build(); + public List<Rule> getExplorationRules() { return EXPLORATION_RULES; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java index 64ebe5171f..129b655e5f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java @@ -17,54 +17,72 @@ package org.apache.doris.nereids.rules.exploration.join; -import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; -import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.Utils; + +import java.util.ArrayList; +import java.util.List; /** * Join Commute */ -@Developing public class JoinCommute extends OneExplorationRuleFactory { - public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN); - public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG); + public static final JoinCommute OUTER_LEFT_DEEP = new JoinCommute(SwapType.LEFT_DEEP); + public static final JoinCommute OUTER_ZIG_ZAG = new JoinCommute(SwapType.ZIG_ZAG); + public static final JoinCommute OUTER_BUSHY = new JoinCommute(SwapType.BUSHY); - private final boolean swapOuter; private final SwapType swapType; - public JoinCommute(boolean swapOuter) { - this.swapOuter = swapOuter; - this.swapType = SwapType.ALL; + public JoinCommute(SwapType swapType) { + this.swapType = swapType; } - public JoinCommute(boolean swapOuter, SwapType swapType) { - this.swapOuter = swapOuter; - this.swapType = swapType; + enum SwapType { + LEFT_DEEP, ZIG_ZAG, BUSHY } @Override public Rule build() { - return innerLogicalJoin().when(JoinCommuteHelper::check).then(join -> { - // TODO: add project for mapping column output. - // List<NamedExpression> newOutput = new ArrayList<>(join.getOutput()); - LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>( - join.getJoinType(), - join.getHashJoinConjuncts(), - join.getOtherJoinCondition(), - join.right(), join.left(), - join.getJoinReorderContext()); - newJoin.getJoinReorderContext().setHasCommute(true); - // if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) { - // newJoin.getJoinReorderContext().setHasCommuteZigZag(true); - // } + return innerLogicalJoin() + .when(this::check) + .then(join -> { + LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>( + join.getJoinType(), + join.getHashJoinConjuncts(), + join.getOtherJoinCondition(), + join.right(), join.left(), + join.getJoinReorderContext()); + newJoin.getJoinReorderContext().setHasCommute(true); + if (swapType == SwapType.ZIG_ZAG && isNotBottomJoin(join)) { + newJoin.getJoinReorderContext().setHasCommuteZigZag(true); + } + + return JoinReorderCommon.project(new ArrayList<>(join.getOutput()), newJoin).get(); + }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE); + } + + private boolean check(LogicalJoin<GroupPlan, GroupPlan> join) { + if (swapType == SwapType.LEFT_DEEP && isNotBottomJoin(join)) { + return false; + } + + return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange(); + } + + private boolean isNotBottomJoin(LogicalJoin<GroupPlan, GroupPlan> join) { + // TODO: tmp way to judge bottomJoin + return containJoin(join.left()) || containJoin(join.right()); + } - // LogicalProject<LogicalJoin> project = new LogicalProject<>(newOutput, newJoin); - return newJoin; - }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE); + private boolean containJoin(GroupPlan groupPlan) { + // TODO: tmp way to judge containJoin + List<SlotReference> output = Utils.getOutputSlotReference(groupPlan); + return !output.stream().map(SlotReference::getQualifier).allMatch(output.get(0).getQualifier()::equals); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java deleted file mode 100644 index 07464275a1..0000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteProject.java +++ /dev/null @@ -1,66 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.exploration.join; - -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; -import org.apache.doris.nereids.rules.exploration.join.JoinCommuteHelper.SwapType; -import org.apache.doris.nereids.trees.plans.GroupPlan; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; - -/** - * Project-Join commute - */ -public class JoinCommuteProject extends OneExplorationRuleFactory { - - public static final JoinCommute SWAP_OUTER_COMMUTE_BOTTOM_JOIN = new JoinCommute(true, SwapType.BOTTOM_JOIN); - public static final JoinCommute SWAP_OUTER_SWAP_ZIG_ZAG = new JoinCommute(true, SwapType.ZIG_ZAG); - - private final SwapType swapType; - private final boolean swapOuter; - - public JoinCommuteProject(boolean swapOuter) { - this.swapOuter = swapOuter; - this.swapType = SwapType.ALL; - } - - public JoinCommuteProject(boolean swapOuter, SwapType swapType) { - this.swapOuter = swapOuter; - this.swapType = swapType; - } - - @Override - public Rule build() { - return logicalProject(innerLogicalJoin()).when(JoinCommuteHelper::check).then(project -> { - LogicalJoin<GroupPlan, GroupPlan> join = project.child(); - LogicalJoin<GroupPlan, GroupPlan> newJoin = new LogicalJoin<>( - join.getJoinType(), - join.getHashJoinConjuncts(), - join.getOtherJoinCondition(), - join.right(), join.left(), - join.getJoinReorderContext()); - newJoin.getJoinReorderContext().setHasCommute(true); - // if (swapType == SwapType.ZIG_ZAG && !isBottomJoin(join)) { - // newJoin.getJoinReorderContext().setHasCommuteZigZag(true); - // } - - return newJoin; - }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java index 65f849cd5d..07d8acaceb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java @@ -17,18 +17,42 @@ package org.apache.doris.nereids.rules.exploration.join; -import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import java.util.function.Predicate; + /** * Rule for change inner join LAsscom (associative and commutive). */ -@Developing public class JoinLAsscom extends OneExplorationRuleFactory { + // for inner-inner + public static final JoinLAsscom INNER = new JoinLAsscom(Type.INNER); + // for inner-leftOuter or leftOuter-leftOuter + public static final JoinLAsscom OUTER = new JoinLAsscom(Type.OUTER); + + private final Predicate<LogicalJoin<LogicalJoin<GroupPlan, GroupPlan>, GroupPlan>> typeChecker; + + private final Type type; + + /** + * Specify join type. + */ + public JoinLAsscom(Type type) { + this.type = type; + if (type == Type.INNER) { + typeChecker = join -> join.getJoinType().isInnerJoin() && join.left().getJoinType().isInnerJoin(); + } else { + typeChecker = join -> JoinLAsscomHelper.outerSet.contains( + Pair.of(join.left().getJoinType(), join.getJoinType())); + } + } + /* * topJoin newTopJoin * / \ / \ @@ -39,18 +63,14 @@ public class JoinLAsscom extends OneExplorationRuleFactory { @Override public Rule build() { return logicalJoin(logicalJoin(), group()) - .when(JoinLAsscomHelper::check) - .when(join -> join.getJoinType().isInnerJoin() || join.getJoinType().isLeftOuterJoin() - && (join.left().getJoinType().isInnerJoin() || join.left().getJoinType().isLeftOuterJoin())) - .then(topJoin -> { - - LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left(); - JoinLAsscomHelper helper = JoinLAsscomHelper.of(topJoin, bottomJoin); - if (!helper.initJoinOnCondition()) { - return null; - } - - return helper.newTopJoin(); - }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); + .when(topJoin -> JoinLAsscomHelper.check(type, topJoin, topJoin.left())) + .when(typeChecker) + .then(topJoin -> { + JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin, topJoin.left()); + if (!helper.initJoinOnCondition()) { + return null; + } + return helper.newTopJoin(); + }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java index e6fe676406..ac31083bde 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomHelper.java @@ -18,29 +18,27 @@ package org.apache.doris.nereids.rules.exploration.join; import org.apache.doris.common.Pair; -import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type; import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.ExpressionUtils; import org.apache.doris.nereids.util.Utils; -import com.google.common.base.Preconditions; -import com.google.common.collect.Lists; +import com.google.common.collect.ImmutableSet; -import java.util.HashSet; +import java.util.ArrayList; import java.util.List; -import java.util.Optional; import java.util.Set; +import java.util.stream.Collectors; /** * Common function for JoinLAsscom */ -public class JoinLAsscomHelper { +class JoinLAsscomHelper extends ThreeJoinHelper { /* * topJoin newTopJoin * / \ / \ @@ -48,209 +46,79 @@ public class JoinLAsscomHelper { * / \ / \ * A B A C */ - private final LogicalJoin topJoin; - private final LogicalJoin<GroupPlan, GroupPlan> bottomJoin; - private final Plan a; - private final Plan b; - private final Plan c; - private final List<Expression> topHashJoinConjuncts; - private final List<Expression> bottomHashJoinConjuncts; - private final List<Expression> allNonHashJoinConjuncts = Lists.newArrayList(); - private final List<SlotReference> aOutputSlots; - private final List<SlotReference> bOutputSlots; - private final List<SlotReference> cOutputSlots; - - private final List<Expression> newBottomHashJoinConjuncts = Lists.newArrayList(); - private final List<Expression> newBottomNonHashJoinConjuncts = Lists.newArrayList(); - - private final List<Expression> newTopHashJoinConjuncts = Lists.newArrayList(); - private final List<Expression> newTopNonHashJoinConjuncts = Lists.newArrayList(); + // Pair<bottomJoin, topJoin> + // newBottomJoin Type = topJoin Type, newTopJoin Type = bottomJoin Type + public static Set<Pair<JoinType, JoinType>> outerSet = ImmutableSet.of( + Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.INNER_JOIN), + Pair.of(JoinType.INNER_JOIN, JoinType.LEFT_OUTER_JOIN), + Pair.of(JoinType.LEFT_OUTER_JOIN, JoinType.LEFT_OUTER_JOIN)); /** * Init plan and output. */ public JoinLAsscomHelper(LogicalJoin<? extends Plan, GroupPlan> topJoin, LogicalJoin<GroupPlan, GroupPlan> bottomJoin) { - this.topJoin = topJoin; - this.bottomJoin = bottomJoin; - - a = bottomJoin.left(); - b = bottomJoin.right(); - c = topJoin.right(); - - Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(), - "topJoin hashJoinConjuncts must exist."); - topHashJoinConjuncts = topJoin.getHashJoinConjuncts(); - if (topJoin.getOtherJoinCondition().isPresent()) { - allNonHashJoinConjuncts.addAll( - ExpressionUtils.extractConjunction(topJoin.getOtherJoinCondition().get())); - } - Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(), - "bottomJoin onClause must exist."); - bottomHashJoinConjuncts = bottomJoin.getHashJoinConjuncts(); - if (bottomJoin.getOtherJoinCondition().isPresent()) { - allNonHashJoinConjuncts.addAll( - ExpressionUtils.extractConjunction(bottomJoin.getOtherJoinCondition().get())); - } - - aOutputSlots = Utils.getOutputSlotReference(a); - bOutputSlots = Utils.getOutputSlotReference(b); - cOutputSlots = Utils.getOutputSlotReference(c); - } - - public static JoinLAsscomHelper of(LogicalJoin<? extends Plan, GroupPlan> topJoin, - LogicalJoin<GroupPlan, GroupPlan> bottomJoin) { - return new JoinLAsscomHelper(topJoin, bottomJoin); - } - - /** - * Get the onCondition of newTopJoin and newBottomJoin. - */ - public boolean initJoinOnCondition() { - for (Expression topJoinOnClauseConjunct : topHashJoinConjuncts) { - // Ignore join with some OnClause like: - // Join C = B + A for above example. - Set<Slot> topJoinUsedSlot = topJoinOnClauseConjunct.getInputSlots(); - if (topJoinUsedSlot.containsAll(aOutputSlots) - && topJoinUsedSlot.containsAll(bOutputSlots) - && topJoinUsedSlot.containsAll(cOutputSlots)) { - return false; - } - } - - List<Expression> allHashJoinConjuncts = Lists.newArrayList(); - allHashJoinConjuncts.addAll(topHashJoinConjuncts); - allHashJoinConjuncts.addAll(bottomHashJoinConjuncts); - - Set<Slot> newBottomJoinSlots = new HashSet<>(aOutputSlots); - newBottomJoinSlots.addAll(cOutputSlots); - - for (Expression hashConjunct : allHashJoinConjuncts) { - Set<Slot> slots = hashConjunct.getInputSlots(); - if (newBottomJoinSlots.containsAll(slots)) { - newBottomHashJoinConjuncts.add(hashConjunct); - } else { - newTopHashJoinConjuncts.add(hashConjunct); - } - } - for (Expression nonHashConjunct : allNonHashJoinConjuncts) { - Set<SlotReference> slots = nonHashConjunct.collect(SlotReference.class::isInstance); - if (newBottomJoinSlots.containsAll(slots)) { - newBottomNonHashJoinConjuncts.add(nonHashConjunct); - } else { - newTopNonHashJoinConjuncts.add(nonHashConjunct); - } - } - // newBottomJoinOnCondition/newTopJoinOnCondition is empty. They are cross join. - // Example: - // A: col1, col2. B: col2, col3. C: col3, col4 - // (A & B on A.col2=B.col2) & C on B.col3=C.col3. - // (A & B) & C -> (A & C) & B. - // (A & C) will be cross join (newBottomJoinOnCondition is empty) - if (newBottomHashJoinConjuncts.isEmpty() || newTopHashJoinConjuncts.isEmpty()) { - return false; - } - - return true; + super(topJoin, bottomJoin, bottomJoin.left(), bottomJoin.right(), topJoin.right()); } /** - * Get projectExpr of left and right. - * Just for project-inside. + * Create newTopJoin. */ - private Pair<List<NamedExpression>, List<NamedExpression>> getProjectExprs() { - Preconditions.checkArgument(topJoin.left() instanceof LogicalProject); - LogicalProject project = (LogicalProject) topJoin.left(); - - List<NamedExpression> projectExprs = project.getProjects(); - List<NamedExpression> newRightProjectExprs = Lists.newArrayList(); - List<NamedExpression> newLeftProjectExpr = Lists.newArrayList(); - - HashSet<SlotReference> bOutputSlotsSet = new HashSet<>(bOutputSlots); - for (NamedExpression projectExpr : projectExprs) { - Set<SlotReference> usedSlotRefs = projectExpr.collect(SlotReference.class::isInstance); - if (bOutputSlotsSet.containsAll(usedSlotRefs)) { - newRightProjectExprs.add(projectExpr); - } else { - newLeftProjectExpr.add(projectExpr); + public Plan newTopJoin() { + Pair<List<NamedExpression>, List<NamedExpression>> projectPair = splitProjectExprs(bOutput); + List<NamedExpression> newLeftProjectExpr = projectPair.second; + List<NamedExpression> newRightProjectExprs = projectPair.first; + + // If add project to B, we should add all slotReference used by hashOnCondition. + // TODO: Does nonHashOnCondition also need to be considered. + Set<SlotReference> onUsedSlotRef = bottomJoin.getHashJoinConjuncts().stream() + .flatMap(expr -> { + Set<SlotReference> usedSlotRefs = expr.collect(SlotReference.class::isInstance); + return usedSlotRefs.stream(); + }).filter(Utils.getOutputSlotReference(bottomJoin)::contains).collect(Collectors.toSet()); + boolean existRightProject = !newRightProjectExprs.isEmpty(); + boolean existLeftProject = !newLeftProjectExpr.isEmpty(); + onUsedSlotRef.forEach(slotRef -> { + if (existRightProject && bOutput.contains(slotRef) && !newRightProjectExprs.contains(slotRef)) { + newRightProjectExprs.add(slotRef); + } else if (existLeftProject && aOutput.contains(slotRef) && !newLeftProjectExpr.contains(slotRef)) { + newLeftProjectExpr.add(slotRef); } - } - - return Pair.of(newLeftProjectExpr, newRightProjectExprs); - } + }); - private LogicalJoin<GroupPlan, GroupPlan> newBottomJoin() { - Optional<Expression> bottomNonHashExpr; - if (newBottomNonHashJoinConjuncts.isEmpty()) { - bottomNonHashExpr = Optional.empty(); - } else { - bottomNonHashExpr = Optional.of(ExpressionUtils.and(newBottomNonHashJoinConjuncts)); + if (existLeftProject) { + newLeftProjectExpr.addAll(cOutput); } - return new LogicalJoin( - bottomJoin.getJoinType(), - newBottomHashJoinConjuncts, - bottomNonHashExpr, - a, c); - } + LogicalJoin<GroupPlan, GroupPlan> newBottomJoin = new LogicalJoin<>(topJoin.getJoinType(), + newBottomHashJoinConjuncts, ExpressionUtils.andByOptional(newBottomNonHashJoinConjuncts), a, c, + bottomJoin.getJoinReorderContext()); + newBottomJoin.getJoinReorderContext().setHasLAsscom(false); + newBottomJoin.getJoinReorderContext().setHasCommute(false); - /** - * Create topJoin for project-inside. - */ - public LogicalJoin newProjectTopJoin() { - Plan left; - Plan right; + Plan left = JoinReorderCommon.project(newLeftProjectExpr, newBottomJoin).orElse(newBottomJoin); + Plan right = JoinReorderCommon.project(newRightProjectExprs, b).orElse(b); - List<NamedExpression> newLeftProjectExpr = getProjectExprs().first; - List<NamedExpression> newRightProjectExprs = getProjectExprs().second; - if (!newLeftProjectExpr.isEmpty()) { - left = new LogicalProject<>(newLeftProjectExpr, newBottomJoin()); - } else { - left = newBottomJoin(); - } - if (!newRightProjectExprs.isEmpty()) { - right = new LogicalProject<>(newRightProjectExprs, b); - } else { - right = b; - } - Optional<Expression> topNonHashExpr; - if (newTopNonHashJoinConjuncts.isEmpty()) { - topNonHashExpr = Optional.empty(); - } else { - topNonHashExpr = Optional.of(ExpressionUtils.and(newTopNonHashJoinConjuncts)); - } - return new LogicalJoin<>( - topJoin.getJoinType(), + LogicalJoin<Plan, Plan> newTopJoin = new LogicalJoin<>(bottomJoin.getJoinType(), newTopHashJoinConjuncts, - topNonHashExpr, - left, right); - } + ExpressionUtils.andByOptional(newTopNonHashJoinConjuncts), left, right, + topJoin.getJoinReorderContext()); + newTopJoin.getJoinReorderContext().setHasLAsscom(true); - /** - * Create topJoin for no-project-inside. - */ - public LogicalJoin newTopJoin() { - // TODO: add column map (use project) - // SlotReference bind() may have solved this problem. - // source: | A | B | C | - // target: | A | C | B | - Optional<Expression> topNonHashExpr; - if (newTopNonHashJoinConjuncts.isEmpty()) { - topNonHashExpr = Optional.empty(); - } else { - topNonHashExpr = Optional.of(ExpressionUtils.and(newTopNonHashJoinConjuncts)); - } - return new LogicalJoin( - topJoin.getJoinType(), - newTopHashJoinConjuncts, - topNonHashExpr, - newBottomJoin(), b); + return JoinReorderCommon.project(new ArrayList<>(topJoin.getOutput()), newTopJoin).get(); } - public static boolean check(LogicalJoin topJoin) { - if (topJoin.getJoinReorderContext().hasCommute()) { - return false; + public static boolean check(Type type, LogicalJoin<? extends Plan, GroupPlan> topJoin, + LogicalJoin<GroupPlan, GroupPlan> bottomJoin) { + if (type == Type.INNER) { + return !bottomJoin.getJoinReorderContext().hasCommuteZigZag() + && !topJoin.getJoinReorderContext().hasLAsscom(); + } else { + // hasCommute will cause to lack of OuterJoinAssocRule:Left + return !topJoin.getJoinReorderContext().hasLeftAssociate() + && !topJoin.getJoinReorderContext().hasRightAssociate() + && !topJoin.getJoinReorderContext().hasExchange() + && !bottomJoin.getJoinReorderContext().hasCommute(); } - return true; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java index 9876ed29fd..5bbd120b52 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProject.java @@ -17,18 +17,43 @@ package org.apache.doris.nereids.rules.exploration.join; -import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.common.Pair; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.rules.exploration.join.JoinReorderCommon.Type; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +import java.util.function.Predicate; /** - * Rule for change inner join left associative to right. + * Rule for change inner join LAsscom (associative and commutive). */ -@Developing public class JoinLAsscomProject extends OneExplorationRuleFactory { + // for inner-inner + public static final JoinLAsscomProject INNER = new JoinLAsscomProject(Type.INNER); + // for inner-leftOuter or leftOuter-leftOuter + public static final JoinLAsscomProject OUTER = new JoinLAsscomProject(Type.OUTER); + + private final Predicate<LogicalJoin<LogicalProject<LogicalJoin<GroupPlan, GroupPlan>>, GroupPlan>> typeChecker; + + private final Type type; + + /** + * Specify join type. + */ + public JoinLAsscomProject(Type type) { + this.type = type; + if (type == Type.INNER) { + typeChecker = join -> join.getJoinType().isInnerJoin() && join.left().child().getJoinType().isInnerJoin(); + } else { + typeChecker = join -> JoinLAsscomHelper.outerSet.contains( + Pair.of(join.left().child().getJoinType(), join.getJoinType())); + } + } + /* * topJoin newTopJoin * / \ / \ @@ -41,19 +66,15 @@ public class JoinLAsscomProject extends OneExplorationRuleFactory { @Override public Rule build() { return logicalJoin(logicalProject(logicalJoin()), group()) - .when(JoinLAsscomHelper::check) - .when(join -> join.getJoinType().isInnerJoin() || join.getJoinType().isLeftOuterJoin() - && (join.left().child().getJoinType().isInnerJoin() || join.left().child().getJoinType() - .isLeftOuterJoin())) - .then(topJoin -> { - LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left().child(); - - JoinLAsscomHelper helper = JoinLAsscomHelper.of(topJoin, bottomJoin); - if (!helper.initJoinOnCondition()) { - return null; - } - - return helper.newProjectTopJoin(); - }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); + .when(topJoin -> JoinLAsscomHelper.check(type, topJoin, topJoin.left().child())) + .when(typeChecker) + .then(topJoin -> { + JoinLAsscomHelper helper = new JoinLAsscomHelper(topJoin, topJoin.left().child()); + helper.initAllProject(topJoin.left()); + if (!helper.initJoinOnCondition()) { + return null; + } + return helper.newTopJoin(); + }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderCommon.java similarity index 54% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderCommon.java index 47da5030e1..2a4f81138a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteHelper.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderCommon.java @@ -17,32 +17,24 @@ package org.apache.doris.nereids.rules.exploration.join; -import org.apache.doris.nereids.trees.plans.GroupPlan; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -/** - * Common function for JoinCommute - */ -public class JoinCommuteHelper { +import java.util.List; +import java.util.Optional; - enum SwapType { - BOTTOM_JOIN, ZIG_ZAG, ALL +class JoinReorderCommon { + public enum Type { + INNER, + OUTER } - private final boolean swapOuter; - private final SwapType swapType; - - public JoinCommuteHelper(boolean swapOuter, SwapType swapType) { - this.swapOuter = swapOuter; - this.swapType = swapType; - } - - public static boolean check(LogicalJoin<GroupPlan, GroupPlan> join) { - return !join.getJoinReorderContext().hasCommute() && !join.getJoinReorderContext().hasExchange(); - } - - public static boolean check(LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project) { - return check(project.child()); + public static Optional<Plan> project(List<NamedExpression> projectExprs, Plan plan) { + if (!projectExprs.isEmpty()) { + return Optional.of(new LogicalProject<>(projectExprs, plan)); + } else { + return Optional.empty(); + } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/JoinReorderContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderContext.java similarity index 87% rename from fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/JoinReorderContext.java rename to fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderContext.java index 8384934d42..44166b625f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/JoinReorderContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinReorderContext.java @@ -15,7 +15,7 @@ // specific language governing permissions and limitations // under the License. -package org.apache.doris.nereids.rules.exploration; +package org.apache.doris.nereids.rules.exploration.join; /** * JoinReorderContext for Duplicate free. @@ -26,6 +26,7 @@ package org.apache.doris.nereids.rules.exploration; public class JoinReorderContext { // left deep tree private boolean hasCommute = false; + private boolean hasLAsscom = false; // zig-zag tree private boolean hasCommuteZigZag = false; @@ -38,16 +39,24 @@ public class JoinReorderContext { public JoinReorderContext() { } + /** + * copy a JoinReorderContext. + */ public void copyFrom(JoinReorderContext joinReorderContext) { this.hasCommute = joinReorderContext.hasCommute; + this.hasLAsscom = joinReorderContext.hasLAsscom; this.hasExchange = joinReorderContext.hasExchange; this.hasLeftAssociate = joinReorderContext.hasLeftAssociate; this.hasRightAssociate = joinReorderContext.hasRightAssociate; this.hasCommuteZigZag = joinReorderContext.hasCommuteZigZag; } + /** + * clear all. + */ public void clear() { hasCommute = false; + hasLAsscom = false; hasCommuteZigZag = false; hasExchange = false; hasRightAssociate = false; @@ -62,6 +71,14 @@ public class JoinReorderContext { this.hasCommute = hasCommute; } + public boolean hasLAsscom() { + return hasLAsscom; + } + + public void setHasLAsscom(boolean hasLAsscom) { + this.hasLAsscom = hasLAsscom; + } + public boolean hasExchange() { return hasExchange; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/ThreeJoinHelper.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/ThreeJoinHelper.java new file mode 100644 index 0000000000..fdf70f2c05 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/ThreeJoinHelper.java @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.Utils; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; + +import java.util.HashSet; +import java.util.List; +import java.util.Set; + +/** + * Common join helper for three-join. + */ +abstract class ThreeJoinHelper { + protected final LogicalJoin<? extends Plan, ? extends Plan> topJoin; + protected final LogicalJoin<GroupPlan, GroupPlan> bottomJoin; + protected final GroupPlan a; + protected final GroupPlan b; + protected final GroupPlan c; + + protected final List<SlotReference> aOutput; + protected final List<SlotReference> bOutput; + protected final List<SlotReference> cOutput; + + protected final List<NamedExpression> allProjects = Lists.newArrayList(); + + protected final List<Expression> allHashJoinConjuncts = Lists.newArrayList(); + protected final List<Expression> allNonHashJoinConjuncts = Lists.newArrayList(); + + protected final List<Expression> newBottomHashJoinConjuncts = Lists.newArrayList(); + protected final List<Expression> newBottomNonHashJoinConjuncts = Lists.newArrayList(); + + protected final List<Expression> newTopHashJoinConjuncts = Lists.newArrayList(); + protected final List<Expression> newTopNonHashJoinConjuncts = Lists.newArrayList(); + + /** + * Init plan and output. + */ + public ThreeJoinHelper(LogicalJoin<? extends Plan, ? extends Plan> topJoin, + LogicalJoin<GroupPlan, GroupPlan> bottomJoin, GroupPlan a, GroupPlan b, GroupPlan c) { + this.topJoin = topJoin; + this.bottomJoin = bottomJoin; + this.a = a; + this.b = b; + this.c = c; + + aOutput = Utils.getOutputSlotReference(a); + bOutput = Utils.getOutputSlotReference(b); + cOutput = Utils.getOutputSlotReference(c); + + Preconditions.checkArgument(!topJoin.getHashJoinConjuncts().isEmpty(), "topJoin hashJoinConjuncts must exist."); + Preconditions.checkArgument(!bottomJoin.getHashJoinConjuncts().isEmpty(), + "bottomJoin hashJoinConjuncts must exist."); + + allHashJoinConjuncts.addAll(topJoin.getHashJoinConjuncts()); + allHashJoinConjuncts.addAll(bottomJoin.getHashJoinConjuncts()); + topJoin.getOtherJoinCondition().ifPresent(otherJoinCondition -> allNonHashJoinConjuncts.addAll( + ExpressionUtils.extractConjunction(otherJoinCondition))); + bottomJoin.getOtherJoinCondition().ifPresent(otherJoinCondition -> allNonHashJoinConjuncts.addAll( + ExpressionUtils.extractConjunction(otherJoinCondition))); + } + + @SafeVarargs + public final void initAllProject(LogicalProject<? extends Plan>... projects) { + for (LogicalProject<? extends Plan> project : projects) { + allProjects.addAll(project.getProjects()); + } + } + + /** + * Get the onCondition of newTopJoin and newBottomJoin. + */ + public boolean initJoinOnCondition() { + // Ignore join with some OnClause like: + // Join C = B + A for above example. + // TODO: also need for otherJoinCondition + for (Expression topJoinOnClauseConjunct : topJoin.getHashJoinConjuncts()) { + Set<SlotReference> topJoinUsedSlot = topJoinOnClauseConjunct.collect(SlotReference.class::isInstance); + if (ExpressionUtils.isIntersecting(topJoinUsedSlot, aOutput) && ExpressionUtils.isIntersecting( + topJoinUsedSlot, bOutput) && ExpressionUtils.isIntersecting(topJoinUsedSlot, cOutput)) { + return false; + } + } + + Set<Slot> newBottomJoinSlots = new HashSet<>(aOutput); + newBottomJoinSlots.addAll(cOutput); + for (Expression hashConjunct : allHashJoinConjuncts) { + Set<SlotReference> slots = hashConjunct.collect(SlotReference.class::isInstance); + if (newBottomJoinSlots.containsAll(slots)) { + newBottomHashJoinConjuncts.add(hashConjunct); + } else { + newTopHashJoinConjuncts.add(hashConjunct); + } + } + for (Expression nonHashConjunct : allNonHashJoinConjuncts) { + Set<SlotReference> slots = nonHashConjunct.collect(SlotReference.class::isInstance); + if (newBottomJoinSlots.containsAll(slots)) { + newBottomNonHashJoinConjuncts.add(nonHashConjunct); + } else { + newTopNonHashJoinConjuncts.add(nonHashConjunct); + } + } + // newBottomJoinOnCondition/newTopJoinOnCondition is empty. They are cross join. + // Example: + // A: col1, col2. B: col2, col3. C: col3, col4 + // (A & B on A.col2=B.col2) & C on B.col3=C.col3. + // (A & B) & C -> (A & C) & B. + // (A & C) will be cross join (newBottomJoinOnCondition is empty) + if (newBottomHashJoinConjuncts.isEmpty() || newTopHashJoinConjuncts.isEmpty()) { + return false; + } + + return true; + } + + /** + * Split inside-project into two part. + * + * @param topJoinChild output of topJoin groupPlan child. + */ + protected Pair<List<NamedExpression>, List<NamedExpression>> splitProjectExprs(List<SlotReference> topJoinChild) { + List<NamedExpression> newTopJoinChildProjectExprs = Lists.newArrayList(); + List<NamedExpression> newBottomJoinProjectExprs = Lists.newArrayList(); + + HashSet<SlotReference> topJoinOutputSlotsSet = new HashSet<>(topJoinChild); + + for (NamedExpression projectExpr : allProjects) { + Set<SlotReference> usedSlotRefs = projectExpr.collect(SlotReference.class::isInstance); + if (topJoinOutputSlotsSet.containsAll(usedSlotRefs)) { + newTopJoinChildProjectExprs.add(projectExpr); + } else { + newBottomJoinProjectExprs.add(projectExpr); + } + } + return Pair.of(newTopJoinChildProjectExprs, newBottomJoinProjectExprs); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 12052cbeb5..61f34a84c3 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -19,7 +19,7 @@ package org.apache.doris.nereids.trees.plans.logical; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.properties.LogicalProperties; -import org.apache.doris.nereids.rules.exploration.JoinReorderContext; +import org.apache.doris.nereids.rules.exploration.join.JoinReorderContext; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.plans.JoinType; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 7f22a4b2fb..a716e3e43f 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.trees.expressions.And; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import com.google.common.base.Preconditions; @@ -29,6 +30,7 @@ import com.google.common.collect.Sets; import java.util.List; import java.util.Objects; +import java.util.Optional; import java.util.Set; /** @@ -76,6 +78,14 @@ public class ExpressionUtils { } } + public static Optional<Expression> andByOptional(List<Expression> expressions) { + if (expressions.isEmpty()) { + return Optional.empty(); + } else { + return Optional.of(ExpressionUtils.and(expressions)); + } + } + public static Expression and(List<Expression> expressions) { return combine(And.class, expressions); } @@ -120,4 +130,16 @@ public class ExpressionUtils { .reduce(type == And.class ? And::new : Or::new) .orElse(new BooleanLiteral(type == And.class)); } + + /** + * Check whether lhs and rhs are intersecting. + */ + public static boolean isIntersecting(Set<SlotReference> lhs, List<SlotReference> rhs) { + for (SlotReference rh : rhs) { + if (lhs.contains(rh)) { + return true; + } + } + return false; + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java index 27f7cbeae7..46e31ead85 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java @@ -17,8 +17,8 @@ package org.apache.doris.nereids.rules.exploration.join; -import org.apache.doris.nereids.CascadesContext; -import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.SlotReference; @@ -26,8 +26,10 @@ import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.types.BigIntType; import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import com.google.common.collect.ImmutableList; @@ -35,7 +37,6 @@ import com.google.common.collect.Lists; import org.junit.jupiter.api.Assertions; import org.junit.jupiter.api.Test; -import java.util.List; import java.util.Optional; public class JoinCommuteTest { @@ -51,14 +52,23 @@ public class JoinCommuteTest { JoinType.INNER_JOIN, Lists.newArrayList(onCondition), Optional.empty(), scan1, scan2); - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(join); - Rule rule = new JoinCommute(true).build(); + PlanChecker.from(MemoTestUtils.createConnectContext(), join) + .transform(JoinCommute.OUTER_LEFT_DEEP.build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(2, root.getLogicalExpressions().size()); - List<Plan> transform = rule.transform(join, cascadesContext); - Assertions.assertEquals(1, transform.size()); - Plan newJoin = transform.get(0); + Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin); + Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject); - Assertions.assertEquals(join.child(0), newJoin.child(1)); - Assertions.assertEquals(join.child(1), newJoin.child(0)); + GroupExpression newJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression(); + Plan left = newJoinGroupExpr.child(0).getLogicalExpression().getPlan(); + Plan right = newJoinGroupExpr.child(1).getLogicalExpression().getPlan(); + Assertions.assertTrue(left instanceof LogicalOlapScan); + Assertions.assertTrue(right instanceof LogicalOlapScan); + + Assertions.assertEquals("t2", ((LogicalOlapScan) left).getTable().getName()); + Assertions.assertEquals("t1", ((LogicalOlapScan) right).getTable().getName()); + }); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProjectTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProjectTest.java deleted file mode 100644 index cb70125c43..0000000000 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomProjectTest.java +++ /dev/null @@ -1,134 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.exploration.join; - -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.CascadesContext; -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.trees.expressions.Alias; -import org.apache.doris.nereids.trees.expressions.EqualTo; -import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.NamedExpression; -import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.Plan; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; -import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; -import org.apache.doris.nereids.trees.plans.logical.LogicalProject; -import org.apache.doris.nereids.util.MemoTestUtils; -import org.apache.doris.nereids.util.PlanConstructor; -import org.apache.doris.nereids.util.Utils; - -import com.google.common.collect.ImmutableList; -import com.google.common.collect.Lists; -import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; -import org.junit.jupiter.api.Test; - -import java.util.List; -import java.util.Optional; - -public class JoinLAsscomProjectTest { - - private static final List<LogicalOlapScan> scans = Lists.newArrayList(); - private static final List<List<SlotReference>> outputs = Lists.newArrayList(); - - @BeforeAll - public static void init() { - LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); - LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); - - scans.add(scan1); - scans.add(scan2); - scans.add(scan3); - - List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1); - List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2); - List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3); - - outputs.add(t1Output); - outputs.add(t2Output); - outputs.add(t3Output); - } - - private Pair<LogicalJoin, LogicalJoin> testJoinProjectLAsscom(List<NamedExpression> projects) { - /* - * topJoin newTopJoin - * / \ / \ - * project C newLeftProject newRightProject - * / ──► / \ - * bottomJoin newBottomJoin B - * / \ / \ - * A B A C - */ - - Assertions.assertEquals(3, scans.size()); - - List<SlotReference> t1 = outputs.get(0); - List<SlotReference> t2 = outputs.get(1); - List<SlotReference> t3 = outputs.get(2); - Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0)); - Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1)); - - LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>> project = new LogicalProject<>( - projects, - new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(bottomJoinOnCondition), - Optional.empty(), scans.get(0), scans.get(1))); - - LogicalJoin<LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>>, LogicalOlapScan> topJoin - = new LogicalJoin<>(JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition), - Optional.empty(), project, scans.get(2)); - - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(topJoin); - Rule rule = new JoinLAsscomProject().build(); - List<Plan> transform = rule.transform(topJoin, cascadesContext); - Assertions.assertEquals(1, transform.size()); - Assertions.assertTrue(transform.get(0) instanceof LogicalJoin); - LogicalJoin newTopJoin = (LogicalJoin) transform.get(0); - return Pair.of(topJoin, newTopJoin); - } - - @Test - public void testStarJoinProjectLAsscom() { - List<SlotReference> t1 = outputs.get(0); - List<SlotReference> t2 = outputs.get(1); - List<NamedExpression> projects = ImmutableList.of( - new Alias(t2.get(0), "t2.id"), - new Alias(t1.get(0), "t1.id"), - t1.get(1), - t2.get(1) - ); - - Pair<LogicalJoin, LogicalJoin> pair = testJoinProjectLAsscom(projects); - - LogicalJoin oldJoin = pair.first; - LogicalJoin newTopJoin = pair.second; - - // Join reorder successfully. - Assertions.assertNotEquals(oldJoin, newTopJoin); - Assertions.assertEquals("t1.id", ((Alias) ((LogicalProject) newTopJoin.left()).getProjects().get(0)).getName()); - Assertions.assertEquals("name", - ((SlotReference) ((LogicalProject) newTopJoin.left()).getProjects().get(1)).getName()); - Assertions.assertEquals("t2.id", - ((Alias) ((LogicalProject) newTopJoin.right()).getProjects().get(0)).getName()); - Assertions.assertEquals("name", - ((SlotReference) ((LogicalProject) newTopJoin.left()).getProjects().get(1)).getName()); - - } -} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java index 6ed20fa125..7d7f1d8b05 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java @@ -17,24 +17,23 @@ package org.apache.doris.nereids.rules.exploration.join; -import org.apache.doris.common.Pair; -import org.apache.doris.nereids.CascadesContext; -import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; -import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.JoinType; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; import org.apache.doris.nereids.util.MemoTestUtils; +import org.apache.doris.nereids.util.PlanChecker; import org.apache.doris.nereids.util.PlanConstructor; import org.apache.doris.nereids.util.Utils; import com.google.common.collect.Lists; import org.junit.jupiter.api.Assertions; -import org.junit.jupiter.api.BeforeAll; import org.junit.jupiter.api.Test; import java.util.List; @@ -42,55 +41,13 @@ import java.util.Optional; public class JoinLAsscomTest { - private static List<LogicalOlapScan> scans = Lists.newArrayList(); - private static List<List<SlotReference>> outputs = Lists.newArrayList(); - - @BeforeAll - public static void init() { - LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); - LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); - LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); - - scans.add(scan1); - scans.add(scan2); - scans.add(scan3); - - List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1); - List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2); - List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3); - outputs.add(t1Output); - outputs.add(t2Output); - outputs.add(t3Output); - } + private final LogicalOlapScan scan1 = PlanConstructor.newLogicalOlapScan(0, "t1", 0); + private final LogicalOlapScan scan2 = PlanConstructor.newLogicalOlapScan(1, "t2", 0); + private final LogicalOlapScan scan3 = PlanConstructor.newLogicalOlapScan(2, "t3", 0); - public Pair<LogicalJoin, LogicalJoin> testJoinLAsscom( - Expression bottomJoinOnCondition, - Expression bottomNonHashExpression, - Expression topJoinOnCondition, - Expression topNonHashExpression) { - /* - * topJoin newTopJoin - * / \ / \ - * bottomJoin C --> newBottomJoin B - * / \ / \ - * A B A C - */ - Assertions.assertEquals(3, scans.size()); - LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN, - Lists.newArrayList(bottomJoinOnCondition), - Optional.of(bottomNonHashExpression), scans.get(0), scans.get(1)); - LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>( - JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition), - Optional.of(topNonHashExpression), bottomJoin, scans.get(2)); - - CascadesContext cascadesContext = MemoTestUtils.createCascadesContext(topJoin); - Rule rule = new JoinLAsscom().build(); - List<Plan> transform = rule.transform(topJoin, cascadesContext); - Assertions.assertEquals(1, transform.size()); - Assertions.assertTrue(transform.get(0) instanceof LogicalJoin); - LogicalJoin newTopJoin = (LogicalJoin) transform.get(0); - return Pair.of(topJoin, newTopJoin); - } + private final List<SlotReference> t1Output = Utils.getOutputSlotReference(scan1); + private final List<SlotReference> t2Output = Utils.getOutputSlotReference(scan2); + private final List<SlotReference> t3Output = Utils.getOutputSlotReference(scan3); @Test public void testStarJoinLAsscom() { @@ -109,31 +66,35 @@ public class JoinLAsscomTest { * t1 t2 t1 t3 */ - List<SlotReference> t1 = outputs.get(0); - List<SlotReference> t2 = outputs.get(1); - List<SlotReference> t3 = outputs.get(2); - Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0)); - Expression bottomNonHashExpression = new LessThan(t1.get(0), t2.get(0)); - Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1)); - Expression topNonHashCondition = new LessThan(t1.get(1), t3.get(1)); - - Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom( - bottomJoinOnCondition, - bottomNonHashExpression, - topJoinOnCondition, - topNonHashCondition); - LogicalJoin oldJoin = pair.first; - LogicalJoin newTopJoin = pair.second; - - // Join reorder successfully. - Assertions.assertNotEquals(oldJoin, newTopJoin); - Assertions.assertEquals("t1", - ((LogicalOlapScan) ((LogicalJoin) newTopJoin.left()).left()).getTable().getName()); - Assertions.assertEquals("t3", - ((LogicalOlapScan) ((LogicalJoin) newTopJoin.left()).right()).getTable().getName()); - Assertions.assertEquals("t2", ((LogicalOlapScan) newTopJoin.right()).getTable().getName()); - Assertions.assertEquals(newTopJoin.getOtherJoinCondition(), - ((LogicalJoin) oldJoin.child(0)).getOtherJoinCondition()); + Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0)); + Expression topJoinOnCondition = new EqualTo(t1Output.get(1), t3Output.get(1)); + + LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN, + Lists.newArrayList(bottomJoinOnCondition), + Optional.empty(), scan1, scan2); + LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>( + JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition), + Optional.empty(), bottomJoin, scan3); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .transform(JoinLAsscom.INNER.build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + Assertions.assertEquals(2, root.getLogicalExpressions().size()); + + Assertions.assertTrue(root.logicalExpressionsAt(0).getPlan() instanceof LogicalJoin); + Assertions.assertTrue(root.logicalExpressionsAt(1).getPlan() instanceof LogicalProject); + + GroupExpression newTopJoinGroupExpr = root.logicalExpressionsAt(1).child(0).getLogicalExpression(); + GroupExpression newBottomJoinGroupExpr = newTopJoinGroupExpr.child(0).getLogicalExpression(); + Plan bottomLeft = newBottomJoinGroupExpr.child(0).getLogicalExpression().getPlan(); + Plan bottomRight = newBottomJoinGroupExpr.child(1).getLogicalExpression().getPlan(); + Plan right = newTopJoinGroupExpr.child(1).getLogicalExpression().getPlan(); + + Assertions.assertEquals("t1", ((LogicalOlapScan) bottomLeft).getTable().getName()); + Assertions.assertEquals("t3", ((LogicalOlapScan) bottomRight).getTable().getName()); + Assertions.assertEquals("t2", ((LogicalOlapScan) right).getTable().getName()); + }); } @Test @@ -151,27 +112,22 @@ public class JoinLAsscomTest { * t1 t2 t1 t3 */ - List<SlotReference> t1 = outputs.get(0); - List<SlotReference> t2 = outputs.get(1); - List<SlotReference> t3 = outputs.get(2); - Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0)); - Expression bottomNonHashExpression = new LessThan(t1.get(0), t2.get(0)); - Expression topJoinOnCondition = new EqualTo(t2.get(0), t3.get(0)); - Expression topNonHashExpression = new LessThan(t2.get(0), t3.get(0)); - - Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(bottomJoinOnCondition, bottomNonHashExpression, - topJoinOnCondition, topNonHashExpression); - LogicalJoin oldJoin = pair.first; - LogicalJoin newTopJoin = pair.second; - - // Join reorder failed. - // Chain-Join LAsscom directly will be failed. - // After t1 -- t2 -- t3 - // -- join commute --> - // t1 -- t2 - // | - // t3 - // then, we can LAsscom for this star-join. - Assertions.assertEquals(oldJoin, newTopJoin); + Expression bottomJoinOnCondition = new EqualTo(t1Output.get(0), t2Output.get(0)); + Expression topJoinOnCondition = new EqualTo(t2Output.get(0), t3Output.get(0)); + LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN, + Lists.newArrayList(bottomJoinOnCondition), + Optional.empty(), scan1, scan2); + LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>( + JoinType.INNER_JOIN, Lists.newArrayList(topJoinOnCondition), + Optional.empty(), bottomJoin, scan3); + + PlanChecker.from(MemoTestUtils.createConnectContext(), topJoin) + .transform(JoinLAsscom.INNER.build()) + .checkMemo(memo -> { + Group root = memo.getRoot(); + + // TODO: need infer onCondition. + Assertions.assertEquals(1, root.getLogicalExpressions().size()); + }); } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java index 0f43ed09eb..3b03d468a6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/PlanChecker.java @@ -23,7 +23,6 @@ import org.apache.doris.nereids.memo.Group; import org.apache.doris.nereids.memo.GroupExpression; import org.apache.doris.nereids.memo.Memo; import org.apache.doris.nereids.pattern.GroupExpressionMatching; -import org.apache.doris.nereids.pattern.GroupExpressionMatching.GroupExpressionIterator; import org.apache.doris.nereids.pattern.MatchingContext; import org.apache.doris.nereids.pattern.PatternDescriptor; import org.apache.doris.nereids.pattern.PatternMatcher; @@ -145,10 +144,8 @@ public class PlanChecker { public PlanChecker transform(GroupExpression groupExpression, PatternMatcher patternMatcher) { GroupExpressionMatching matchResult = new GroupExpressionMatching(patternMatcher.pattern, groupExpression); - GroupExpressionIterator iterator = matchResult.iterator(); - while (iterator.hasNext()) { - Plan before = iterator.next(); + for (Plan before : matchResult) { Plan after = patternMatcher.matchedAction.apply( new MatchingContext(before, patternMatcher.pattern, cascadesContext)); if (before != after) { @@ -162,6 +159,38 @@ public class PlanChecker { return this; } + public PlanChecker transform(Rule rule) { + return transform(cascadesContext.getMemo().getRoot(), rule); + } + + public PlanChecker transform(Group group, Rule rule) { + // copy groupExpressions can prevent ConcurrentModificationException + for (GroupExpression logicalExpression : Lists.newArrayList(group.getLogicalExpressions())) { + transform(logicalExpression, rule); + } + + for (GroupExpression physicalExpression : Lists.newArrayList(group.getPhysicalExpressions())) { + transform(physicalExpression, rule); + } + return this; + } + + public PlanChecker transform(GroupExpression groupExpression, Rule rule) { + GroupExpressionMatching matchResult = new GroupExpressionMatching(rule.getPattern(), groupExpression); + + for (Plan before : matchResult) { + Plan after = rule.transform(before, cascadesContext).get(0); + if (before != after) { + cascadesContext.getMemo().copyIn(after, before.getGroupExpression().get().getOwnerGroup(), false); + } + } + + for (Group childGroup : groupExpression.children()) { + transform(childGroup, rule); + } + return this; + } + public PlanChecker matchesFromRoot(PatternDescriptor<? extends Plan> patternDesc) { Memo memo = cascadesContext.getMemo(); assertMatches(memo, () -> new GroupExpressionMatching(patternDesc.pattern, --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org