This is an automated email from the ASF dual-hosted git repository.

starocean999 pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new 565edd9d133 [feature](nereids) in predicate extract non constant 
expressions (#46794)
565edd9d133 is described below

commit 565edd9d13353f481093731d08915e9ef77e803b
Author: yujun <yu...@selectdb.com>
AuthorDate: Mon Jan 13 17:29:57 2025 +0800

    [feature](nereids) in predicate extract non constant expressions (#46794)
    
    Problem Summary:
    if an in predicate contains non-literal, backend process it will reduce
    performance. so we need to extract the non constant from the in
    predicate.
    
    this pr add an expression rewrite rule InPredicateExtractNonConstant, it
    will extract all the non-constant out of the in predicate. for example:
    
    ```
    k1  in (k2,  k3 + 3,   1, 2, 3 + 3)  => k1 in (1, 2, 3 + 3) or k1 = k2 or 
k1 = k3 + 1
    ```
---
 .../rules/expression/ExpressionNormalization.java  |  2 +
 .../rules/expression/ExpressionRuleType.java       |  1 +
 .../rules/expression/rules/InPredicateDedup.java   | 16 ++---
 .../rules/InPredicateExtractNonConstant.java       | 77 ++++++++++++++++++++++
 .../nereids/rules/expression/rules/OrToIn.java     | 19 +-----
 .../rules/expression/rules/RangeInference.java     |  3 +-
 .../rules/expression/rules/SimplifyRange.java      | 20 +-----
 .../apache/doris/nereids/util/ExpressionUtils.java |  8 +++
 .../rules/expression/ExpressionRewriteTest.java    | 15 +++++
 .../rules/InPredicateExtractNonConstantTest.java   | 47 +++++++++++++
 10 files changed, 164 insertions(+), 44 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
index b4430d33087..135f80111e9 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionNormalization.java
@@ -22,6 +22,7 @@ import 
org.apache.doris.nereids.rules.expression.rules.ConvertAggStateCast;
 import org.apache.doris.nereids.rules.expression.rules.DigitalMaskingConvert;
 import org.apache.doris.nereids.rules.expression.rules.FoldConstantRule;
 import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup;
+import 
org.apache.doris.nereids.rules.expression.rules.InPredicateExtractNonConstant;
 import 
org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule;
 import org.apache.doris.nereids.rules.expression.rules.MedianConvert;
 import org.apache.doris.nereids.rules.expression.rules.MergeDateTrunc;
@@ -47,6 +48,7 @@ public class ExpressionNormalization extends 
ExpressionRewrite {
                 SupportJavaDateFormatter.INSTANCE,
                 NormalizeBinaryPredicatesRule.INSTANCE,
                 InPredicateDedup.INSTANCE,
+                InPredicateExtractNonConstant.INSTANCE,
                 InPredicateToEqualToRule.INSTANCE,
                 SimplifyNotExprRule.INSTANCE,
                 SimplifyArithmeticRule.INSTANCE,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
index bc12c0459ee..7f83ab8a090 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/ExpressionRuleType.java
@@ -35,6 +35,7 @@ public enum ExpressionRuleType {
     FOLD_CONSTANT_ON_BE,
     FOLD_CONSTANT_ON_FE,
     IN_PREDICATE_DEDUP,
+    IN_PREDICATE_EXTRACT_NON_CONSTANT,
     IN_PREDICATE_TO_EQUAL_TO,
     LIKE_TO_EQUAL,
     MERGE_DATE_TRUNC,
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java
index aaa822ac691..1be5971f6a2 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateDedup.java
@@ -36,29 +36,29 @@ import java.util.Set;
 public class InPredicateDedup implements ExpressionPatternRuleFactory {
     public static final InPredicateDedup INSTANCE = new InPredicateDedup();
 
+    // In many BI scenarios, the sql is auto-generated, and hence there may be 
thousands of options.
+    // It takes a long time to apply this rule. So set a threshold for the max 
number.
+    public static final int REWRITE_OPTIONS_MAX_SIZE = 200;
+
     @Override
     public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
         return ImmutableList.of(
-            matchesType(InPredicate.class).then(InPredicateDedup::dedup)
+            matchesType(InPredicate.class)
+                    .when(inPredicate -> inPredicate.getOptions().size() <= 
REWRITE_OPTIONS_MAX_SIZE)
+                    .then(InPredicateDedup::dedup)
                     .toRule(ExpressionRuleType.IN_PREDICATE_DEDUP)
         );
     }
 
     /** dedup */
     public static Expression dedup(InPredicate inPredicate) {
-        // In many BI scenarios, the sql is auto-generated, and hence there 
may be thousands of options.
-        // It takes a long time to apply this rule. So set a threshold for the 
max number.
-        int optionSize = inPredicate.getOptions().size();
-        if (optionSize > 200) {
-            return inPredicate;
-        }
         ImmutableSet.Builder<Expression> newOptionsBuilder = 
ImmutableSet.builderWithExpectedSize(inPredicate.arity());
         for (Expression option : inPredicate.getOptions()) {
             newOptionsBuilder.add(option);
         }
 
         Set<Expression> newOptions = newOptionsBuilder.build();
-        if (newOptions.size() == optionSize) {
+        if (newOptions.size() == inPredicate.getOptions().size()) {
             return inPredicate;
         }
         return new InPredicate(inPredicate.getCompareExpr(), newOptions);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstant.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstant.java
new file mode 100644
index 00000000000..c869dafa0a2
--- /dev/null
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstant.java
@@ -0,0 +1,77 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.rules.expression.ExpressionPatternMatcher;
+import org.apache.doris.nereids.rules.expression.ExpressionPatternRuleFactory;
+import org.apache.doris.nereids.rules.expression.ExpressionRuleType;
+import org.apache.doris.nereids.trees.expressions.EqualTo;
+import org.apache.doris.nereids.trees.expressions.Expression;
+import org.apache.doris.nereids.trees.expressions.InPredicate;
+import org.apache.doris.nereids.util.ExpressionUtils;
+
+import com.google.common.collect.ImmutableList;
+import com.google.common.collect.Sets;
+import org.apache.hadoop.util.Lists;
+
+import java.util.List;
+import java.util.Set;
+import java.util.stream.Collectors;
+
+/**
+ * Extract non-constant of InPredicate, For example:
+ * where k1 in (k2, k3, 10, 20, 30) ==> where k1 in (10, 20, 30) or k1 = k2 or 
k1 = k3.
+ * It's because backend handle in predicate which contains none-constant 
column will reduce performance.
+ */
+public class InPredicateExtractNonConstant implements 
ExpressionPatternRuleFactory {
+    public static final InPredicateExtractNonConstant INSTANCE = new 
InPredicateExtractNonConstant();
+
+    @Override
+    public List<ExpressionPatternMatcher<? extends Expression>> buildRules() {
+        return ImmutableList.of(
+                matchesType(InPredicate.class)
+                        .when(inPredicate -> inPredicate.getOptions().size()
+                                <= InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE)
+                        .then(this::rewrite)
+                        
.toRule(ExpressionRuleType.IN_PREDICATE_EXTRACT_NON_CONSTANT)
+        );
+    }
+
+    private Expression rewrite(InPredicate inPredicate) {
+        Set<Expression> nonConstants = 
Sets.newLinkedHashSetWithExpectedSize(inPredicate.arity());
+        for (Expression option : inPredicate.getOptions()) {
+            if (!option.isConstant()) {
+                nonConstants.add(option);
+            }
+        }
+        if (nonConstants.isEmpty()) {
+            return inPredicate;
+        }
+        Expression key = inPredicate.getCompareExpr();
+        List<Expression> disjunctions = 
Lists.newArrayListWithExpectedSize(inPredicate.getOptions().size());
+        List<Expression> constants = 
inPredicate.getOptions().stream().filter(Expression::isConstant)
+                .collect(Collectors.toList());
+        if (!constants.isEmpty()) {
+            disjunctions.add(ExpressionUtils.toInPredicateOrEqualTo(key, 
constants));
+        }
+        for (Expression option : nonConstants) {
+            disjunctions.add(new EqualTo(key, option));
+        }
+        return ExpressionUtils.or(disjunctions);
+    }
+}
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
index 136b40af584..16006c02080 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/OrToIn.java
@@ -31,11 +31,9 @@ import org.apache.doris.nereids.trees.expressions.Slot;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
-import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Sets;
 
-import java.util.ArrayList;
 import java.util.LinkedHashMap;
 import java.util.LinkedHashSet;
 import java.util.List;
@@ -69,8 +67,6 @@ public class OrToIn {
     public static final OrToIn EXTRACT_MODE_INSTANCE = new 
OrToIn(Mode.extractMode);
     public static final OrToIn REPLACE_MODE_INSTANCE = new 
OrToIn(Mode.replaceMode);
 
-    public static final int REWRITE_OR_TO_IN_PREDICATE_THRESHOLD = 2;
-
     private final Mode mode;
 
     public OrToIn(Mode mode) {
@@ -196,18 +192,9 @@ public class OrToIn {
     }
 
     private Expression candidatesToFinalResult(Map<Expression, Set<Literal>> 
candidates) {
-        List<Expression> conjuncts = new ArrayList<>();
-        for (Expression key : candidates.keySet()) {
-            Set<Literal> literals = candidates.get(key);
-            if (literals.size() < REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
-                for (Literal literal : literals) {
-                    conjuncts.add(new EqualTo(key, literal));
-                }
-            } else {
-                conjuncts.add(new InPredicate(key, 
ImmutableList.copyOf(literals)));
-            }
-        }
-        return ExpressionUtils.and(conjuncts);
+        return ExpressionUtils.and(candidates.entrySet().stream()
+                .map(entry -> 
ExpressionUtils.toInPredicateOrEqualTo(entry.getKey(), entry.getValue()))
+                .collect(Collectors.toList()));
     }
 
     /*
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
index c78ec7a75fb..7c23ce36a3d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
@@ -110,7 +110,8 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
     @Override
     public ValueDesc visitInPredicate(InPredicate inPredicate, 
ExpressionRewriteContext context) {
         // only handle `NumericType` and `DateLikeType`
-        if (ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions())
+        if (inPredicate.getOptions().size() <= 
InPredicateDedup.REWRITE_OPTIONS_MAX_SIZE
+                && 
ExpressionUtils.isAllNonNullLiteral(inPredicate.getOptions())
                 && (ExpressionUtils.matchNumericType(inPredicate.getOptions())
                 || 
ExpressionUtils.matchDateLikeType(inPredicate.getOptions()))) {
             return ValueDesc.discrete(context, inPredicate);
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
index 64891882f7d..6ac69f1eb56 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
@@ -27,14 +27,11 @@ import 
org.apache.doris.nereids.rules.expression.rules.RangeInference.RangeValue
 import 
org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue;
 import 
org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc;
 import org.apache.doris.nereids.trees.expressions.CompoundPredicate;
-import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.GreaterThan;
 import org.apache.doris.nereids.trees.expressions.GreaterThanEqual;
-import org.apache.doris.nereids.trees.expressions.InPredicate;
 import org.apache.doris.nereids.trees.expressions.LessThan;
 import org.apache.doris.nereids.trees.expressions.LessThanEqual;
-import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
@@ -45,9 +42,7 @@ import com.google.common.collect.Lists;
 import com.google.common.collect.Range;
 import org.apache.commons.lang3.NotImplementedException;
 
-import java.util.Iterator;
 import java.util.List;
-import java.util.Set;
 
 /**
  * This class implements the function to simplify expression range.
@@ -133,20 +128,7 @@ public class SimplifyRange implements 
ExpressionPatternRuleFactory {
     }
 
     private Expression getExpression(DiscreteValue value) {
-        Expression reference = value.getReference();
-        Set<Literal> values = value.getValues();
-        // NOTICE: it's related with `InPredicateToEqualToRule`
-        // They are same processes, so must change synchronously.
-        if (values.size() == 1) {
-            return new EqualTo(reference, values.iterator().next());
-
-            // this condition should as same as OrToIn, or else meet dead loop
-        } else if (values.size() < 
OrToIn.REWRITE_OR_TO_IN_PREDICATE_THRESHOLD) {
-            Iterator<Literal> iterator = values.iterator();
-            return new Or(new EqualTo(reference, iterator.next()), new 
EqualTo(reference, iterator.next()));
-        } else {
-            return new InPredicate(reference, Lists.newArrayList(values));
-        }
+        return ExpressionUtils.toInPredicateOrEqualTo(value.getReference(), 
value.getValues());
     }
 
     private Expression getExpression(UnknownValue value) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 723224409bf..3d8aef2c842 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -276,6 +276,14 @@ public class ExpressionUtils {
         }
     }
 
+    public static Expression toInPredicateOrEqualTo(Expression reference, 
Collection<? extends Expression> values) {
+        if (values.size() < 2) {
+            return or(values.stream().map(value -> new EqualTo(reference, 
value)).collect(Collectors.toList()));
+        } else {
+            return new InPredicate(reference, ImmutableList.copyOf(values));
+        }
+    }
+
     /**
      * Use AND/OR to combine expressions together.
      */
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
index 13f2789c0a9..34b40efcc29 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/ExpressionRewriteTest.java
@@ -21,6 +21,7 @@ import 
org.apache.doris.nereids.rules.expression.rules.AddMinMax;
 import org.apache.doris.nereids.rules.expression.rules.DistinctPredicatesRule;
 import org.apache.doris.nereids.rules.expression.rules.ExtractCommonFactorRule;
 import org.apache.doris.nereids.rules.expression.rules.InPredicateDedup;
+import 
org.apache.doris.nereids.rules.expression.rules.InPredicateExtractNonConstant;
 import 
org.apache.doris.nereids.rules.expression.rules.InPredicateToEqualToRule;
 import 
org.apache.doris.nereids.rules.expression.rules.NormalizeBinaryPredicatesRule;
 import org.apache.doris.nereids.rules.expression.rules.SimplifyCastRule;
@@ -361,4 +362,18 @@ class ExpressionRewriteTest extends 
ExpressionRewriteTestHelper {
         assertRewriteAfterTypeCoercion("TA between 10 and 20 and TB between 10 
and 20 or TA between 30 and 40 and TB between 30 and 40 or TA between 60 and 50 
and TB between 60 and 50",
                 "(TA <= 20 and TB <= 20 or TA >= 30 and TB >= 30 or TA is null 
and null and TB is null) and TA >= 10 and TA <= 40 and TB >= 10 and TB <= 40");
     }
+
+    @Test
+    public void testInPredicateExtractNonConstant() {
+        executor = new ExpressionRuleExecutor(ImmutableList.of(
+                bottomUp(
+                        InPredicateExtractNonConstant.INSTANCE
+                )
+        ));
+
+        assertRewriteAfterTypeCoercion("TA in (3, 2, 1)", "TA in (3, 2, 1)");
+        assertRewriteAfterTypeCoercion("TA in (TB, TC, TB)", "TA = TB or TA = 
TC");
+        assertRewriteAfterTypeCoercion("TA in (3, 2, 1, TB, TC, TB)", "TA in 
(3, 2, 1) or TA = TB or TA = TC");
+        assertRewriteAfterTypeCoercion("IA in (1 + 2, 2 + 3, 3 + TB)", "IA in 
(cast(1 + 2 as int), cast(2 + 3 as int)) or IA = cast(3 + TB as int)");
+    }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstantTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstantTest.java
new file mode 100644
index 00000000000..60b511f4475
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/rules/InPredicateExtractNonConstantTest.java
@@ -0,0 +1,47 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.nereids.rules.expression.rules;
+
+import org.apache.doris.nereids.sqltest.SqlTestBase;
+import org.apache.doris.nereids.util.PlanChecker;
+
+import org.junit.jupiter.api.Test;
+
+class InPredicateExtractNonConstantTest extends SqlTestBase {
+    @Test
+    public void testExtractNonConstant() {
+        
connectContext.getSessionVariable().setDisableNereidsRules("PRUNE_EMPTY_PARTITION");
+        String sql = "select * from T1 where id in (score, score, score + 
100)";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(
+                        logicalFilter().when(f -> 
f.getPredicate().toString().equals(
+                                "OR[(id#0 = score#1),(id#0 = (score#1 + 100))]"
+                        )));
+
+        sql = "select * from T1 where id in (score,  score + 10, score + 
score, score, 10, 20, 30, 100 + 200)";
+        PlanChecker.from(connectContext)
+                .analyze(sql)
+                .rewrite()
+                .matches(
+                        logicalFilter().when(f -> 
f.getPredicate().toString().equals(
+                                "OR[id#0 IN (20, 10, 300, 30),(id#0 = 
score#1),(id#0 = (score#1 + 10)),(id#0 = (score#1 + score#1))]"
+                )));
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to