This is an automated email from the ASF dual-hosted git repository. lingmiao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new f998c0b044 [Enhancement](Nereids) push down predicate through join (#10462) f998c0b044 is described below commit f998c0b044ea81efa5057f1b67a232ad7e40da85 Author: shee <13843187+qz...@users.noreply.github.com> AuthorDate: Fri Jul 1 15:39:01 2022 +0800 [Enhancement](Nereids) push down predicate through join (#10462) Add filter operator to join children according to the predicate of filter and join, in order to achieving predicate push-down Pattern: ``` filter | join / \ child child ``` Transform: ``` filter | join / \ filter filter | | child child ``` --- .../java/org/apache/doris/nereids/jobs/Job.java | 3 +- .../nereids/jobs/rewrite/RewriteTopDownJob.java | 6 +- .../java/org/apache/doris/nereids/memo/Memo.java | 3 + .../doris/nereids/parser/LogicalPlanBuilder.java | 3 +- .../apache/doris/nereids/properties/OrderKey.java | 5 + .../org/apache/doris/nereids/rules/RuleType.java | 1 + .../expression/rewrite/ExpressionRuleExecutor.java | 8 +- .../rewrite/logical/PushPredicateThroughJoin.java | 165 +++++++++++++ .../doris/nereids/trees/expressions/Add.java | 3 + .../trees/expressions/CompoundPredicate.java | 29 +++ .../trees/expressions/IterationVisitor.java | 161 +++++++++++++ .../doris/nereids/trees/expressions/Literal.java | 6 + .../nereids/trees/expressions/SlotExtractor.java | 68 ++++++ .../apache/doris/nereids/util/ExpressionUtils.java | 140 +++++++++++ .../rewrite/logical/PushDownPredicateTest.java | 256 +++++++++++++++++++++ .../org/apache/doris/nereids/ssb/SSBUtils.java | 14 +- .../doris/nereids/util/ExpressionUtilsTest.java | 101 ++++++++ 17 files changed, 960 insertions(+), 12 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java index 97c2a375e6..3f7ef0200c 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/Job.java @@ -25,6 +25,7 @@ import org.apache.doris.nereids.rules.RuleSet; import org.apache.doris.nereids.trees.TreeNode; import java.util.List; +import java.util.Objects; import java.util.stream.Collectors; /** @@ -57,7 +58,7 @@ public abstract class Job<NODE_TYPE extends TreeNode<NODE_TYPE>> { public List<Rule<NODE_TYPE>> getValidRules(GroupExpression groupExpression, List<Rule<NODE_TYPE>> candidateRules) { return candidateRules.stream() - .filter(rule -> rule.getPattern().matchOperator(groupExpression.getOperator()) + .filter(rule -> Objects.nonNull(rule) && rule.getPattern().matchOperator(groupExpression.getOperator()) && groupExpression.notApplied(rule)).collect(Collectors.toList()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteTopDownJob.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteTopDownJob.java index 89c7f70d6b..4504109d68 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteTopDownJob.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/jobs/rewrite/RewriteTopDownJob.java @@ -72,7 +72,9 @@ public class RewriteTopDownJob extends Job<Plan> { Preconditions.checkArgument(afters.size() == 1); Plan after = afters.get(0); if (after != before) { - context.getOptimizerContext().getMemo().copyIn(after, group, rule.isRewrite()); + GroupExpression expression = context.getOptimizerContext().getMemo() + .copyIn(after, group, rule.isRewrite()); + expression.setApplied(rule); pushTask(new RewriteTopDownJob(group, rules, context)); return; } @@ -80,7 +82,7 @@ public class RewriteTopDownJob extends Job<Plan> { logicalExpression.setApplied(rule); } - for (Group childGroup : logicalExpression.children()) { + for (Group childGroup : group.getLogicalExpression().children()) { pushTask(new RewriteTopDownJob(childGroup, rules, context)); } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java index 761ac4c342..b125a85ec8 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/memo/Memo.java @@ -98,6 +98,9 @@ public class Memo { childrenNode.add(groupToTreeNode(child)); } Plan result = logicalExpression.getOperator().toTreeNode(logicalExpression); + if (result.children().size() == 0) { + return result; + } return result.withChildren(childrenNode); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java index f19c8fb9a7..e63601dc65 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/parser/LogicalPlanBuilder.java @@ -39,6 +39,7 @@ import org.apache.doris.nereids.DorisParser.MultipartIdentifierContext; import org.apache.doris.nereids.DorisParser.NamedExpressionContext; import org.apache.doris.nereids.DorisParser.NamedExpressionSeqContext; import org.apache.doris.nereids.DorisParser.NullLiteralContext; +import org.apache.doris.nereids.DorisParser.ParenthesizedExpressionContext; import org.apache.doris.nereids.DorisParser.PredicateContext; import org.apache.doris.nereids.DorisParser.PredicatedContext; import org.apache.doris.nereids.DorisParser.QualifiedNameContext; @@ -402,7 +403,7 @@ public class LogicalPlanBuilder extends DorisParserBaseVisitor<Object> { } @Override - public Expression visitParenthesizedExpression(DorisParser.ParenthesizedExpressionContext ctx) { + public Expression visitParenthesizedExpression(ParenthesizedExpressionContext ctx) { return getExpression(ctx.expression()); } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/OrderKey.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/OrderKey.java index 9cbf800d6e..2bd91c6540 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/OrderKey.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/properties/OrderKey.java @@ -53,4 +53,9 @@ public class OrderKey { public boolean isNullFirst() { return nullFirst; } + + @Override + public String toString() { + return expr.sql(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java index 2e961f27ee..9cd9b42ccb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/RuleType.java @@ -39,6 +39,7 @@ public enum RuleType { // rewrite rules COLUMN_PRUNE_PROJECTION(RuleTypeClass.REWRITE), + PUSH_DOWN_PREDICATE_THROUGH_JOIN(RuleTypeClass.REWRITE), // exploration rules LOGICAL_JOIN_COMMUTATIVE(RuleTypeClass.EXPLORATION), diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java index 547266cea5..55c2ef0060 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rewrite/ExpressionRuleExecutor.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.rules.expression.rewrite.rules.NormalizeExpressi import org.apache.doris.nereids.rules.expression.rewrite.rules.SimplifyNotExprRule; import org.apache.doris.nereids.trees.expressions.Expression; +import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import java.util.List; @@ -30,7 +31,7 @@ import java.util.List; */ public class ExpressionRuleExecutor { - public static final List<ExpressionRewriteRule> REWRITE_RULES = Lists.newArrayList( + public static final List<ExpressionRewriteRule> REWRITE_RULES = ImmutableList.of( new SimplifyNotExprRule(), new NormalizeExpressionRule() ); @@ -38,6 +39,11 @@ public class ExpressionRuleExecutor { private final ExpressionRewriteContext ctx; private final List<ExpressionRewriteRule> rules; + public ExpressionRuleExecutor() { + this.rules = REWRITE_RULES; + this.ctx = new ExpressionRewriteContext(); + } + public ExpressionRuleExecutor(List<ExpressionRewriteRule> rules) { this.rules = rules; this.ctx = new ExpressionRewriteContext(); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java new file mode 100644 index 0000000000..429c495623 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/logical/PushPredicateThroughJoin.java @@ -0,0 +1,165 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.nereids.operators.plans.logical.LogicalFilter; +import org.apache.doris.nereids.operators.plans.logical.LogicalJoin; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.rules.RuleType; +import org.apache.doris.nereids.rules.expression.rewrite.ExpressionRuleExecutor; +import org.apache.doris.nereids.rules.rewrite.OneRewriteRuleFactory; +import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Literal; +import org.apache.doris.nereids.trees.expressions.Slot; +import org.apache.doris.nereids.trees.expressions.SlotExtractor; +import org.apache.doris.nereids.trees.plans.GroupPlan; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.logical.LogicalBinaryPlan; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.Lists; +import com.google.common.collect.Sets; + +import java.util.List; +import java.util.Objects; +import java.util.Optional; +import java.util.Set; + +/** + * Push the predicate in the LogicalFilter or LogicalJoin to the join children. + * For example: + * select a.k1,b.k1 from a join b on a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5 where a.k1 > 1 and b.k1 > 2 + * Logical plan tree: + * project + * | + * filter (a.k1 > 1 and b.k1 > 2) + * | + * join (a.k1 = b.k1 and a.k2 > 2 and b.k2 > 5) + * / \ + * scan scan + * transformed: + * project + * | + * join (a.k1 = b.k1) + * / \ + * filter(a.k1 > 1 and a.k2 > 2 ) filter(b.k1 > 2 and b.k2 > 5) + * | | + * scan scan + * todo: Now, only support eq on condition for inner join, support other case later + */ +public class PushPredicateThroughJoin extends OneRewriteRuleFactory { + + @Override + public Rule<Plan> build() { + return logicalFilter(innerLogicalJoin()).then(filter -> { + + LogicalJoin joinOp = filter.child().operator; + + Expression wherePredicates = filter.operator.getPredicates(); + Expression onPredicates = Literal.TRUE_LITERAL; + + List<Expression> otherConditions = Lists.newArrayList(); + List<Expression> eqConditions = Lists.newArrayList(); + + if (joinOp.getCondition().isPresent()) { + onPredicates = joinOp.getCondition().get(); + } + + List<Slot> leftInput = filter.child().left().getOutput(); + List<Slot> rightInput = filter.child().right().getOutput(); + + ExpressionUtils.extractConjunct(ExpressionUtils.add(onPredicates, wherePredicates)).forEach(predicate -> { + if (Objects.nonNull(getJoinCondition(predicate, leftInput, rightInput))) { + eqConditions.add(predicate); + } else { + otherConditions.add(predicate); + } + }); + + List<Expression> leftPredicates = Lists.newArrayList(); + List<Expression> rightPredicates = Lists.newArrayList(); + + for (Expression p : otherConditions) { + Set<Slot> slots = SlotExtractor.extractSlot(p); + if (slots.isEmpty()) { + leftPredicates.add(p); + rightPredicates.add(p); + } + if (leftInput.containsAll(slots)) { + leftPredicates.add(p); + } + if (rightInput.containsAll(slots)) { + rightPredicates.add(p); + } + } + + otherConditions.removeAll(leftPredicates); + otherConditions.removeAll(rightPredicates); + otherConditions.addAll(eqConditions); + Expression joinConditions = ExpressionUtils.add(otherConditions); + + return pushDownPredicate(filter.child(), joinConditions, leftPredicates, rightPredicates); + }).toRule(RuleType.PUSH_DOWN_PREDICATE_THROUGH_JOIN); + } + + private Plan pushDownPredicate(LogicalBinaryPlan<LogicalJoin, GroupPlan, GroupPlan> joinPlan, + Expression joinConditions, List<Expression> leftPredicates, List<Expression> rightPredicates) { + + Expression left = ExpressionUtils.add(leftPredicates); + Expression right = ExpressionUtils.add(rightPredicates); + //todo expr should optimize again using expr rewrite + ExpressionRuleExecutor exprRewriter = new ExpressionRuleExecutor(); + Plan leftPlan = joinPlan.left(); + Plan rightPlan = joinPlan.right(); + if (!left.equals(Literal.TRUE_LITERAL)) { + leftPlan = plan(new LogicalFilter(exprRewriter.rewrite(left)), leftPlan); + } + + if (!right.equals(Literal.TRUE_LITERAL)) { + rightPlan = plan(new LogicalFilter(exprRewriter.rewrite(right)), rightPlan); + } + + return plan(new LogicalJoin(joinPlan.operator.getJoinType(), Optional.of(joinConditions)), leftPlan, rightPlan); + } + + private Expression getJoinCondition(Expression predicate, List<Slot> leftOutputs, List<Slot> rightOutputs) { + if (!(predicate instanceof ComparisonPredicate)) { + return null; + } + + ComparisonPredicate comparison = (ComparisonPredicate) predicate; + + Set<Slot> leftSlots = SlotExtractor.extractSlot(comparison.left()); + Set<Slot> rightSlots = SlotExtractor.extractSlot(comparison.right()); + + if (!(leftSlots.size() >= 1 && rightSlots.size() >= 1)) { + return null; + } + + Set<Slot> left = Sets.newLinkedHashSet(leftOutputs); + Set<Slot> right = Sets.newLinkedHashSet(rightOutputs); + + if ((left.containsAll(leftSlots) && right.containsAll(rightSlots)) || (left.containsAll(rightSlots) + && right.containsAll(leftSlots))) { + return predicate; + } + + return null; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java index 27bf07a430..dac9d919e2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Add.java @@ -48,4 +48,7 @@ public class Add<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_TYPE extends Ex } + public String toString() { + return sql(); + } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java index f09e1ae4f6..39dcc2f4f4 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/CompoundPredicate.java @@ -19,6 +19,9 @@ package org.apache.doris.nereids.trees.expressions; import org.apache.doris.nereids.trees.NodeType; +import java.util.List; +import java.util.Objects; + /** * Compound predicate expression. * Such as &&,||,AND,OR. @@ -53,4 +56,30 @@ public class CompoundPredicate<LEFT_CHILD_TYPE extends Expression, RIGHT_CHILD_T public <R, C> R accept(ExpressionVisitor<R, C> visitor, C context) { return visitor.visitCompoundPredicate(this, context); } + + public NodeType flip() { + if (getType() == NodeType.AND) { + return NodeType.OR; + } + return NodeType.AND; + } + + @Override + public boolean equals(Object o) { + if (this == o) { + return true; + } + if (o == null || getClass() != o.getClass()) { + return false; + } + CompoundPredicate other = (CompoundPredicate) o; + return (type == other.getType()) && Objects.equals(left(), other.left()) + && Objects.equals(right(), other.right()); + } + + @Override + public Expression withChildren(List<Expression> children) { + return new CompoundPredicate<>(getType(), children.get(0), children.get(1)); + } } + diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/IterationVisitor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/IterationVisitor.java new file mode 100644 index 0000000000..34c3baac12 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/IterationVisitor.java @@ -0,0 +1,161 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import org.apache.doris.nereids.trees.expressions.functions.AggregateFunction; +import org.apache.doris.nereids.trees.expressions.functions.BoundFunction; + +/** + * Iterative traversal of an expression. + */ +public abstract class IterationVisitor<C> extends DefaultExpressionVisitor<Void, C> { + + @Override + public Void visit(Expression expr, C context) { + return expr.accept(this, context); + } + + @Override + public Void visitNot(Not expr, C context) { + visit(expr.child(), context); + return null; + } + + @Override + public Void visitCompoundPredicate(CompoundPredicate expr, C context) { + visit(expr.left(), context); + visit(expr.right(), context); + return null; + } + + @Override + public Void visitLiteral(Literal literal, C context) { + return null; + } + + @Override + public Void visitArithmetic(Arithmetic arithmetic, C context) { + visit(arithmetic.child(0), context); + if (arithmetic.getArithmeticOperator().isBinary()) { + visit(arithmetic.child(1), context); + } + return null; + } + + @Override + public Void visitBetween(Between betweenPredicate, C context) { + visit(betweenPredicate.getCompareExpr(), context); + visit(betweenPredicate.getLowerBound(), context); + visit(betweenPredicate.getUpperBound(), context); + return null; + } + + @Override + public Void visitAlias(Alias alias, C context) { + return visitNamedExpression(alias, context); + } + + @Override + public Void visitComparisonPredicate(ComparisonPredicate cp, C context) { + visit(cp.left(), context); + visit(cp.right(), context); + return null; + } + + @Override + public Void visitEqualTo(EqualTo equalTo, C context) { + return visitComparisonPredicate(equalTo, context); + } + + @Override + public Void visitGreaterThan(GreaterThan greaterThan, C context) { + return visitComparisonPredicate(greaterThan, context); + } + + @Override + public Void visitGreaterThanEqual(GreaterThanEqual greaterThanEqual, C context) { + return visitComparisonPredicate(greaterThanEqual, context); + } + + @Override + public Void visitLessThan(LessThan lessThan, C context) { + return visitComparisonPredicate(lessThan, context); + } + + @Override + public Void visitLessThanEqual(LessThanEqual lessThanEqual, C context) { + return visitComparisonPredicate(lessThanEqual, context); + } + + @Override + public Void visitNullSafeEqual(NullSafeEqual nullSafeEqual, C context) { + return visitComparisonPredicate(nullSafeEqual, context); + } + + @Override + public Void visitSlot(Slot slot, C context) { + return null; + } + + @Override + public Void visitNamedExpression(NamedExpression namedExpression, C context) { + for (Expression child : namedExpression.children()) { + visit(child, context); + } + return null; + } + + @Override + public Void visitBoundFunction(BoundFunction boundFunction, C context) { + for (Expression argument : boundFunction.getArguments()) { + visit(argument, context); + } + return null; + } + + @Override + public Void visitAggregateFunction(AggregateFunction aggregateFunction, C context) { + return visitBoundFunction(aggregateFunction, context); + } + + @Override + public Void visitAdd(Add add, C context) { + return visitArithmetic(add, context); + } + + @Override + public Void visitSubtract(Subtract subtract, C context) { + return visitArithmetic(subtract, context); + } + + @Override + public Void visitMultiply(Multiply multiply, C context) { + return visitArithmetic(multiply, context); + } + + @Override + public Void visitDivide(Divide divide, C context) { + return visitArithmetic(divide, context); + } + + @Override + public Void visitMod(Mod mod, C context) { + return visitArithmetic(mod, context); + } + +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Literal.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Literal.java index 912b72705e..0d958f38c2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Literal.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/Literal.java @@ -34,6 +34,8 @@ import java.util.Objects; * TODO: Increase the implementation of sub expression. such as Integer. */ public class Literal extends Expression implements LeafExpression { + public static final Literal TRUE_LITERAL = new Literal(true); + public static final Literal FALSE_LITERAL = new Literal(false); private final DataType dataType; private final Object value; @@ -97,6 +99,10 @@ public class Literal extends Expression implements LeafExpression { return value == null; } + public static Literal of(Object value) { + return new Literal(value); + } + @Override public boolean isConstant() { return true; diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotExtractor.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotExtractor.java new file mode 100644 index 0000000000..68f0b8d2de --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/trees/expressions/SlotExtractor.java @@ -0,0 +1,68 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.trees.expressions; + +import com.clearspring.analytics.util.Lists; +import com.google.common.collect.Sets; + +import java.util.Collection; +import java.util.List; +import java.util.Set; + +/** + * Extracts the SlotReference contained in the expression. + */ +public class SlotExtractor extends IterationVisitor<List<Slot>> { + + /** + * extract slot reference. + */ + public static Set<Slot> extractSlot(Collection<Expression> expressions) { + + Set<Slot> slots = Sets.newLinkedHashSet(); + for (Expression expression : expressions) { + slots.addAll(extractSlot(expression)); + } + return slots; + } + + /** + * extract slot reference. + */ + public static Set<Slot> extractSlot(Expression... expressions) { + + Set<Slot> slots = Sets.newLinkedHashSet(); + for (Expression expression : expressions) { + slots.addAll(extractSlot(expression)); + } + return slots; + } + + private static List<Slot> extractSlot(Expression expression) { + List<Slot> slots = Lists.newArrayList(); + new SlotExtractor().visit(expression, slots); + return slots; + } + + + @Override + public Void visitSlotReference(SlotReference slotReference, List<Slot> context) { + context.add(slotReference); + return null; + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java new file mode 100644 index 0000000000..b513038b35 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -0,0 +1,140 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.util; + +import org.apache.doris.nereids.trees.NodeType; +import org.apache.doris.nereids.trees.expressions.CompoundPredicate; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.Literal; + +import com.google.common.collect.Lists; + +import java.util.LinkedHashSet; +import java.util.LinkedList; +import java.util.List; +import java.util.Objects; +import java.util.stream.Collectors; + +/** + * Expression rewrite helper class. + */ +public class ExpressionUtils { + + public static boolean isConstant(Expression expr) { + return expr.isConstant(); + } + + public static List<Expression> extractConjunct(Expression expr) { + return extract(NodeType.AND, expr); + } + + + public static List<Expression> extractDisjunct(Expression expr) { + return extract(NodeType.OR, expr); + } + + public static List<Expression> extract(CompoundPredicate expr) { + return extract(expr.getType(), expr); + } + + private static List<Expression> extract(NodeType op, Expression expr) { + List<Expression> result = Lists.newArrayList(); + extract(op, expr, result); + return result; + } + + private static void extract(NodeType op, Expression expr, List<Expression> result) { + if (expr instanceof CompoundPredicate && expr.getType() == op) { + CompoundPredicate predicate = (CompoundPredicate) expr; + extract(op, predicate.left(), result); + extract(op, predicate.right(), result); + } else { + result.add(expr); + } + } + + + public static Expression add(List<Expression> expressions) { + return combine(NodeType.AND, expressions); + } + + public static Expression add(Expression... expressions) { + return combine(NodeType.AND, Lists.newArrayList(expressions)); + } + + public static Expression or(Expression... expressions) { + return combine(NodeType.OR, Lists.newArrayList(expressions)); + } + + public static Expression or(List<Expression> expressions) { + return combine(NodeType.OR, expressions); + } + + /** + * Use AND/OR to combine expressions together. + */ + public static Expression combine(NodeType op, List<Expression> expressions) { + + Objects.requireNonNull(expressions, "expressions is null"); + + if (expressions.size() == 0) { + if (op == NodeType.AND) { + return new Literal(true); + } + if (op == NodeType.OR) { + return new Literal(false); + } + } + + if (expressions.size() == 1) { + return expressions.get(0); + } + + List<Expression> distinctExpressions = Lists.newArrayList(new LinkedHashSet<>(expressions)); + if (op == NodeType.AND) { + if (distinctExpressions.contains(Literal.FALSE_LITERAL)) { + return Literal.FALSE_LITERAL; + } + distinctExpressions = distinctExpressions.stream().filter(p -> !p.equals(Literal.TRUE_LITERAL)) + .collect(Collectors.toList()); + } + + if (op == NodeType.OR) { + if (distinctExpressions.contains(Literal.TRUE_LITERAL)) { + return Literal.TRUE_LITERAL; + } + distinctExpressions = distinctExpressions.stream().filter(p -> !p.equals(Literal.FALSE_LITERAL)) + .collect(Collectors.toList()); + } + + List<List<Expression>> partitions = Lists.partition(distinctExpressions, 2); + List<Expression> result = new LinkedList<>(); + + for (List<Expression> partition : partitions) { + if (partition.size() == 2) { + result.add(new CompoundPredicate(op, partition.get(0), partition.get(1))); + } + if (partition.size() == 1) { + result.add(partition.get(0)); + } + } + + return combine(op, result); + } +} + diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java new file mode 100644 index 0000000000..ee07b99194 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/rewrite/logical/PushDownPredicateTest.java @@ -0,0 +1,256 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.rewrite.logical; + +import org.apache.doris.catalog.AggregateType; +import org.apache.doris.catalog.Column; +import org.apache.doris.catalog.Table; +import org.apache.doris.catalog.Type; +import org.apache.doris.nereids.OptimizerContext; +import org.apache.doris.nereids.PlannerContext; +import org.apache.doris.nereids.jobs.rewrite.RewriteTopDownJob; +import org.apache.doris.nereids.memo.Group; +import org.apache.doris.nereids.memo.Memo; +import org.apache.doris.nereids.operators.Operator; +import org.apache.doris.nereids.operators.plans.JoinType; +import org.apache.doris.nereids.operators.plans.logical.LogicalFilter; +import org.apache.doris.nereids.operators.plans.logical.LogicalJoin; +import org.apache.doris.nereids.operators.plans.logical.LogicalOlapScan; +import org.apache.doris.nereids.operators.plans.logical.LogicalProject; +import org.apache.doris.nereids.properties.PhysicalProperties; +import org.apache.doris.nereids.rules.Rule; +import org.apache.doris.nereids.trees.expressions.Add; +import org.apache.doris.nereids.trees.expressions.Between; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.GreaterThan; +import org.apache.doris.nereids.trees.expressions.Literal; +import org.apache.doris.nereids.trees.expressions.Subtract; +import org.apache.doris.nereids.trees.plans.Plan; +import org.apache.doris.nereids.trees.plans.Plans; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Lists; +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.BeforeAll; +import org.junit.jupiter.api.Test; +import org.junit.jupiter.api.TestInstance; + +import java.util.List; +import java.util.Optional; + +/** + * plan rewrite ut. + */ +@TestInstance(TestInstance.Lifecycle.PER_CLASS) +public class PushDownPredicateTest implements Plans { + + private Table student; + private Table score; + private Table course; + + private Plan rStudent; + private Plan rScore; + private Plan rCourse; + + /** + * ut before. + */ + @BeforeAll + public final void beforeAll() { + student = new Table(0L, "student", Table.TableType.OLAP, + ImmutableList.<Column>of(new Column("id", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("name", Type.STRING, true, AggregateType.NONE, "", ""), + new Column("age", Type.INT, true, AggregateType.NONE, "", ""))); + + score = new Table(0L, "score", Table.TableType.OLAP, + ImmutableList.<Column>of(new Column("sid", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("cid", Type.INT, true, AggregateType.NONE, "", ""), + new Column("grade", Type.DOUBLE, true, AggregateType.NONE, "", ""))); + + course = new Table(0L, "course", Table.TableType.OLAP, + ImmutableList.<Column>of(new Column("cid", Type.INT, true, AggregateType.NONE, "0", ""), + new Column("name", Type.STRING, true, AggregateType.NONE, "", ""), + new Column("teacher", Type.STRING, true, AggregateType.NONE, "", ""))); + + rStudent = plan(new LogicalOlapScan(student, ImmutableList.of("student"))); + + rScore = plan(new LogicalOlapScan(score, ImmutableList.of("score"))); + + rCourse = plan(new LogicalOlapScan(course, ImmutableList.of("course"))); + } + + @Test + public void pushDownPredicateIntoScanTest1() { + // select id,name,grade from student join score on student.id = score.sid and student.id > 1 + // and score.cid > 2 where student.age > 18 and score.grade > 60 + Expression onCondition1 = new EqualTo<>(rStudent.getOutput().get(0), rScore.getOutput().get(0)); + Expression onCondition2 = new GreaterThan<>(rStudent.getOutput().get(0), Literal.of(1)); + Expression onCondition3 = new GreaterThan<>(rScore.getOutput().get(0), Literal.of(2)); + Expression onCondition = ExpressionUtils.add(onCondition1, onCondition2, onCondition3); + + Expression whereCondition1 = new GreaterThan<>(rStudent.getOutput().get(1), Literal.of(18)); + Expression whereCondition2 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60)); + Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2); + + + Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.of(onCondition)), rStudent, rScore); + Plan filter = plan(new LogicalFilter(whereCondition), join); + + Plan root = plan(new LogicalProject( + Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2))), + filter); + + Memo memo = new Memo(); + memo.initialize(root); + System.out.println(memo.copyOut().treeString()); + + OptimizerContext optimizerContext = new OptimizerContext(memo); + PlannerContext plannerContext = new PlannerContext(optimizerContext, null, new PhysicalProperties()); + RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(), + ImmutableList.of(new PushPredicateThroughJoin().build()), plannerContext); + plannerContext.getOptimizerContext().pushJob(rewriteTopDownJob); + plannerContext.getOptimizerContext().getJobScheduler().executeJobPool(plannerContext); + + Group rootGroup = memo.getRoot(); + System.out.println(memo.copyOut().treeString()); + System.out.println(11); + + Operator op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getOperator(); + Operator op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().getOperator(); + Operator op3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression().getOperator(); + + Assertions.assertTrue(op1 instanceof LogicalJoin); + Assertions.assertTrue(op2 instanceof LogicalFilter); + Assertions.assertTrue(op3 instanceof LogicalFilter); + LogicalJoin join1 = (LogicalJoin) op1; + LogicalFilter filter1 = (LogicalFilter) op2; + LogicalFilter filter2 = (LogicalFilter) op3; + + Assertions.assertEquals(join1.getCondition().get(), onCondition1); + Assertions.assertEquals(filter1.getPredicates(), ExpressionUtils.add(onCondition2, whereCondition1)); + Assertions.assertEquals(filter2.getPredicates(), ExpressionUtils.add(onCondition3, whereCondition2)); + } + + @Test + public void pushDownPredicateIntoScanTest3() { + //select id,name,grade from student left join score on student.id + 1 = score.sid - 2 + //where student.age > 18 and score.grade > 60 + Expression whereCondition1 = new EqualTo<>(new Add<>(rStudent.getOutput().get(0), Literal.of(1)), + new Subtract<>(rScore.getOutput().get(0), Literal.of(2))); + Expression whereCondition2 = new GreaterThan<>(rStudent.getOutput().get(1), Literal.of(18)); + Expression whereCondition3 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60)); + Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2, whereCondition3); + + Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), rStudent, rScore); + Plan filter = plan(new LogicalFilter(whereCondition), join); + + Plan root = plan(new LogicalProject( + Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2))), + filter); + + Memo memo = new Memo(); + memo.initialize(root); + System.out.println(memo.copyOut().treeString()); + + OptimizerContext optimizerContext = new OptimizerContext(memo); + PlannerContext plannerContext = new PlannerContext(optimizerContext, null, new PhysicalProperties()); + RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(), + ImmutableList.of(new PushPredicateThroughJoin().build()), plannerContext); + plannerContext.getOptimizerContext().pushJob(rewriteTopDownJob); + plannerContext.getOptimizerContext().getJobScheduler().executeJobPool(plannerContext); + + Group rootGroup = memo.getRoot(); + System.out.println(memo.copyOut().treeString()); + + Operator op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getOperator(); + Operator op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().getOperator(); + Operator op3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression().getOperator(); + + Assertions.assertTrue(op1 instanceof LogicalJoin); + Assertions.assertTrue(op2 instanceof LogicalFilter); + Assertions.assertTrue(op3 instanceof LogicalFilter); + LogicalJoin join1 = (LogicalJoin) op1; + LogicalFilter filter1 = (LogicalFilter) op2; + LogicalFilter filter2 = (LogicalFilter) op3; + Assertions.assertEquals(join1.getCondition().get(), whereCondition1); + Assertions.assertEquals(filter1.getPredicates(), whereCondition2); + Assertions.assertEquals(filter2.getPredicates(), whereCondition3); + } + + @Test + public void pushDownPredicateIntoScanTest4() { + /* + select + student.name, + course.name, + score.grade + from student,score,course + where on student.id = score.sid and student.age between 18 and 20 and score.grade > 60 and student.id = score.sid + */ + + // student.id = score.sid + Expression whereCondition1 = new EqualTo<>(rStudent.getOutput().get(0), rScore.getOutput().get(0)); + // score.cid = course.cid + Expression whereCondition2 = new EqualTo<>(rScore.getOutput().get(1), rCourse.getOutput().get(0)); + // student.age between 18 and 20 + Expression whereCondition3 = new Between<>(rStudent.getOutput().get(2), Literal.of(18), Literal.of(20)); + // score.grade > 60 + Expression whereCondition4 = new GreaterThan<>(rScore.getOutput().get(2), Literal.of(60)); + + Expression whereCondition = ExpressionUtils.add(whereCondition1, whereCondition2, whereCondition3, whereCondition4); + + Plan join = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), rStudent, rScore); + Plan join1 = plan(new LogicalJoin(JoinType.INNER_JOIN, Optional.empty()), join, rCourse); + Plan filter = plan(new LogicalFilter(whereCondition), join1); + + Plan root = plan(new LogicalProject( + Lists.newArrayList(rStudent.getOutput().get(1), rCourse.getOutput().get(1), rScore.getOutput().get(2))), + filter); + + + Memo memo = new Memo(); + memo.initialize(root); + System.out.println(memo.copyOut().treeString()); + + OptimizerContext optimizerContext = new OptimizerContext(memo); + PlannerContext plannerContext = new PlannerContext(optimizerContext, null, new PhysicalProperties()); + List<Rule<Plan>> fakeRules = Lists.newArrayList(new PushPredicateThroughJoin().build()); + RewriteTopDownJob rewriteTopDownJob = new RewriteTopDownJob(memo.getRoot(), fakeRules, plannerContext); + plannerContext.getOptimizerContext().pushJob(rewriteTopDownJob); + plannerContext.getOptimizerContext().getJobScheduler().executeJobPool(plannerContext); + + Group rootGroup = memo.getRoot(); + System.out.println(memo.copyOut().treeString()); + Operator join2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().getOperator(); + Operator join3 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().getOperator(); + Operator op1 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().getOperator(); + Operator op2 = rootGroup.getLogicalExpression().child(0).getLogicalExpression().child(0).getLogicalExpression().child(1).getLogicalExpression().getOperator(); + + Assertions.assertTrue(join2 instanceof LogicalJoin); + Assertions.assertTrue(join3 instanceof LogicalJoin); + Assertions.assertTrue(op1 instanceof LogicalFilter); + Assertions.assertTrue(op2 instanceof LogicalFilter); + + Assertions.assertEquals(((LogicalJoin) join2).getCondition().get(), whereCondition2); + Assertions.assertEquals(((LogicalJoin) join3).getCondition().get(), whereCondition1); + Assertions.assertEquals(((LogicalFilter) op1).getPredicates(), whereCondition3); + Assertions.assertEquals(((LogicalFilter) op2).getPredicates(), whereCondition4); + } +} diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/ssb/SSBUtils.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/ssb/SSBUtils.java index cf43c1246b..074e970b31 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/ssb/SSBUtils.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/ssb/SSBUtils.java @@ -84,7 +84,7 @@ public class SSBUtils { + " s_nation,\n" + " d_year,\n" + " SUM(lo_revenue) AS REVENUE\n" - + "FROM customer, lineorder, supplier, dates\n" + + "FROM lineorder, customer, supplier, dates\n" + "WHERE\n" + " lo_custkey = c_custkey\n" + " AND lo_suppkey = s_suppkey\n" @@ -101,7 +101,7 @@ public class SSBUtils { + " s_city,\n" + " d_year,\n" + " SUM(lo_revenue) AS REVENUE\n" - + "FROM customer, lineorder, supplier, dates\n" + + "FROM lineorder, customer , supplier, dates\n" + "WHERE\n" + " lo_custkey = c_custkey\n" + " AND lo_suppkey = s_suppkey\n" @@ -118,7 +118,7 @@ public class SSBUtils { + " s_city,\n" + " d_year,\n" + " SUM(lo_revenue) AS REVENUE\n" - + "FROM customer, lineorder, supplier, dates\n" + + "FROM lineorder, customer, supplier, dates\n" + "WHERE\n" + " lo_custkey = c_custkey\n" + " AND lo_suppkey = s_suppkey\n" @@ -141,7 +141,7 @@ public class SSBUtils { + " s_city,\n" + " d_year,\n" + " SUM(lo_revenue) AS REVENUE\n" - + "FROM customer, lineorder, supplier, dates\n" + + "FROM lineorder, customer, supplier, dates\n" + "WHERE\n" + " lo_custkey = c_custkey\n" + " AND lo_suppkey = s_suppkey\n" @@ -162,7 +162,7 @@ public class SSBUtils { + " d_year,\n" + " c_nation,\n" + " SUM(lo_revenue - lo_supplycost) AS PROFIT\n" - + "FROM dates, customer, supplier, part, lineorder\n" + + "FROM lineorder, dates, customer, supplier, part\n" + "WHERE\n" + " lo_custkey = c_custkey\n" + " AND lo_suppkey = s_suppkey\n" @@ -182,7 +182,7 @@ public class SSBUtils { + " s_nation,\n" + " p_category,\n" + " SUM(lo_revenue - lo_supplycost) AS PROFIT\n" - + "FROM dates, customer, supplier, part, lineorder\n" + + "FROM lineorder, dates, customer, supplier, part\n" + "WHERE\n" + " lo_custkey = c_custkey\n" + " AND lo_suppkey = s_suppkey\n" @@ -206,7 +206,7 @@ public class SSBUtils { + " s_city,\n" + " p_brand,\n" + " SUM(lo_revenue - lo_supplycost) AS PROFIT\n" - + "FROM dates, customer, supplier, part, lineorder\n" + + "FROM lineorder, dates, customer, supplier, part\n" + "WHERE\n" + " lo_custkey = c_custkey\n" + " AND lo_suppkey = s_suppkey\n" diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionUtilsTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionUtilsTest.java new file mode 100644 index 0000000000..e60b322804 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/util/ExpressionUtilsTest.java @@ -0,0 +1,101 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.util; + +import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.trees.expressions.Expression; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +import java.util.List; + +/** + * ExpressionUtils ut. + */ +public class ExpressionUtilsTest { + + private static final NereidsParser PARSER = new NereidsParser(); + + @Test + public void extractConjunctsTest() { + List<Expression> expressions; + Expression expr; + + expr = PARSER.createExpression("a"); + expressions = ExpressionUtils.extractConjunct(expr); + Assertions.assertEquals(expressions.size(), 1); + Assertions.assertEquals(expressions.get(0), expr); + + + expr = PARSER.createExpression("a and b and c"); + Expression a = PARSER.createExpression("a"); + Expression b = PARSER.createExpression("b"); + Expression c = PARSER.createExpression("c"); + + expressions = ExpressionUtils.extractConjunct(expr); + Assertions.assertEquals(expressions.size(), 3); + Assertions.assertEquals(expressions.get(0), a); + Assertions.assertEquals(expressions.get(1), b); + Assertions.assertEquals(expressions.get(2), c); + + + expr = PARSER.createExpression("(a or b) and c and (e or f)"); + expressions = ExpressionUtils.extractConjunct(expr); + Expression aOrb = PARSER.createExpression("a or b"); + Expression eOrf = PARSER.createExpression("e or f"); + Assertions.assertEquals(expressions.size(), 3); + Assertions.assertEquals(expressions.get(0), aOrb); + Assertions.assertEquals(expressions.get(1), c); + Assertions.assertEquals(expressions.get(2), eOrf); + + } + + @Test + public void extractDisjunctsTest() { + List<Expression> expressions; + Expression expr; + + expr = PARSER.createExpression("a"); + expressions = ExpressionUtils.extractDisjunct(expr); + Assertions.assertEquals(expressions.size(), 1); + Assertions.assertEquals(expressions.get(0), expr); + + + expr = PARSER.createExpression("a or b or c"); + Expression a = PARSER.createExpression("a"); + Expression b = PARSER.createExpression("b"); + Expression c = PARSER.createExpression("c"); + + expressions = ExpressionUtils.extractDisjunct(expr); + Assertions.assertEquals(expressions.size(), 3); + Assertions.assertEquals(expressions.get(0), a); + Assertions.assertEquals(expressions.get(1), b); + Assertions.assertEquals(expressions.get(2), c); + + + expr = PARSER.createExpression("(a and b) or c or (e and f)"); + expressions = ExpressionUtils.extractDisjunct(expr); + Expression aAndb = PARSER.createExpression("a and b"); + Expression eAndf = PARSER.createExpression("e and f"); + Assertions.assertEquals(expressions.size(), 3); + Assertions.assertEquals(expressions.get(0), aAndb); + Assertions.assertEquals(expressions.get(1), c); + Assertions.assertEquals(expressions.get(2), eAndf); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org