This is an automated email from the ASF dual-hosted git repository. starocean999 pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new 565edd9d133 [feature](nereids) in predicate extract non constant expressions (#46794) 565edd9d133 is described below commit 565edd9d13353f481093731d08915e9ef77e803b Author: yujun <yu...@selectdb.com> AuthorDate: Mon Jan 13 17:29:57 2025 +0800 [feature](nereids) in predicate extract non constant expressions (#46794) Problem Summary: if an in predicate contains non-literal, backend process it will reduce performance. so we need to extract the non constant from the in predicate. this pr add an expression rewrite rule InPredicateExtractNonConstant, it will extract all the non-constant out of the in predicate. for example: ``` k1 in (k2, k3 + 3, 1, 2, 3 + 3) => k1 in (1, 2, 3 + 3) or k1 = k2 or k1 = k3 + 1 ``` --- .../rules/expression/ExpressionNormalization.java | 2 + .../rules/expression/ExpressionRuleType.java | 1 + .../rules/expression/rules/InPredicateDedup.java | 16 ++--- .../rules/InPredicateExtractNonConstant.java | 77 ++++++++++++++++++++++ .../nereids/rules/expression/rules/OrToIn.java | 19 +----- .../rules/expression/rules/RangeInference.java | 3 +- .../rules/expression/rules/SimplifyRange.java | 20 +----- .../apache/doris/nereids/util/ExpressionUtils.java | 8 +++ .../rules/expression/ExpressionRewriteTest.java | 15 +++++ .../rules/InPredicateExtractNonConstantTest.java | 47 +++++++++++++ 10 files changed, 164 insertions(+), 44 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java index b4430d33087..135f80111e9 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java @@ -22,6 +22,7 @@ import org.apache.doris.nereids.rules.expression.rules.ConvertAggStateCast; import org.apache.doris.nereids.rules.expression.rules.DigitalMaskingConvert; import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule; import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup; +import org.apache.doris.nereids.rules.expression.rules.InPredicateExtractNonConstant; import org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule; import org.apache.doris.nereids.rules.expression.rules.MedianConvert; import org.apache.doris.nereids.rules.expression.rules.MergeDateTrunc; @@ -47,6 +48,7 @@ public class ExpressionNormalization extends ExpressionRewrite { SupportJavaDateFormatter.INSTANCE, NormalizeBinaryPredicatesRule.INSTANCE, InPredicateDedup.INSTANCE, + InPredicateExtractNonConstant.INSTANCE, InPredicateToEqualToRule.INSTANCE, SimplifyNotExprRule.INSTANCE, SimplifyArithmeticRule.INSTANCE, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java index bc12c0459ee..7f83ab8a090 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java @@ -35,6 +35,7 @@ public enum ExpressionRuleType { FOLD_CONSTANT_ON_BE, FOLD_CONSTANT_ON_FE, IN_PREDICATE_DEDUP, + IN_PREDICATE_EXTRACT_NON_CONSTANT, IN_PREDICATE_TO_EQUAL_TO, LIKE_TO_EQUAL, MERGE_DATE_TRUNC, diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java index aaa822ac691..1be5971f6a2 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java @@ -36,29 +36,29 @@ import java.util.Set; public class InPredicateDedup implements ExpressionPatternRuleFactory { public static final InPredicateDedup INSTANCE = new InPredicateDedup(); + // In many BI scenarios, the sql is auto-generated, and hence there may be thousands of options. + // It takes a long time to apply this rule. So set a threshold for the max number. + public static final int REWRITE_OPTIONS_MAX_SIZE = 200; + @Override public List<ExpressionPatternMatcher<? extends Expression>> buildRules() { return ImmutableList.of( - matchesType(InPredicate.class).then(InPredicateDedup::dedup) + matchesType(InPredicate.class) + .when(inPredicate -> inPredicate.getOptions().size() <= REWRITE_OPTIONS_MAX_SIZE) + .then(InPredicateDedup::dedup) .toRule(ExpressionRuleType.IN_PREDICATE_DEDUP) ); } /** dedup */ public static Expression dedup(InPredicate inPredicate) { - // In many BI scenarios, the sql is auto-generated, and hence there may be thousands of options. - // It takes a long time to apply this rule. So set a threshold for the max number. - int optionSize = inPredicate.getOptions().size(); - if (optionSize > 200) { - return inPredicate; - } ImmutableSet.Builder<Expression> newOptionsBuilder = ImmutableSet.builderWithExpectedSize(inPredicate.arity()); for (Expression option : inPredicate.getOptions()) { newOptionsBuilder.add(option); } Set<Expression> newOptions = newOptionsBuilder.build(); - if (newOptions.size() == optionSize) { + if (newOptions.size() == inPredicate.getOptions().size()) { return inPredicate; } return new InPredicate(inPredicate.getCompareExpr(), newOptions); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstant.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstant.java new file mode 100644 index 00000000000..c869dafa0a2 --- /dev/null +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstant.java @@ -0,0 +1,77 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher; +import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory; +import org.apache.doris.nereids.rules.expression.ExpressionRuleType; +import org.apache.doris.nereids.trees.expressions.EqualTo; +import org.apache.doris.nereids.trees.expressions.Expression; +import org.apache.doris.nereids.trees.expressions.InPredicate; +import org.apache.doris.nereids.util.ExpressionUtils; + +import com.google.common.collect.ImmutableList; +import com.google.common.collect.Sets; +import org.apache.hadoop.util.Lists; + +import java.util.List; +import java.util.Set; +import java.util.stream.Collectors; + +/** + * Extract non-constant of InPredicate, For example: + * where k1 in (k2, k3, 10, 20, 30) ==> where k1 in (10, 20, 30) or k1 = k2 or k1 = k3. + * It's because backend handle in predicate which contains none-constant column will reduce performance. + */ +public class InPredicateExtractNonConstant implements ExpressionPatternRuleFactory { + public static final InPredicateExtractNonConstant INSTANCE = new InPredicateExtractNonConstant(); + + @Override + public List<ExpressionPatternMatcher<? extends Expression>> buildRules() { + return ImmutableList.of( + matchesType(InPredicate.class) + .when(inPredicate -> inPredicate.getOptions().size() + <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE) + .then(this::rewrite) + .toRule(ExpressionRuleType.IN_PREDICATE_EXTRACT_NON_CONSTANT) + ); + } + + private Expression rewrite(InPredicate inPredicate) { + Set<Expression> nonConstants = Sets.newLinkedHashSetWithExpectedSize(inPredicate.arity()); + for (Expression option : inPredicate.getOptions()) { + if (!option.isConstant()) { + nonConstants.add(option); + } + } + if (nonConstants.isEmpty()) { + return inPredicate; + } + Expression key = inPredicate.getCompareExpr(); + List<Expression> disjunctions = Lists.newArrayListWithExpectedSize(inPredicate.getOptions().size()); + List<Expression> constants = inPredicate.getOptions().stream().filter(Expression::isConstant) + .collect(Collectors.toList()); + if (!constants.isEmpty()) { + disjunctions.add(ExpressionUtils.toInPredicateOrEqualTo(key, constants)); + } + for (Expression option : nonConstants) { + disjunctions.add(new EqualTo(key, option)); + } + return ExpressionUtils.or(disjunctions); + } +} diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java index 136b40af584..16006c02080 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java @@ -31,11 +31,9 @@ import org.apache.doris.nereids.trees.expressions.Slot; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; -import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Sets; -import java.util.ArrayList; import java.util.LinkedHashMap; import java.util.LinkedHashSet; import java.util.List; @@ -69,8 +67,6 @@ public class OrToIn { public static final OrToIn EXTRACT_MODE_INSTANCE = new OrToIn(Mode.extractMode); public static final OrToIn REPLACE_MODE_INSTANCE = new OrToIn(Mode.replaceMode); - public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2; - private final Mode mode; public OrToIn(Mode mode) { @@ -196,18 +192,9 @@ public class OrToIn { } private Expression candidatesToFinalResult(Map<Expression, Set<Literal>> candidates) { - List<Expression> conjuncts = new ArrayList<>(); - for (Expression key : candidates.keySet()) { - Set<Literal> literals = candidates.get(key); - if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { - for (Literal literal : literals) { - conjuncts.add(new EqualTo(key, literal)); - } - } else { - conjuncts.add(new InPredicate(key, ImmutableList.copyOf(literals))); - } - } - return ExpressionUtils.and(conjuncts); + return ExpressionUtils.and(candidates.entrySet().stream() + .map(entry -> ExpressionUtils.toInPredicateOrEqualTo(entry.getKey(), entry.getValue())) + .collect(Collectors.toList())); } /* diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java index c78ec7a75fb..7c23ce36a3d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java @@ -110,7 +110,8 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, @Override public ValueDesc visitInPredicate(InPredicate inPredicate, ExpressionRewriteContext context) { // only handle `NumericType` and `DateLikeType` - if (ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions()) + if (inPredicate.getOptions().size() <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE + && ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions()) && (ExpressionUtils.matchNumericType(inPredicate.getOptions()) || ExpressionUtils.matchDateLikeType(inPredicate.getOptions()))) { return ValueDesc.discrete(context, inPredicate); diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java index 64891882f7d..6ac69f1eb56 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java @@ -27,14 +27,11 @@ import org.apache.doris.nereids.rules.expression.rules.RangeInference.RangeValue import org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue; import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc; import org.apache.doris.nereids.trees.expressions.CompoundPredicate; -import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.GreaterThan; import org.apache.doris.nereids.trees.expressions.GreaterThanEqual; -import org.apache.doris.nereids.trees.expressions.InPredicate; import org.apache.doris.nereids.trees.expressions.LessThan; import org.apache.doris.nereids.trees.expressions.LessThanEqual; -import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; @@ -45,9 +42,7 @@ import com.google.common.collect.Lists; import com.google.common.collect.Range; import org.apache.commons.lang3.NotImplementedException; -import java.util.Iterator; import java.util.List; -import java.util.Set; /** * This class implements the function to simplify expression range. @@ -133,20 +128,7 @@ public class SimplifyRange implements ExpressionPatternRuleFactory { } private Expression getExpression(DiscreteValue value) { - Expression reference = value.getReference(); - Set<Literal> values = value.getValues(); - // NOTICE: it's related with `InPredicateToEqualToRule` - // They are same processes, so must change synchronously. - if (values.size() == 1) { - return new EqualTo(reference, values.iterator().next()); - - // this condition should as same as OrToIn, or else meet dead loop - } else if (values.size() < OrToIn.REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) { - Iterator<Literal> iterator = values.iterator(); - return new Or(new EqualTo(reference, iterator.next()), new EqualTo(reference, iterator.next())); - } else { - return new InPredicate(reference, Lists.newArrayList(values)); - } + return ExpressionUtils.toInPredicateOrEqualTo(value.getReference(), value.getValues()); } private Expression getExpression(UnknownValue value) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 723224409bf..3d8aef2c842 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -276,6 +276,14 @@ public class ExpressionUtils { } } + public static Expression toInPredicateOrEqualTo(Expression reference, Collection<? extends Expression> values) { + if (values.size() < 2) { + return or(values.stream().map(value -> new EqualTo(reference, value)).collect(Collectors.toList())); + } else { + return new InPredicate(reference, ImmutableList.copyOf(values)); + } + } + /** * Use AND/OR to combine expressions together. */ diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java index 13f2789c0a9..34b40efcc29 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java @@ -21,6 +21,7 @@ import org.apache.doris.nereids.rules.expression.rules.AddMinMax; import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule; import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule; import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup; +import org.apache.doris.nereids.rules.expression.rules.InPredicateExtractNonConstant; import org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule; import org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule; import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule; @@ -361,4 +362,18 @@ class ExpressionRewriteTest extends ExpressionRewriteTestHelper { assertRewriteAfterTypeCoercion("TA between 10 and 20 and TB between 10 and 20 or TA between 30 and 40 and TB between 30 and 40 or TA between 60 and 50 and TB between 60 and 50", "(TA <= 20 and TB <= 20 or TA >= 30 and TB >= 30 or TA is null and null and TB is null) and TA >= 10 and TA <= 40 and TB >= 10 and TB <= 40"); } + + @Test + public void testInPredicateExtractNonConstant() { + executor = new ExpressionRuleExecutor(ImmutableList.of( + bottomUp( + InPredicateExtractNonConstant.INSTANCE + ) + )); + + assertRewriteAfterTypeCoercion("TA in (3, 2, 1)", "TA in (3, 2, 1)"); + assertRewriteAfterTypeCoercion("TA in (TB, TC, TB)", "TA = TB or TA = TC"); + assertRewriteAfterTypeCoercion("TA in (3, 2, 1, TB, TC, TB)", "TA in (3, 2, 1) or TA = TB or TA = TC"); + assertRewriteAfterTypeCoercion("IA in (1 + 2, 2 + 3, 3 + TB)", "IA in (cast(1 + 2 as int), cast(2 + 3 as int)) or IA = cast(3 + TB as int)"); + } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstantTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstantTest.java new file mode 100644 index 00000000000..60b511f4475 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstantTest.java @@ -0,0 +1,47 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.nereids.rules.expression.rules; + +import org.apache.doris.nereids.sqltest.SqlTestBase; +import org.apache.doris.nereids.util.PlanChecker; + +import org.junit.jupiter.api.Test; + +class InPredicateExtractNonConstantTest extends SqlTestBase { + @Test + public void testExtractNonConstant() { + connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION"); + String sql = "select * from T1 where id in (score, score, score + 100)"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalFilter().when(f -> f.getPredicate().toString().equals( + "OR[(id#0 = score#1),(id#0 = (score#1 + 100))]" + ))); + + sql = "select * from T1 where id in (score, score + 10, score + score, score, 10, 20, 30, 100 + 200)"; + PlanChecker.from(connectContext) + .analyze(sql) + .rewrite() + .matches( + logicalFilter().when(f -> f.getPredicate().toString().equals( + "OR[id#0 IN (20, 10, 300, 30),(id#0 = score#1),(id#0 = (score#1 + 10)),(id#0 = (score#1 + score#1))]" + ))); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org