This is an automated email from the ASF dual-hosted git repository. huajianlan pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 846716ac10 [feature](nereids): join reorder (#10479) 846716ac10 is described below commit 846716ac10b701646f8f2b1c73708a4f97dd1c87 Author: jakevin <jakevin...@gmail.com> AuthorDate: Tue Jul 26 15:35:00 2022 +0800 [feature](nereids): join reorder (#10479) Enhance join reorder. Add LAsscom (include with project). Add Commute. Add UT for join reorder --- .../org/apache/doris/analysis/PredicateUtils.java | 2 +- .../glue/translator/PhysicalPlanTranslator.java | 4 +- .../nereids/properties/LogicalProperties.java | 2 +- .../org/apache/doris/nereids/rules/RuleSet.java | 6 +- .../rules/exploration/JoinReorderContext.java | 68 ++++++-- .../rules/exploration/join/JoinCommutative.java | 59 ------- .../rules/exploration/join/JoinCommute.java | 102 ++++++++++++ .../rules/exploration/join/JoinExchange.java | 2 + .../rules/exploration/join/JoinLAsscom.java | 123 ++++++++++++-- .../exploration/join/JoinLeftAssociative.java | 54 ------ .../rules/exploration/join/JoinProjectLAsscom.java | 185 +++++++++++++++++++++ .../rewrite/logical/PushPredicateThroughJoin.java | 54 +++--- .../nereids/rules/rewrite/logical/ReorderJoin.java | 4 +- .../nereids/trees/expressions/SlotReference.java | 1 + .../nereids/trees/plans/logical/LogicalJoin.java | 6 + .../apache/doris/nereids/util/ExpressionUtils.java | 61 ++++++- .../rules/exploration/join/JoinCommuteTest.java | 70 ++++++++ .../rules/exploration/join/JoinLAsscomTest.java | 176 ++++++++++++++++++++ .../exploration/join/JoinProjectLAsscomTest.java | 148 +++++++++++++++++ .../doris/nereids/util/ExpressionUtilsTest.java | 12 +- 20 files changed, 961 insertions(+), 178 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/PredicateUtils.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/PredicateUtils.java index f6d058d450..7d8b10e22a 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/PredicateUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/PredicateUtils.java @@ -30,7 +30,7 @@ public class PredicateUtils { * Some examples: * a or b -> a, b * a or b or c -> a, b, c - * (a and b) or (c or d) -> (a and b), (c and d) + * (a and b) or (c or d) -> (a and b), c, d * (a or b) and c -> (a or b) and c * a -> a */ diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java index b314138116..47f9748263 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/glue/translator/PhysicalPlanTranslator.java @@ -348,7 +348,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla throw new RuntimeException("Physical hash join could not execute without equal join condition."); } else { Expression eqJoinExpression = hashJoin.getCondition().get(); - List<Expr> execEqConjunctList = ExpressionUtils.extractConjunct(eqJoinExpression).stream() + List<Expr> execEqConjunctList = ExpressionUtils.extractConjunctive(eqJoinExpression).stream() .map(EqualTo.class::cast) .map(e -> swapEqualToForChildrenOrder(e, hashJoin.left().getOutput())) .map(e -> ExpressionTranslator.translate(e, context)) @@ -400,7 +400,7 @@ public class PhysicalPlanTranslator extends DefaultPlanVisitor<PlanFragment, Pla PlanFragment inputFragment = filter.child(0).accept(this, context); PlanNode planNode = inputFragment.getPlanRoot(); Expression expression = filter.getPredicates(); - List<Expression> expressionList = ExpressionUtils.extractConjunct(expression); + List<Expression> expressionList = ExpressionUtils.extractConjunctive(expression); expressionList.stream().map(e -> ExpressionTranslator.translate(e, context)).forEach(planNode::addConjunct); return inputFragment; } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java index c7fd2b66b3..443c8448ef 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/LogicalProperties.java @@ -39,7 +39,7 @@ public class LogicalProperties { */ public LogicalProperties(Supplier<List<Slot>> outputSupplier) { this.outputSupplier = Suppliers.memoize( - Objects.requireNonNull(outputSupplier, "outputSupplier can not be null") + Objects.requireNonNull(outputSupplier, "outputSupplier can not be null") ); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java index f9827431df..0ca1eb8ea4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleSet.java @@ -17,8 +17,7 @@ package org.apache.doris.nereids.rules; -import org.apache.doris.nereids.rules.exploration.join.JoinCommutative; -import org.apache.doris.nereids.rules.exploration.join.JoinLeftAssociative; +import org.apache.doris.nereids.rules.exploration.join.JoinCommute; import org.apache.doris.nereids.rules.implementation.LogicalAggToPhysicalHashAgg; import org.apache.doris.nereids.rules.implementation.LogicalFilterToPhysicalFilter; import org.apache.doris.nereids.rules.implementation.LogicalJoinToHashJoin; @@ -37,8 +36,7 @@ import java.util.List; */ public class RuleSet { public static final List<Rule> EXPLORATION_RULES = planRuleFactories() - .add(new JoinCommutative(false)) - .add(new JoinLeftAssociative()) + .add(new JoinCommute(true)) .build(); public static final List<Rule> REWRITE_RULES = planRuleFactories() diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/JoinReorderContext.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/JoinReorderContext.java index 170c286351..c5ac5d8d18 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/JoinReorderContext.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/JoinReorderContext.java @@ -20,32 +20,78 @@ package org.apache.doris.nereids.rules.exploration; /** * JoinReorderContext for Duplicate free. - * Paper: Improving Join Reorderability with Compensation Operators + * Paper: + * - Optimizing Join Enumeration in Transformation-based Query Optimizers + * - Improving Join Reorderability with Compensation Operators */ public class JoinReorderContext { // left deep tree private boolean hasCommute = false; - private boolean hasTopPushThrough = false; + + // zig-zag tree + private boolean hasCommuteZigZag = false; + + // bushy tree + private boolean hasExchange = false; + private boolean hasRightAssociate = false; + private boolean hasLeftAssociate = false; public JoinReorderContext() { } - void copyFrom(JoinReorderContext joinReorderContext) { + public void copyFrom(JoinReorderContext joinReorderContext) { this.hasCommute = joinReorderContext.hasCommute; - this.hasTopPushThrough = joinReorderContext.hasTopPushThrough; + this.hasExchange = joinReorderContext.hasExchange; + this.hasLeftAssociate = joinReorderContext.hasLeftAssociate; + this.hasRightAssociate = joinReorderContext.hasRightAssociate; + this.hasCommuteZigZag = joinReorderContext.hasCommuteZigZag; } - JoinReorderContext copy() { - JoinReorderContext joinReorderContext = new JoinReorderContext(); - joinReorderContext.copyFrom(this); - return joinReorderContext; + public void clear() { + hasCommute = false; + hasCommuteZigZag = false; + hasExchange = false; + hasRightAssociate = false; + hasLeftAssociate = false; } - public boolean isHasCommute() { + public boolean hasCommute() { return hasCommute; } - public boolean isHasTopPushThrough() { - return hasTopPushThrough; + public void setHasCommute(boolean hasCommute) { + this.hasCommute = hasCommute; + } + + public boolean hasExchange() { + return hasExchange; + } + + public void setHasExchange(boolean hasExchange) { + this.hasExchange = hasExchange; + } + + public boolean hasRightAssociate() { + return hasRightAssociate; + } + + public void setHasRightAssociate(boolean hasRightAssociate) { + this.hasRightAssociate = hasRightAssociate; + } + + public boolean hasLeftAssociate() { + return hasLeftAssociate; + } + + public void setHasLeftAssociate(boolean hasLeftAssociate) { + this.hasLeftAssociate = hasLeftAssociate; + } + + public boolean hasCommuteZigZag() { + return hasCommuteZigZag; + } + + public void setHasCommuteZigZag(boolean hasCommuteZigZag) { + this.hasCommuteZigZag = hasCommuteZigZag; } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommutative.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommutative.java deleted file mode 100644 index 701a6984e1..0000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommutative.java +++ /dev/null @@ -1,59 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.exploration.join; - -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; - -/** - * rule factory for exchange inner join's children. - */ -public class JoinCommutative extends OneExplorationRuleFactory { - private boolean justApplyInnerOuterCrossJoin = false; - - private final SwapType swapType; - - /** - * If param is true, just apply rule in inner/full-outer/cross join. - */ - public JoinCommutative(boolean justApplyInnerOuterCrossJoin) { - this.justApplyInnerOuterCrossJoin = justApplyInnerOuterCrossJoin; - this.swapType = SwapType.ALL; - } - - public JoinCommutative(boolean justApplyInnerOuterCrossJoin, SwapType swapType) { - this.justApplyInnerOuterCrossJoin = justApplyInnerOuterCrossJoin; - this.swapType = swapType; - } - - enum SwapType { - BOTTOM_JOIN, ZIG_ZAG, ALL - } - - @Override - public Rule build() { - return innerLogicalJoin().then(join -> new LogicalJoin( - join.getJoinType().swap(), - join.getCondition(), - join.right(), - join.left()) - ).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java new file mode 100644 index 0000000000..776d722293 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinCommute.java @@ -0,0 +1,102 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +/** + * rule factory for exchange inner join's children. + */ +@Developing +public class JoinCommute extends OneExplorationRuleFactory { + private final SwapType swapType; + private final boolean swapOuter; + + public JoinCommute(boolean swapOuter) { + this.swapOuter = swapOuter; + this.swapType = SwapType.ALL; + } + + public JoinCommute(boolean swapOuter, SwapType swapType) { + this.swapOuter = swapOuter; + this.swapType = swapType; + } + + enum SwapType { + BOTTOM_JOIN, ZIG_ZAG, ALL + } + + @Override + public Rule build() { + return innerLogicalJoin(any(), any()).then(join -> { + if (!check(join)) { + return null; + } + boolean isBottomJoin = isBottomJoin(join); + if (swapType == SwapType.BOTTOM_JOIN && !isBottomJoin) { + return null; + } + + LogicalJoin newJoin = new LogicalJoin( + join.getJoinType(), + join.getCondition(), + join.right(), join.left(), + join.getJoinReorderContext() + ); + newJoin.getJoinReorderContext().setHasCommute(true); + if (swapType == SwapType.ZIG_ZAG && !isBottomJoin) { + newJoin.getJoinReorderContext().setHasCommuteZigZag(true); + } + + return newJoin; + }).toRule(RuleType.LOGICAL_JOIN_COMMUTATIVE); + } + + + private boolean check(LogicalJoin join) { + if (join.getJoinReorderContext().hasCommute() || join.getJoinReorderContext().hasExchange()) { + return false; + } + return true; + } + + private boolean isBottomJoin(LogicalJoin join) { + // TODO: wait for tree model of pattern-match. + if (join.left() instanceof LogicalProject) { + LogicalProject project = (LogicalProject) join.left(); + if (project.child() instanceof LogicalJoin) { + return false; + } + } + if (join.right() instanceof LogicalProject) { + LogicalProject project = (LogicalProject) join.left(); + if (project.child() instanceof LogicalJoin) { + return false; + } + } + if (join.left() instanceof LogicalJoin || join.right() instanceof LogicalJoin) { + return false; + } + return true; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchange.java index 73c3f875c0..90153c8df5 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinExchange.java @@ -17,6 +17,7 @@ package org.apache.doris.nereids.rules.exploration.join; +import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; @@ -28,6 +29,7 @@ import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; /** * Rule for busy-tree, exchange the children node. */ +@Developing public class JoinExchange extends OneExplorationRuleFactory { /* * topJoin newTopJoin diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java index 520d19f12a..031ac28c91 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscom.java @@ -17,35 +17,138 @@ package org.apache.doris.nereids.rules.exploration.join; +import org.apache.doris.nereids.annotation.Developing; import org.apache.doris.nereids.rules.Rule; import org.apache.doris.nereids.rules.RuleType; import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; import org.apache.doris.nereids.trees.plans.GroupPlan; import org.apache.doris.nereids.trees.plans.Plan; import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; /** * Rule for change inner join left associative to right. */ +@Developing public class JoinLAsscom extends OneExplorationRuleFactory { /* - * topJoin newTopJoin - * / \ / \ - * bottomJoin C --> newBottomJoin B - * / \ / \ - * A B A C + * topJoin newTopJoin + * / \ / \ + * bottomJoin C --> newBottomJoin B + * / \ / \ + * A B A C */ @Override public Rule build() { - return innerLogicalJoin(innerLogicalJoin(), any()).then(topJoin -> { + return innerLogicalJoin(innerLogicalJoin(groupPlan(), groupPlan()), groupPlan()).then(topJoin -> { + if (!check(topJoin)) { + return null; + } + LogicalJoin<GroupPlan, GroupPlan> bottomJoin = topJoin.left(); - GroupPlan a = bottomJoin.left(); - GroupPlan b = bottomJoin.right(); + Plan a = bottomJoin.left(); + Plan b = bottomJoin.right(); Plan c = topJoin.right(); - Plan newBottomJoin = new LogicalJoin(bottomJoin.getJoinType(), bottomJoin.getCondition(), a, c); - return new LogicalJoin(bottomJoin.getJoinType(), topJoin.getCondition(), newBottomJoin, b); + Optional<Expression> optTopJoinOnClause = topJoin.getCondition(); + // inner join, onClause can't be empty(). + Preconditions.checkArgument(optTopJoinOnClause.isPresent(), + "bottomJoin in inner join, onClause must be present."); + Expression topJoinOnClause = optTopJoinOnClause.get(); + Optional<Expression> optBottomJoinOnClause = bottomJoin.getCondition(); + Preconditions.checkArgument(optBottomJoinOnClause.isPresent(), + "bottomJoin in inner join, onClause must be present."); + Expression bottomJoinOnClause = optBottomJoinOnClause.get(); + + List<SlotReference> aOutputSlots = a.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + List<SlotReference> bOutputSlots = b.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + List<SlotReference> cOutputSlots = c.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + + // Ignore join with some OnClause like: + // Join C = B + A for above example. + List<Expression> topJoinOnClauseConjuncts = ExpressionUtils.extractConjunctive(topJoinOnClause); + for (Expression topJoinOnClauseConjunct : topJoinOnClauseConjuncts) { + if (ExpressionUtils.isIntersecting(topJoinOnClauseConjunct.collect(SlotReference.class::isInstance), + aOutputSlots) + && ExpressionUtils.isIntersecting( + topJoinOnClauseConjunct.collect(SlotReference.class::isInstance), + bOutputSlots) + && ExpressionUtils.isIntersecting( + topJoinOnClauseConjunct.collect(SlotReference.class::isInstance), + cOutputSlots) + ) { + return null; + } + } + List<Expression> bottomJoinOnClauseConjuncts = ExpressionUtils.extractConjunctive(bottomJoinOnClause); + + List<Expression> allOnCondition = Lists.newArrayList(); + allOnCondition.addAll(topJoinOnClauseConjuncts); + allOnCondition.addAll(bottomJoinOnClauseConjuncts); + + List<SlotReference> newBottomJoinSlots = Lists.newArrayList(); + newBottomJoinSlots.addAll(aOutputSlots); + newBottomJoinSlots.addAll(cOutputSlots); + + List<Expression> newBottomJoinOnCondition = Lists.newArrayList(); + List<Expression> newTopJoinOnCondition = Lists.newArrayList(); + for (Expression onCondition : allOnCondition) { + List<SlotReference> slots = onCondition.collect(SlotReference.class::isInstance); + if (ExpressionUtils.containsAll(newBottomJoinSlots, slots)) { + newBottomJoinOnCondition.add(onCondition); + } else { + newTopJoinOnCondition.add(onCondition); + } + } + + // newBottomJoinOnCondition/newTopJoinOnCondition is empty. They are cross join. + // Example: + // A: col1, col2. B: col2, col3. C: col3, col4 + // (A & B on A.col2=B.col2) & C on B.col3=C.col3. + // If (A & B) & C -> (A & C) & B. + // (A & C) will be cross join (newBottomJoinOnCondition is empty) + if (newBottomJoinOnCondition.isEmpty() || newTopJoinOnCondition.isEmpty()) { + return null; + } + + // new bottom join (a, c) + LogicalJoin newBottomJoin = new LogicalJoin( + bottomJoin.getJoinType(), + Optional.of(ExpressionUtils.and(newBottomJoinOnCondition)), + a, c); + + // TODO: add column map (use project) + // SlotReference bind() may have solved this problem. + // source: | A | B | C | + // target: | A | C | B | + + // new top join: b + LogicalJoin newTopJoin = new LogicalJoin( + topJoin.getJoinType(), + Optional.of(ExpressionUtils.and(newTopJoinOnCondition)), + newBottomJoin, b); + + return newTopJoin; }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); } + + private boolean check(LogicalJoin topJoin) { + if (topJoin.getJoinReorderContext().hasCommute()) { + return false; + } + return true; + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLeftAssociative.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLeftAssociative.java deleted file mode 100644 index ba033f686f..0000000000 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinLeftAssociative.java +++ /dev/null @@ -1,54 +0,0 @@ -// Licensed to the Apache Software Foundation (ASF) under one -// or more contributor license agreements. See the NOTICE file -// distributed with this work for additional information -// regarding copyright ownership. The ASF licenses this file -// to you under the Apache License, Version 2.0 (the -// "License"); you may not use this file except in compliance -// with the License. You may obtain a copy of the License at -// -// http://www.apache.org/licenses/LICENSE-2.0 -// -// Unless required by applicable law or agreed to in writing, -// software distributed under the License is distributed on an -// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -// KIND, either express or implied. See the License for the -// specific language governing permissions and limitations -// under the License. - -package org.apache.doris.nereids.rules.exploration.join; - -import org.apache.doris.nereids.rules.Rule; -import org.apache.doris.nereids.rules.RuleType; -import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; -import org.apache.doris.nereids.trees.plans.JoinType; -import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; - -/** - * Rule factory for change inner join left associative to right. - */ -public class JoinLeftAssociative extends OneExplorationRuleFactory { - /* - * topJoin newTopJoin - * / \ / \ - * bottomJoin C --> A newBottomJoin - * / \ / \ - * A B B C - */ - @Override - public Rule build() { - return innerLogicalJoin(innerLogicalJoin(), any()).then(root -> { - // fixme, just for example now - return new LogicalJoin( - JoinType.INNER_JOIN, - root.getCondition(), - root.left().left(), - new LogicalJoin( - JoinType.INNER_JOIN, - root.getCondition(), - root.left().right(), - root.right() - ) - ); - }).toRule(RuleType.LOGICAL_LEFT_JOIN_ASSOCIATIVE); - } -} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinProjectLAsscom.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinProjectLAsscom.java new file mode 100644 index 0000000000..7f5500d4cc --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/exploration/join/JoinProjectLAsscom.java @@ -0,0 +1,185 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.nereids.annotation.Developing; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.exploration.OneExplorationRuleFactory; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.base.Preconditions; +import com.google.common.collect.Lists; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +/** + * Rule for change inner join left associative to right. + */ +@Developing +public class JoinProjectLAsscom extends OneExplorationRuleFactory { + /* + * topJoin newTopJoin + * / \ / \ + * project C newLeftProject newRightProject + * / ──► / \ + * bottomJoin newBottomJoin B + * / \ / \ + * A B A C + */ + @Override + public Rule build() { + return innerLogicalJoin(logicalProject(innerLogicalJoin(groupPlan(), groupPlan())), groupPlan()) + .when(this::check) + .then(topJoin -> { + if (!check(topJoin)) { + return null; + } + + LogicalProject<LogicalJoin<GroupPlan, GroupPlan>> project = topJoin.left(); + LogicalJoin<GroupPlan, GroupPlan> bottomJoin = project.child(); + + Plan a = bottomJoin.left(); + Plan b = bottomJoin.right(); + Plan c = topJoin.right(); + + Optional<Expression> optTopJoinOnClause = topJoin.getCondition(); + // inner join, onClause can't be empty(). + Preconditions.checkArgument(optTopJoinOnClause.isPresent(), + "bottomJoin in inner join, onClause must be present."); + Expression topJoinOnClause = optTopJoinOnClause.get(); + Optional<Expression> optBottomJoinOnClause = bottomJoin.getCondition(); + Preconditions.checkArgument(optBottomJoinOnClause.isPresent(), + "bottomJoin in inner join, onClause must be present."); + Expression bottomJoinOnClause = optBottomJoinOnClause.get(); + + List<SlotReference> aOutputSlots = a.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + List<SlotReference> bOutputSlots = b.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + List<SlotReference> cOutputSlots = c.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + + // Ignore join with some OnClause like: + // Join C = B + A for above example. + List<Expression> topJoinOnClauseConjuncts = ExpressionUtils.extractConjunctive(topJoinOnClause); + for (Expression topJoinOnClauseConjunct : topJoinOnClauseConjuncts) { + if (ExpressionUtils.isIntersecting( + topJoinOnClauseConjunct.collect(SlotReference.class::isInstance), aOutputSlots) + && ExpressionUtils.isIntersecting( + topJoinOnClauseConjunct.collect(SlotReference.class::isInstance), + bOutputSlots) + && ExpressionUtils.isIntersecting( + topJoinOnClauseConjunct.collect(SlotReference.class::isInstance), + cOutputSlots) + ) { + return null; + } + } + List<Expression> bottomJoinOnClauseConjuncts = ExpressionUtils.extractConjunctive( + bottomJoinOnClause); + + List<Expression> allOnCondition = Lists.newArrayList(); + allOnCondition.addAll(topJoinOnClauseConjuncts); + allOnCondition.addAll(bottomJoinOnClauseConjuncts); + + List<SlotReference> newBottomJoinSlots = Lists.newArrayList(); + newBottomJoinSlots.addAll(aOutputSlots); + newBottomJoinSlots.addAll(cOutputSlots); + + List<Expression> newBottomJoinOnCondition = Lists.newArrayList(); + List<Expression> newTopJoinOnCondition = Lists.newArrayList(); + for (Expression onCondition : allOnCondition) { + List<SlotReference> slots = onCondition.collect(SlotReference.class::isInstance); + if (ExpressionUtils.containsAll(newBottomJoinSlots, slots)) { + newBottomJoinOnCondition.add(onCondition); + } else { + newTopJoinOnCondition.add(onCondition); + } + } + + // newBottomJoinOnCondition/newTopJoinOnCondition is empty. They are cross join. + // Example: + // A: col1, col2. B: col2, col3. C: col3, col4 + // (A & B on A.col2=B.col2) & C on B.col3=C.col3. + // If (A & B) & C -> (A & C) & B. + // (A & C) will be cross join (newBottomJoinOnCondition is empty) + if (newBottomJoinOnCondition.isEmpty() || newTopJoinOnCondition.isEmpty()) { + return null; + } + + Plan newBottomJoin = new LogicalJoin( + bottomJoin.getJoinType(), + Optional.of(ExpressionUtils.and(newBottomJoinOnCondition)), + a, c); + + // Handle project. + List<NamedExpression> projectExprs = project.getProjects(); + List<NamedExpression> newRightProjectExprs = Lists.newArrayList(); + List<NamedExpression> newLeftProjectExpr = Lists.newArrayList(); + for (NamedExpression projectExpr : projectExprs) { + List<SlotReference> usedSlotRefs = projectExpr.collect(SlotReference.class::isInstance); + if (ExpressionUtils.containsAll(bOutputSlots, usedSlotRefs)) { + newRightProjectExprs.add(projectExpr); + } else { + newLeftProjectExpr.add(projectExpr); + } + } + + // project include b, add project for right. + if (newRightProjectExprs.size() != 0) { + LogicalProject newRightProject = new LogicalProject<>(newRightProjectExprs, b); + + if (newLeftProjectExpr.size() != 0) { + // project include bottom join, add project for left bottom join. + LogicalProject newLeftProject = new LogicalProject<>(newLeftProjectExpr, newBottomJoin); + return new LogicalJoin( + topJoin.getJoinType(), + Optional.of(ExpressionUtils.and(newTopJoinOnCondition)), + newLeftProject, newRightProject); + } + return new LogicalJoin( + topJoin.getJoinType(), + Optional.of(ExpressionUtils.and(newTopJoinOnCondition)), + newBottomJoin, newRightProject); + } + + return new LogicalJoin( + topJoin.getJoinType(), + Optional.of(ExpressionUtils.and(newTopJoinOnCondition)), + newBottomJoin, b); + }).toRule(RuleType.LOGICAL_JOIN_L_ASSCOM); + } + + private boolean check(LogicalJoin topJoin) { + if (topJoin.getJoinReorderContext().hasCommute()) { + return false; + } + return true; + } + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java index ff76ac1006..84103b0da7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java @@ -42,28 +42,29 @@ import java.util.Set; /** * Push the predicate in the LogicalFilter or LogicalJoin to the join children. - * For example: - * select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 where a.k1 > 1 and b.k1 > 2 - * Logical plan tree: - * project - * | - * filter (a.k1 > 1 and b.k1 > 2) - * | - * join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5) - * / \ - * scan scan - * transformed: - * project - * | - * join (a.k1 = b.k1) - * / \ - * filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5) - * | | - * scan scan * todo: Now, only support eq on condition for inner join, support other case later */ public class PushPredicateThroughJoin extends OneRewriteRuleFactory { - + /* + * For example: + * select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 where a.k1 > 1 and b.k1 > 2 + * Logical plan tree: + * project + * | + * filter (a.k1 > 1 and b.k1 > 2) + * | + * join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5) + * / \ + * scan scan + * transformed: + * project + * | + * join (a.k1 = b.k1) + * / \ + * filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5) + * | | + * scan scan + */ @Override public Rule build() { return logicalFilter(innerLogicalJoin()).then(filter -> { @@ -79,13 +80,14 @@ public class PushPredicateThroughJoin extends OneRewriteRuleFactory { List<Slot> leftInput = join.left().getOutput(); List<Slot> rightInput = join.right().getOutput(); - ExpressionUtils.extractConjunct(ExpressionUtils.and(onPredicates, wherePredicates)).forEach(predicate -> { - if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) { - eqConditions.add(predicate); - } else { - otherConditions.add(predicate); - } - }); + ExpressionUtils.extractConjunctive(ExpressionUtils.and(onPredicates, wherePredicates)) + .forEach(predicate -> { + if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) { + eqConditions.add(predicate); + } else { + otherConditions.add(predicate); + } + }); List<Expression> leftPredicates = Lists.newArrayList(); List<Expression> rightPredicates = Lists.newArrayList(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java index 0405c2975b..7fafe2ec90 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/ReorderJoin.java @@ -191,7 +191,7 @@ public class ReorderJoin extends OneRewriteRuleFactory { public Void visitLogicalFilter(LogicalFilter<Plan> filter, Void context) { Plan child = filter.child(); if (child instanceof LogicalJoin) { - conjuncts.addAll(ExpressionUtils.extractConjunct(filter.getPredicates())); + conjuncts.addAll(ExpressionUtils.extractConjunctive(filter.getPredicates())); } child.accept(this, context); @@ -207,7 +207,7 @@ public class ReorderJoin extends OneRewriteRuleFactory { join.left().accept(this, context); join.right().accept(this, context); - join.getCondition().ifPresent(cond -> conjuncts.addAll(ExpressionUtils.extractConjunct(cond))); + join.getCondition().ifPresent(cond -> conjuncts.addAll(ExpressionUtils.extractConjunctive(cond))); if (!(join.left() instanceof LogicalJoin)) { joinInputs.add(join.left()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java index c9f72c02e7..12cf3acca8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotReference.java @@ -108,6 +108,7 @@ public class SlotReference extends Slot { } SlotReference that = (SlotReference) o; return nullable == that.nullable + && dataType.equals(that.dataType) && exprId.equals(that.exprId) && dataType.equals(that.dataType) && name.equals(that.name) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java index 05324fd75f..5d6f06458e 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/plans/logical/LogicalJoin.java @@ -61,6 +61,12 @@ public class LogicalJoin<LEFT_CHILD_TYPE extends Plan, RIGHT_CHILD_TYPE extends this(joinType, condition, Optional.empty(), Optional.empty(), leftChild, rightChild); } + public LogicalJoin(JoinType joinType, Optional<Expression> condition, + LEFT_CHILD_TYPE leftChild, RIGHT_CHILD_TYPE rightChild, JoinReorderContext joinReorderContext) { + this(joinType, condition, Optional.empty(), Optional.empty(), leftChild, rightChild); + this.joinReorderContext.copyFrom(joinReorderContext); + } + /** * Constructor for LogicalJoinPlan. * diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 6517a3d91f..3641fd7a4c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.trees.expressions.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.Or; +import org.apache.doris.nereids.trees.expressions.SlotReference; import com.google.common.base.Preconditions; import com.google.common.collect.Lists; @@ -36,14 +37,26 @@ import java.util.Set; */ public class ExpressionUtils { - public static List<Expression> extractConjunct(Expression expr) { + public static List<Expression> extractConjunctive(Expression expr) { return extract(And.class, expr); } - public static List<Expression> extractDisjunct(Expression expr) { + public static List<Expression> extractDisjunctive(Expression expr) { return extract(Or.class, expr); } + /** + * Split predicates with `And/Or` form recursively. + * Some examples for `And`: + * <p> + * a and b -> a, b + * (a and b) and c -> a, b, c + * (a or b) and (c and d) -> (a or b), c , d + * <p> + * Stop recursion when meeting `Or`, so this func will ignore `And` inside `Or`. + * Warning examples: + * (a and b) or c -> (a and b) or c + */ public static List<Expression> extract(CompoundPredicate expr) { return extract(expr.getClass(), expr); } @@ -84,6 +97,12 @@ public class ExpressionUtils { * Use AND/OR to combine expressions together. */ public static Expression combine(Class<? extends Expression> type, List<Expression> expressions) { + /* + * (AB) (CD) E ((AB)(CD)) E (((AB)(CD))E) + * ▲ ▲ ▲ ▲ ▲ ▲ + * │ │ │ │ │ │ + * A B C D E ──► A B C D E ──► (AB) (CD) E ──► ((AB)(CD)) E ──► (((AB)(CD))E) + */ Preconditions.checkArgument(type == And.class || type == Or.class); Objects.requireNonNull(expressions, "expressions is null"); @@ -102,4 +121,42 @@ public class ExpressionUtils { .reduce(type == And.class ? And::new : Or::new) .orElse(new BooleanLiteral(type == And.class)); } + + /** + * Check whether lhs and rhs (both are List of SlotReference) are intersecting. + */ + public static boolean isIntersecting(List<SlotReference> lhs, List<SlotReference> rhs) { + for (SlotReference lSlot : lhs) { + for (SlotReference rSlot : rhs) { + if (lSlot.equals(rSlot)) { + return true; + } + } + } + return false; + } + + /** + * Whether `List of SlotReference` contains a `SlotReference`. + */ + public static boolean contains(List<SlotReference> list, SlotReference item) { + for (SlotReference slotRefInList : list) { + if (item.equals(slotRefInList)) { + return true; + } + } + return false; + } + + /** + * Whether `List of SlotReference` contains all another `List of SlotReference`. + */ + public static boolean containsAll(List<SlotReference> large, List<SlotReference> small) { + for (SlotReference slotRefInSmall : small) { + if (!contains(large, slotRefInSmall)) { + return false; + } + } + return true; + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java new file mode 100644 index 0000000000..8d347335f6 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinCommuteTest.java @@ -0,0 +1,70 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.catalog.AggregateType; +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.Table; +import org.apache.doris.catalog.Type; +import org.apache.doris.nereids.PlannerContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.types.BigIntType; + +import com.google.common.collect.ImmutableList; +import mockit.Mocked; +import org.junit.Assert; +import org.junit.Test; + +import java.util.List; +import java.util.Optional; + +public class JoinCommuteTest { + @Test + public void testInnerJoinCommute(@Mocked PlannerContext plannerContext) { + Table table1 = new Table(0L, "table1", Table.TableType.OLAP, + ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""))); + LogicalOlapScan scan1 = new LogicalOlapScan(table1, ImmutableList.of()); + + Table table2 = new Table(0L, "table2", Table.TableType.OLAP, + ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""))); + LogicalOlapScan scan2 = new LogicalOlapScan(table2, ImmutableList.of()); + + Expression onCondition = new EqualTo( + new SlotReference("id", new BigIntType(), true, ImmutableList.of("table1")), + new SlotReference("id", new BigIntType(), true, ImmutableList.of("table2"))); + LogicalJoin<LogicalOlapScan, LogicalOlapScan> join = new LogicalJoin<>( + JoinType.INNER_JOIN, Optional.of(onCondition), scan1, scan2); + + Rule rule = new JoinCommute(true).build(); + + List<Plan> transform = rule.transform(join, plannerContext); + Assert.assertEquals(1, transform.size()); + Plan newJoin = transform.get(0); + + Assert.assertEquals(newJoin.child(1), join.child(0)); + Assert.assertEquals(newJoin.child(0), join.child(1)); + } + +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java new file mode 100644 index 0000000000..7e0dd28ec0 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinLAsscomTest.java @@ -0,0 +1,176 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.catalog.AggregateType; +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.Table; +import org.apache.doris.catalog.Type; +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.PlannerContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import mockit.Mocked; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +public class JoinLAsscomTest { + + private static List<LogicalOlapScan> scans = Lists.newArrayList(); + private static List<List<SlotReference>> outputs = Lists.newArrayList(); + + @BeforeClass + public static void init() { + Table t1 = new Table(0L, "t1", Table.TableType.OLAP, + ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("name", Type.STRING, true, AggregateType.NONE, "0", ""))); + LogicalOlapScan scan1 = new LogicalOlapScan(t1, ImmutableList.of()); + + Table t2 = new Table(0L, "t2", Table.TableType.OLAP, + ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("name", Type.STRING, true, AggregateType.NONE, "0", ""))); + LogicalOlapScan scan2 = new LogicalOlapScan(t2, ImmutableList.of()); + + Table t3 = new Table(0L, "t3", Table.TableType.OLAP, + ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("name", Type.STRING, true, AggregateType.NONE, "0", ""))); + LogicalOlapScan scan3 = new LogicalOlapScan(t3, ImmutableList.of()); + scans.add(scan1); + scans.add(scan2); + scans.add(scan3); + + List<SlotReference> t1Output = scan1.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + List<SlotReference> t2Output = scan2.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + List<SlotReference> t3Output = scan3.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + outputs.add(t1Output); + outputs.add(t2Output); + outputs.add(t3Output); + } + + public Pair<LogicalJoin, LogicalJoin> testJoinLAsscom(PlannerContext plannerContext, + Expression bottomJoinOnCondition, Expression topJoinOnCondition) { + /* + * topJoin newTopJoin + * / \ / \ + * bottomJoin C --> newBottomJoin B + * / \ / \ + * A B A C + */ + Assert.assertEquals(3, scans.size()); + LogicalJoin<LogicalOlapScan, LogicalOlapScan> bottomJoin = new LogicalJoin<>(JoinType.INNER_JOIN, + Optional.of(bottomJoinOnCondition), scans.get(0), scans.get(1)); + LogicalJoin<LogicalJoin<LogicalOlapScan, LogicalOlapScan>, LogicalOlapScan> topJoin = new LogicalJoin<>( + JoinType.INNER_JOIN, Optional.of(topJoinOnCondition), bottomJoin, scans.get(2)); + + Rule rule = new JoinLAsscom().build(); + List<Plan> transform = rule.transform(topJoin, plannerContext); + Assert.assertEquals(1, transform.size()); + Assert.assertTrue(transform.get(0) instanceof LogicalJoin); + LogicalJoin newTopJoin = (LogicalJoin) transform.get(0); + return new Pair<>(topJoin, newTopJoin); + } + + @Test + public void testStarJoinLAsscom(@Mocked PlannerContext plannerContext) { + /* + * Star-Join + * t1 -- t2 + * | + * t3 + * <p> + * t1.id=t3.id t1.id=t2.id + * topJoin newTopJoin + * / \ / \ + * t1.id=t2.id t3 t1.id=t3.id t2 + * bottomJoin --> newBottomJoin + * / \ / \ + * t1 t2 t1 t3 + */ + + List<SlotReference> t1 = outputs.get(0); + List<SlotReference> t2 = outputs.get(1); + List<SlotReference> t3 = outputs.get(2); + Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0)); + Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1)); + + Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(plannerContext, bottomJoinOnCondition, + topJoinOnCondition); + LogicalJoin oldJoin = pair.first; + LogicalJoin newTopJoin = pair.second; + + // Join reorder successfully. + Assert.assertNotEquals(oldJoin, newTopJoin); + Assert.assertEquals("t1", ((LogicalOlapScan) ((LogicalJoin) newTopJoin.left()).left()).getTable().getName()); + Assert.assertEquals("t3", ((LogicalOlapScan) ((LogicalJoin) newTopJoin.left()).right()).getTable().getName()); + Assert.assertEquals("t2", ((LogicalOlapScan) newTopJoin.right()).getTable().getName()); + } + + @Test + public void testChainJoinLAsscom(@Mocked PlannerContext plannerContext) { + /* + * Chain-Join + * t1 -- t2 -- t3 + * <p> + * t2.id=t3.id t2.id=t3.id + * topJoin newTopJoin + * / \ / \ + * t1.id=t2.id t3 t1.id=t3.id t2 + * bottomJoin --> newBottomJoin + * / \ / \ + * t1 t2 t1 t3 + */ + + List<SlotReference> t1 = outputs.get(0); + List<SlotReference> t2 = outputs.get(1); + List<SlotReference> t3 = outputs.get(2); + Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0)); + Expression topJoinOnCondition = new EqualTo(t2.get(0), t3.get(0)); + + Pair<LogicalJoin, LogicalJoin> pair = testJoinLAsscom(plannerContext, bottomJoinOnCondition, + topJoinOnCondition); + LogicalJoin oldJoin = pair.first; + LogicalJoin newTopJoin = pair.second; + + // Join reorder failed. + // Chain-Join LAsscom directly will be failed. + // After t1 -- t2 -- t3 + // -- join commute --> + // t1 -- t2 + // | + // t3 + // then, we can LAsscom for this star-join. + Assert.assertEquals(oldJoin, newTopJoin); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinProjectLAsscomTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinProjectLAsscomTest.java new file mode 100644 index 0000000000..c41d14c85b --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/exploration/join/JoinProjectLAsscomTest.java @@ -0,0 +1,148 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.exploration.join; + +import org.apache.doris.catalog.AggregateType; +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.Table; +import org.apache.doris.catalog.Type; +import org.apache.doris.common.Pair; +import org.apache.doris.nereids.PlannerContext; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.trees.expressions.Alias; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.NamedExpression; +import org.apache.doris.nereids.trees.expressions.SlotReference; +import org.apache.doris.nereids.trees.plans.JoinType; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalJoin; +import org.apache.doris.nereids.trees.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.trees.plans.logical.LogicalProject; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import mockit.Mocked; +import org.junit.Assert; +import org.junit.BeforeClass; +import org.junit.Test; + +import java.util.List; +import java.util.Optional; +import java.util.stream.Collectors; + +public class JoinProjectLAsscomTest { + + private static List<LogicalOlapScan> scans = Lists.newArrayList(); + private static List<List<SlotReference>> outputs = Lists.newArrayList(); + + @BeforeClass + public static void init() { + Table t1 = new Table(0L, "t1", Table.TableType.OLAP, + ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("name", Type.STRING, true, AggregateType.NONE, "0", ""))); + LogicalOlapScan scan1 = new LogicalOlapScan(t1, ImmutableList.of()); + + Table t2 = new Table(0L, "t2", Table.TableType.OLAP, + ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("name", Type.STRING, true, AggregateType.NONE, "0", ""))); + LogicalOlapScan scan2 = new LogicalOlapScan(t2, ImmutableList.of()); + + Table t3 = new Table(0L, "t3", Table.TableType.OLAP, + ImmutableList.of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("name", Type.STRING, true, AggregateType.NONE, "0", ""))); + LogicalOlapScan scan3 = new LogicalOlapScan(t3, ImmutableList.of()); + scans.add(scan1); + scans.add(scan2); + scans.add(scan3); + + List<SlotReference> t1Output = scan1.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + List<SlotReference> t2Output = scan2.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + List<SlotReference> t3Output = scan3.getOutput().stream().map(slot -> (SlotReference) slot) + .collect(Collectors.toList()); + outputs.add(t1Output); + outputs.add(t2Output); + outputs.add(t3Output); + } + + private Pair<LogicalJoin, LogicalJoin> testJoinProjectLAsscom(PlannerContext plannerContext, + List<NamedExpression> projects) { + /* + * topJoin newTopJoin + * / \ / \ + * project C newLeftProject newRightProject + * / ──► / \ + * bottomJoin newBottomJoin B + * / \ / \ + * A B A C + */ + + Assert.assertEquals(3, scans.size()); + + List<SlotReference> t1 = outputs.get(0); + List<SlotReference> t2 = outputs.get(1); + List<SlotReference> t3 = outputs.get(2); + Expression bottomJoinOnCondition = new EqualTo(t1.get(0), t2.get(0)); + Expression topJoinOnCondition = new EqualTo(t1.get(1), t3.get(1)); + + LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>> project = new LogicalProject<>( + projects, + new LogicalJoin<>(JoinType.INNER_JOIN, Optional.of(bottomJoinOnCondition), scans.get(0), scans.get(1))); + + LogicalJoin<LogicalProject<LogicalJoin<LogicalOlapScan, LogicalOlapScan>>, LogicalOlapScan> topJoin + = new LogicalJoin<>(JoinType.INNER_JOIN, Optional.of(topJoinOnCondition), project, scans.get(2)); + + Rule rule = new JoinProjectLAsscom().build(); + List<Plan> transform = rule.transform(topJoin, plannerContext); + Assert.assertEquals(1, transform.size()); + Assert.assertTrue(transform.get(0) instanceof LogicalJoin); + LogicalJoin newTopJoin = (LogicalJoin) transform.get(0); + return new Pair<>(topJoin, newTopJoin); + } + + @Test + public void testStarJoinProjectLAsscom(@Mocked PlannerContext plannerContext) { + List<SlotReference> t1 = outputs.get(0); + List<SlotReference> t2 = outputs.get(1); + List<NamedExpression> projects = ImmutableList.of( + new Alias(t2.get(0), "t2.id"), + new Alias(t1.get(0), "t1.id"), + t1.get(1), + t2.get(1) + ); + + Pair<LogicalJoin, LogicalJoin> pair = testJoinProjectLAsscom(plannerContext, projects); + + LogicalJoin oldJoin = pair.first; + LogicalJoin newTopJoin = pair.second; + + // Join reorder successfully. + Assert.assertNotEquals(oldJoin, newTopJoin); + Assert.assertEquals("t1.id", + ((Alias) ((LogicalProject) newTopJoin.left()).getProjects().get(0)).getName()); + Assert.assertEquals("name", + ((SlotReference) ((LogicalProject) newTopJoin.left()).getProjects().get(1)).getName()); + Assert.assertEquals("t2.id", + ((Alias) ((LogicalProject) newTopJoin.right()).getProjects().get(0)).getName()); + Assert.assertEquals("name", + ((SlotReference) ((LogicalProject) newTopJoin.left()).getProjects().get(1)).getName()); + + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionUtilsTest.java index 27d85b5f0b..4fe1b3e2c0 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionUtilsTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionUtilsTest.java @@ -38,7 +38,7 @@ public class ExpressionUtilsTest { Expression expr; expr = PARSER.parseExpression("a"); - expressions = ExpressionUtils.extractConjunct(expr); + expressions = ExpressionUtils.extractConjunctive(expr); Assertions.assertEquals(expressions.size(), 1); Assertions.assertEquals(expressions.get(0), expr); @@ -47,7 +47,7 @@ public class ExpressionUtilsTest { Expression b = PARSER.parseExpression("b"); Expression c = PARSER.parseExpression("c"); - expressions = ExpressionUtils.extractConjunct(expr); + expressions = ExpressionUtils.extractConjunctive(expr); Assertions.assertEquals(expressions.size(), 3); Assertions.assertEquals(expressions.get(0), a); Assertions.assertEquals(expressions.get(1), b); @@ -55,7 +55,7 @@ public class ExpressionUtilsTest { expr = PARSER.parseExpression("(a or b) and c and (e or f)"); - expressions = ExpressionUtils.extractConjunct(expr); + expressions = ExpressionUtils.extractConjunctive(expr); Expression aOrb = PARSER.parseExpression("a or b"); Expression eOrf = PARSER.parseExpression("e or f"); Assertions.assertEquals(expressions.size(), 3); @@ -70,7 +70,7 @@ public class ExpressionUtilsTest { Expression expr; expr = PARSER.parseExpression("a"); - expressions = ExpressionUtils.extractDisjunct(expr); + expressions = ExpressionUtils.extractDisjunctive(expr); Assertions.assertEquals(expressions.size(), 1); Assertions.assertEquals(expressions.get(0), expr); @@ -79,14 +79,14 @@ public class ExpressionUtilsTest { Expression b = PARSER.parseExpression("b"); Expression c = PARSER.parseExpression("c"); - expressions = ExpressionUtils.extractDisjunct(expr); + expressions = ExpressionUtils.extractDisjunctive(expr); Assertions.assertEquals(expressions.size(), 3); Assertions.assertEquals(expressions.get(0), a); Assertions.assertEquals(expressions.get(1), b); Assertions.assertEquals(expressions.get(2), c); expr = PARSER.parseExpression("(a and b) or c or (e and f)"); - expressions = ExpressionUtils.extractDisjunct(expr); + expressions = ExpressionUtils.extractDisjunctive(expr); Expression aAndb = PARSER.parseExpression("a and b"); Expression eAndf = PARSER.parseExpression("e and f"); Assertions.assertEquals(expressions.size(), 3); --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org