This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new bbf8a819de1 [opt](nereids) opt range inference for or expression when 
out of order (#46303)
bbf8a819de1 is described below

commit bbf8a819de19d36a31aa057e6f10c60c2466eb58
Author: yujun <yu...@selectdb.com>
AuthorDate: Thu Jan 9 12:06:45 2025 +0800

    [opt](nereids) opt range inference for or expression when out of order 
(#46303)
    
    ### What problem does this PR solve?
    
    Problem Summary:
    
    For range inference, it will merge multiple value desc whose reference
    are the same. It will merge two value desc step by step. Diff merge
    order may get diff result.
    
    For range Inference: `x1 op x2 op x3 op x4`
    
    If op is `AND`, then the merge order doesn't matter. It will always get
    the same result.
    
    But if op is `OR`, then the merge order does matter. For example: `(a <
    10) or ( a > 30) or (a >= 15 and a <= 35)`. When merge the first OP, it
    will get an UnknownValue: and its source is: `[ (-00, 10), (30, +00) ]`,
    latter will merge this UnknowValue with RangeValue `[15, 35]`. Since
    UnknowValue union another value desc will get a new UnknownValue, then
    then final result is UknownValue(UnknowValue(RangeValue(`a<10`) or
    RangeValue(`a>30`)) or RangeValue(`a>=15 and a <= 35`)). This is bad. It
    should merge the 1st and 3rd value desc firstly, latter merge the 2nd
    value desc, Then finally the merge result is 'TRUE'.
    
    In order to achieve this, use a RangeSet to record all the ranges, then
    RangeSet will auto merge the results.
    
    What's more, this pr also:
    1.  opt  'a > 20 or a = 20'  to  'a >= 20';
    2. for the discrete value's options, if an option is in one range, then
    the option will eliminate. for example: `a <= 10 or a in [1, 2, 3, 11,
    12, 13]` will opt to `a <= 10 or a in [11, 12, 13]`;
    3. delete toExpr in RangeInference;
---
 .../rules/expression/rules/RangeInference.java     | 197 +++++++++++----------
 .../rules/expression/rules/SimplifyRange.java      |  23 +--
 .../rules/expression/SimplifyRangeTest.java        |  63 +++++--
 .../apache/doris/nereids/sqltest/InferTest.java    |   4 +-
 4 files changed, 171 insertions(+), 116 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
index 247856578c2..c78ec7a75fb 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java
@@ -34,15 +34,17 @@ import 
org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
 import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
+import com.google.common.collect.BoundType;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
 import com.google.common.collect.Multimap;
 import com.google.common.collect.Multimaps;
 import com.google.common.collect.Range;
+import com.google.common.collect.RangeSet;
 import com.google.common.collect.Sets;
+import com.google.common.collect.TreeRangeSet;
 
 import java.util.ArrayList;
-import java.util.Arrays;
 import java.util.Collection;
 import java.util.LinkedHashMap;
 import java.util.List;
@@ -118,18 +120,17 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
 
     @Override
     public ValueDesc visitAnd(And and, ExpressionRewriteContext context) {
-        return simplify(context, and, ExpressionUtils.extractConjunction(and),
+        return simplify(context, ExpressionUtils.extractConjunction(and),
                 ValueDesc::intersect, true);
     }
 
     @Override
     public ValueDesc visitOr(Or or, ExpressionRewriteContext context) {
-        return simplify(context, or, ExpressionUtils.extractDisjunction(or),
+        return simplify(context, ExpressionUtils.extractDisjunction(or),
                 ValueDesc::union, false);
     }
 
-    private ValueDesc simplify(ExpressionRewriteContext context,
-            Expression originExpr, List<Expression> predicates,
+    private ValueDesc simplify(ExpressionRewriteContext context, 
List<Expression> predicates,
             BinaryOperator<ValueDesc> op, boolean isAnd) {
 
         boolean convertIsNullToEmptyValue = isAnd && 
predicates.stream().anyMatch(expr -> expr instanceof NullLiteral);
@@ -144,7 +145,7 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
             // but we don't consider this case here, we should fold IsNull(a) 
to FALSE using other rule.
             ValueDesc valueDesc = null;
             if (convertIsNullToEmptyValue && predicate instanceof IsNull) {
-                valueDesc = new EmptyValue(context, ((IsNull) 
predicate).child(), predicate);
+                valueDesc = new EmptyValue(context, ((IsNull) 
predicate).child());
             } else {
                 valueDesc = predicate.accept(this, context);
             }
@@ -154,7 +155,11 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
 
         List<ValueDesc> valuePerRefs = Lists.newArrayList();
         for (Entry<Expression, Collection<ValueDesc>> referenceValues : 
groupByReference.asMap().entrySet()) {
+            Expression reference = referenceValues.getKey();
             List<ValueDesc> valuePerReference = (List) 
referenceValues.getValue();
+            if (!isAnd) {
+                valuePerReference = ValueDesc.unionDiscreteAndRange(context, 
reference, valuePerReference);
+            }
 
             // merge per reference
             ValueDesc simplifiedValue = valuePerReference.get(0);
@@ -170,7 +175,7 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
         }
 
         // use UnknownValue to wrap different references
-        return new UnknownValue(context, originExpr, valuePerRefs, isAnd);
+        return new UnknownValue(context, valuePerRefs, isAnd);
     }
 
     /**
@@ -178,12 +183,10 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
      */
     public abstract static class ValueDesc {
         ExpressionRewriteContext context;
-        Expression toExpr;
         Expression reference;
 
-        public ValueDesc(ExpressionRewriteContext context, Expression 
reference, Expression toExpr) {
+        public ValueDesc(ExpressionRewriteContext context, Expression 
reference) {
             this.context = context;
-            this.toExpr = toExpr;
             this.reference = reference;
         }
 
@@ -191,10 +194,6 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
             return reference;
         }
 
-        public Expression getOriginExpr() {
-            return toExpr;
-        }
-
         public ExpressionRewriteContext getExpressionRewriteContext() {
             return context;
         }
@@ -204,16 +203,62 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
         /** or */
         public static ValueDesc union(ExpressionRewriteContext context,
                 RangeValue range, DiscreteValue discrete, boolean 
reverseOrder) {
-            long count = discrete.values.stream().filter(x -> 
range.range.test(x)).count();
-            if (count == discrete.values.size()) {
+            if (discrete.values.stream().allMatch(x -> range.range.test(x))) {
                 return range;
             }
-            Expression toExpr = FoldConstantRuleOnFE.evaluate(
-                    new Or(range.toExpr, discrete.toExpr), context);
             List<ValueDesc> sourceValues = reverseOrder
                     ? ImmutableList.of(discrete, range)
                     : ImmutableList.of(range, discrete);
-            return new UnknownValue(context, toExpr, sourceValues, false);
+            return new UnknownValue(context, sourceValues, false);
+        }
+
+        /** merge discrete and ranges only, no merge other value desc */
+        public static List<ValueDesc> 
unionDiscreteAndRange(ExpressionRewriteContext context,
+                Expression reference, List<ValueDesc> valueDescs) {
+            Set<Literal> discreteValues = Sets.newHashSet();
+            for (ValueDesc valueDesc : valueDescs) {
+                if (valueDesc instanceof DiscreteValue) {
+                    discreteValues.addAll(((DiscreteValue) 
valueDesc).getValues());
+                }
+            }
+
+            // for 'a > 8 or a = 8', then range (8, +00) can convert to [8, 
+00)
+            RangeSet<Literal> rangeSet = TreeRangeSet.create();
+            for (ValueDesc valueDesc : valueDescs) {
+                if (valueDesc instanceof RangeValue) {
+                    Range<Literal> range = ((RangeValue) valueDesc).range;
+                    rangeSet.add(range);
+                    if (range.hasLowerBound()
+                            && range.lowerBoundType() == BoundType.OPEN
+                            && discreteValues.contains(range.lowerEndpoint())) 
{
+                        rangeSet.add(Range.singleton(range.lowerEndpoint()));
+                    }
+                    if (range.hasUpperBound()
+                            && range.upperBoundType() == BoundType.OPEN
+                            && discreteValues.contains(range.upperEndpoint())) 
{
+                        rangeSet.add(Range.singleton(range.upperEndpoint()));
+                    }
+                }
+            }
+
+            if (!rangeSet.isEmpty()) {
+                discreteValues.removeIf(x -> rangeSet.contains(x));
+            }
+
+            List<ValueDesc> result = 
Lists.newArrayListWithExpectedSize(valueDescs.size());
+            if (!discreteValues.isEmpty()) {
+                result.add(new DiscreteValue(context, reference, 
discreteValues));
+            }
+            for (Range<Literal> range : rangeSet.asRanges()) {
+                result.add(new RangeValue(context, reference, range));
+            }
+            for (ValueDesc valueDesc : valueDescs) {
+                if (!(valueDesc instanceof DiscreteValue) && !(valueDesc 
instanceof RangeValue)) {
+                    result.add(valueDesc);
+                }
+            }
+
+            return result;
         }
 
         /** intersect */
@@ -221,19 +266,19 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
 
         /** intersect */
         public static ValueDesc intersect(ExpressionRewriteContext context, 
RangeValue range, DiscreteValue discrete) {
-            DiscreteValue result = new DiscreteValue(context, 
discrete.reference, discrete.toExpr);
-            discrete.values.stream().filter(x -> 
range.range.contains(x)).forEach(result.values::add);
-            if (!result.values.isEmpty()) {
-                return result;
+            Set<Literal> newValues = discrete.values.stream().filter(x -> 
range.range.contains(x))
+                    .collect(Collectors.toSet());
+            if (newValues.isEmpty()) {
+                return new EmptyValue(context, range.reference);
+            } else {
+                return new DiscreteValue(context, range.reference, newValues);
             }
-            Expression originExpr = FoldConstantRuleOnFE.evaluate(new 
And(range.toExpr, discrete.toExpr), context);
-            return new EmptyValue(context, range.reference, originExpr);
         }
 
         private static ValueDesc range(ExpressionRewriteContext context, 
ComparisonPredicate predicate) {
             Literal value = (Literal) predicate.right();
             if (predicate instanceof EqualTo) {
-                return new DiscreteValue(context, predicate.left(), predicate, 
value);
+                return new DiscreteValue(context, predicate.left(), 
Sets.newHashSet(value));
             }
             Range<Literal> range = null;
             if (predicate instanceof GreaterThanEqual) {
@@ -246,13 +291,13 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
                 range = Range.lessThan(value);
             }
 
-            return new RangeValue(context, predicate.left(), predicate, range);
+            return new RangeValue(context, predicate.left(), range);
         }
 
         public static ValueDesc discrete(ExpressionRewriteContext context, 
InPredicate in) {
             // Set<Literal> literals = (Set) 
Utils.fastToImmutableSet(in.getOptions());
             Set<Literal> literals = 
in.getOptions().stream().map(Literal.class::cast).collect(Collectors.toSet());
-            return new DiscreteValue(context, in.getCompareExpr(), in, 
literals);
+            return new DiscreteValue(context, in.getCompareExpr(), literals);
         }
     }
 
@@ -261,8 +306,8 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
      */
     public static class EmptyValue extends ValueDesc {
 
-        public EmptyValue(ExpressionRewriteContext context, Expression 
reference, Expression toExpr) {
-            super(context, reference, toExpr);
+        public EmptyValue(ExpressionRewriteContext context, Expression 
reference) {
+            super(context, reference);
         }
 
         @Override
@@ -284,9 +329,8 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
     public static class RangeValue extends ValueDesc {
         Range<Literal> range;
 
-        public RangeValue(ExpressionRewriteContext context, Expression 
reference,
-                Expression toExpr, Range<Literal> range) {
-            super(context, reference, toExpr);
+        public RangeValue(ExpressionRewriteContext context, Expression 
reference, Range<Literal> range) {
+            super(context, reference);
             this.range = range;
         }
 
@@ -300,20 +344,16 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
                 return other.union(this);
             }
             if (other instanceof RangeValue) {
-                Expression originExpr = FoldConstantRuleOnFE.evaluate(new 
Or(toExpr, other.toExpr), context);
                 RangeValue o = (RangeValue) other;
                 if (range.isConnected(o.range)) {
-                    return new RangeValue(context, reference, originExpr, 
range.span(o.range));
+                    return new RangeValue(context, reference, 
range.span(o.range));
                 }
-                return new UnknownValue(context, originExpr,
-                        ImmutableList.of(this, other), false);
+                return new UnknownValue(context, ImmutableList.of(this, 
other), false);
             }
             if (other instanceof DiscreteValue) {
                 return union(context, this, (DiscreteValue) other, false);
             }
-            Expression originExpr = FoldConstantRuleOnFE.evaluate(new 
Or(toExpr, other.toExpr), context);
-            return new UnknownValue(context, originExpr,
-                    ImmutableList.of(this, other), false);
+            return new UnknownValue(context, ImmutableList.of(this, other), 
false);
         }
 
         @Override
@@ -322,19 +362,16 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
                 return other.intersect(this);
             }
             if (other instanceof RangeValue) {
-                Expression originExpr = FoldConstantRuleOnFE.evaluate(new 
And(toExpr, other.toExpr), context);
                 RangeValue o = (RangeValue) other;
                 if (range.isConnected(o.range)) {
-                    return new RangeValue(context, reference, originExpr, 
range.intersection(o.range));
+                    return new RangeValue(context, reference, 
range.intersection(o.range));
                 }
-                return new EmptyValue(context, reference, originExpr);
+                return new EmptyValue(context, reference);
             }
             if (other instanceof DiscreteValue) {
                 return intersect(context, this, (DiscreteValue) other);
             }
-            Expression originExpr = FoldConstantRuleOnFE.evaluate(new 
And(toExpr, other.toExpr), context);
-            return new UnknownValue(context, originExpr,
-                    ImmutableList.of(this, other), true);
+            return new UnknownValue(context, ImmutableList.of(this, other), 
true);
         }
 
         @Override
@@ -349,17 +386,12 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
      * a in (1,2,3) => [1,2,3]
      */
     public static class DiscreteValue extends ValueDesc {
-        Set<Literal> values;
+        final Set<Literal> values;
 
         public DiscreteValue(ExpressionRewriteContext context,
-                Expression reference, Expression toExpr, Literal... values) {
-            this(context, reference, toExpr, Arrays.asList(values));
-        }
-
-        public DiscreteValue(ExpressionRewriteContext context,
-                Expression reference, Expression toExpr, Collection<Literal> 
values) {
-            super(context, reference, toExpr);
-            this.values = Sets.newHashSet(values);
+                Expression reference, Set<Literal> values) {
+            super(context, reference);
+            this.values = values;
         }
 
         public Set<Literal> getValues() {
@@ -372,20 +404,15 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
                 return other.union(this);
             }
             if (other instanceof DiscreteValue) {
-                Expression originExpr = FoldConstantRuleOnFE.evaluate(
-                        ExpressionUtils.or(toExpr, other.toExpr), context);
-                DiscreteValue discreteValue = new DiscreteValue(context, 
reference, originExpr);
-                discreteValue.values.addAll(((DiscreteValue) other).values);
-                discreteValue.values.addAll(this.values);
-                return discreteValue;
+                Set<Literal> newValues = Sets.newHashSet();
+                newValues.addAll(((DiscreteValue) other).values);
+                newValues.addAll(this.values);
+                return new DiscreteValue(context, reference, newValues);
             }
             if (other instanceof RangeValue) {
                 return union(context, (RangeValue) other, this, true);
             }
-            Expression originExpr = FoldConstantRuleOnFE.evaluate(
-                    ExpressionUtils.or(toExpr, other.toExpr), context);
-            return new UnknownValue(context, originExpr,
-                    ImmutableList.of(this, other), false);
+            return new UnknownValue(context, ImmutableList.of(this, other), 
false);
         }
 
         @Override
@@ -394,24 +421,19 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
                 return other.intersect(this);
             }
             if (other instanceof DiscreteValue) {
-                Expression originExpr = FoldConstantRuleOnFE.evaluate(
-                        ExpressionUtils.and(toExpr, other.toExpr), context);
-                DiscreteValue discreteValue = new DiscreteValue(context, 
reference, originExpr);
-                discreteValue.values.addAll(((DiscreteValue) other).values);
-                discreteValue.values.retainAll(this.values);
-                if (discreteValue.values.isEmpty()) {
-                    return new EmptyValue(context, reference, originExpr);
+                Set<Literal> newValues = Sets.newHashSet();
+                newValues.addAll(((DiscreteValue) other).values);
+                newValues.retainAll(this.values);
+                if (newValues.isEmpty()) {
+                    return new EmptyValue(context, reference);
                 } else {
-                    return discreteValue;
+                    return new DiscreteValue(context, reference, newValues);
                 }
             }
             if (other instanceof RangeValue) {
                 return intersect(context, (RangeValue) other, this);
             }
-            Expression originExpr = FoldConstantRuleOnFE.evaluate(
-                    ExpressionUtils.and(toExpr, other.toExpr), context);
-            return new UnknownValue(context, originExpr,
-                    ImmutableList.of(this, other), true);
+            return new UnknownValue(context, ImmutableList.of(this, other), 
true);
         }
 
         @Override
@@ -428,14 +450,14 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
         private final boolean isAnd;
 
         private UnknownValue(ExpressionRewriteContext context, Expression 
expr) {
-            super(context, expr, expr);
+            super(context, expr);
             sourceValues = ImmutableList.of();
             isAnd = false;
         }
 
-        public UnknownValue(ExpressionRewriteContext context, Expression 
toExpr,
+        private UnknownValue(ExpressionRewriteContext context,
                 List<ValueDesc> sourceValues, boolean isAnd) {
-            super(context, getReference(sourceValues, toExpr), toExpr);
+            super(context, getReference(context, sourceValues, isAnd));
             this.sourceValues = ImmutableList.copyOf(sourceValues);
             this.isAnd = isAnd;
         }
@@ -455,11 +477,12 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
         //    E union UnknownValue1 = E.union(UnknownValue1) = UnknownValue1,
         // 2. since E and UnknownValue2's reference not equals, then
         //    E union UnknownValue2 = UnknownValue3(E union UnknownValue2, 
reference=E union UnknownValue2)
-        private static Expression getReference(List<ValueDesc> sourceValues, 
Expression toExpr) {
+        private static Expression getReference(ExpressionRewriteContext 
context,
+                List<ValueDesc> sourceValues, boolean isAnd) {
             Expression reference = sourceValues.get(0).reference;
             for (int i = 1; i < sourceValues.size(); i++) {
                 if (!reference.equals(sourceValues.get(i).reference)) {
-                    return toExpr;
+                    return SimplifyRange.INSTANCE.getExpression(context, 
sourceValues, isAnd);
                 }
             }
             return reference;
@@ -480,10 +503,7 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
             if (other instanceof EmptyValue) {
                 return other.union(this);
             }
-            Expression originExpr = FoldConstantRuleOnFE.evaluate(
-                    ExpressionUtils.or(toExpr, other.toExpr), context);
-            return new UnknownValue(context, originExpr,
-                    ImmutableList.of(this, other), false);
+            return new UnknownValue(context, ImmutableList.of(this, other), 
false);
         }
 
         @Override
@@ -493,10 +513,7 @@ public class RangeInference extends 
ExpressionVisitor<RangeInference.ValueDesc,
             if (other instanceof EmptyValue) {
                 return other.intersect(this);
             }
-            Expression originExpr = FoldConstantRuleOnFE.evaluate(
-                    ExpressionUtils.and(toExpr, other.toExpr), context);
-            return new UnknownValue(context, originExpr,
-                    ImmutableList.of(this, other), true);
+            return new UnknownValue(context, ImmutableList.of(this, other), 
true);
         }
     }
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
index 576ef6bbf4d..64891882f7d 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java
@@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.Or;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.util.ExpressionUtils;
 
+import com.google.common.base.Preconditions;
 import com.google.common.collect.BoundType;
 import com.google.common.collect.ImmutableList;
 import com.google.common.collect.Lists;
@@ -150,26 +151,28 @@ public class SimplifyRange implements 
ExpressionPatternRuleFactory {
 
     private Expression getExpression(UnknownValue value) {
         List<ValueDesc> sourceValues = value.getSourceValues();
-        Expression originExpr = value.getOriginExpr();
         if (sourceValues.isEmpty()) {
-            return originExpr;
+            return value.getReference();
+        } else {
+            return getExpression(value.getExpressionRewriteContext(), 
sourceValues, value.isAnd());
         }
+    }
+
+    /** getExpression */
+    public Expression getExpression(ExpressionRewriteContext context,
+            List<ValueDesc> sourceValues, boolean isAnd) {
+        Preconditions.checkArgument(!sourceValues.isEmpty());
         List<Expression> sourceExprs = 
Lists.newArrayListWithExpectedSize(sourceValues.size());
         for (ValueDesc sourceValue : sourceValues) {
             Expression expr = getExpression(sourceValue);
-            if (value.isAnd()) {
+            if (isAnd) {
                 sourceExprs.addAll(ExpressionUtils.extractConjunction(expr));
             } else {
                 sourceExprs.addAll(ExpressionUtils.extractDisjunction(expr));
             }
         }
-        Expression result = value.isAnd() ? ExpressionUtils.and(sourceExprs) : 
ExpressionUtils.or(sourceExprs);
-        result = FoldConstantRuleOnFE.evaluate(result, 
value.getExpressionRewriteContext());
-        // ATTN: we must return original expr, because OrToIn is implemented 
with MutableState,
-        //   newExpr will lose these states leading to dead loop by OrToIn -> 
SimplifyRange -> FoldConstantByFE
-        if (result.equals(originExpr)) {
-            return originExpr;
-        }
+        Expression result = isAnd ? ExpressionUtils.and(sourceExprs) : 
ExpressionUtils.or(sourceExprs);
+        result = FoldConstantRuleOnFE.evaluate(result, context);
         return result;
     }
 }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
index 784600577c3..7393439c5e6 100644
--- 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java
@@ -22,6 +22,10 @@ import org.apache.doris.nereids.analyzer.UnboundRelation;
 import org.apache.doris.nereids.analyzer.UnboundSlot;
 import org.apache.doris.nereids.parser.NereidsParser;
 import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer;
+import org.apache.doris.nereids.rules.expression.rules.RangeInference;
+import 
org.apache.doris.nereids.rules.expression.rules.RangeInference.EmptyValue;
+import 
org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue;
+import 
org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc;
 import org.apache.doris.nereids.rules.expression.rules.SimplifyRange;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.InPredicate;
@@ -60,6 +64,24 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         context = new ExpressionRewriteContext(cascadesContext);
     }
 
+    @Test
+    public void testRangeInference() {
+        ValueDesc valueDesc = getValueDesc("TA IS NULL");
+        Assertions.assertInstanceOf(UnknownValue.class, valueDesc);
+        List<ValueDesc> sourceValues = ((UnknownValue) 
valueDesc).getSourceValues();
+        Assertions.assertEquals(0, sourceValues.size());
+        Assertions.assertEquals("TA IS NULL", 
valueDesc.getReference().toSql());
+
+        valueDesc = getValueDesc("TA IS NULL AND TB IS NULL AND NULL");
+        Assertions.assertInstanceOf(UnknownValue.class, valueDesc);
+        sourceValues = ((UnknownValue) valueDesc).getSourceValues();
+        Assertions.assertEquals(3, sourceValues.size());
+        Assertions.assertInstanceOf(EmptyValue.class, sourceValues.get(0));
+        Assertions.assertInstanceOf(EmptyValue.class, sourceValues.get(1));
+        Assertions.assertEquals("TA", 
sourceValues.get(0).getReference().toSql());
+        Assertions.assertEquals("TB", 
sourceValues.get(1).getReference().toSql());
+    }
+
     @Test
     public void testSimplify() {
         executor = new ExpressionRuleExecutor(ImmutableList.of(
@@ -69,8 +91,15 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewrite("TA > 3 or TA > null", "TA > 3 OR NULL");
         assertRewrite("TA > 3 or TA < null", "TA > 3 OR NULL");
         assertRewrite("TA > 3 or TA = null", "TA > 3 OR NULL");
+        assertRewrite("TA > 3 or TA = 3 or TA < null", "TA >= 3 OR NULL");
+        assertRewrite("TA < 10 or TA in (1, 2, 3, 11, 12, 13)", "TA in (11, 
12, 13) OR TA < 10");
+        assertRewrite("TA < 10 or TA in (1, 2, 3, 10, 11, 12, 13) or TA > 13 
or TA < 10 or TA in (1, 2, 3, 10, 11, 12, 13) or TA > 13",
+                "TA in (11, 12) OR TA <= 10 OR TA >= 13");
         assertRewrite("TA > 3 or TA <> null", "TA > 3 or null");
         assertRewrite("TA > 3 or TA <=> null", "TA > 3 or TA <=> null");
+        assertRewrite("(TA < 1 or TA > 2) or (TA >= 0 and TA <= 3)", "TA IS 
NOT NULL OR NULL");
+        assertRewrite("TA between 10 and 20 or TA between 100 and 120 or TA 
between 15 and 25 or TA between 115 and 125",
+                "TA between 10 and 25 or TA between 100 and 125");
         assertRewriteNotNull("TA > 3 and TA > null", "TA > 3 and NULL");
         assertRewriteNotNull("TA > 3 and TA < null", "TA > 3 and NULL");
         assertRewriteNotNull("TA > 3 and TA = null", "TA > 3 and NULL");
@@ -88,13 +117,13 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewrite("TA >= 3 and TA < 3", "TA >= 3 and TA < 3");
         assertRewriteNotNull("TA = 1 and TA > 10", "FALSE");
         assertRewrite("TA = 1 and TA > 10", "TA is null and null");
-        assertRewrite("TA > 5 or TA < 1", "TA > 5 or TA < 1");
+        assertRewrite("TA > 5 or TA < 1", "TA < 1 or TA > 5");
         assertRewrite("TA > 5 or TA > 1 or TA > 10", "TA > 1");
         assertRewrite("TA > 5 or TA > 1 or TA < 10", "TA is not null or null");
         assertRewriteNotNull("TA > 5 or TA > 1 or TA < 10", "TRUE");
         assertRewrite("TA > 5 and TA > 1 and TA > 10", "TA > 10");
         assertRewrite("TA > 5 and TA > 1 and TA < 10", "TA > 5 and TA < 10");
-        assertRewrite("TA > 1 or TA < 1", "TA > 1 or TA < 1");
+        assertRewrite("TA > 1 or TA < 1", "TA < 1 or TA > 1");
         assertRewrite("TA > 1 or TA < 10", "TA is not null or null");
         assertRewriteNotNull("TA > 1 or TA < 10", "TRUE");
         assertRewrite("TA > 5 and TA < 10", "TA > 5 and TA < 10");
@@ -109,7 +138,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewrite("(TA > 10 or TA > 20) and (TB > 10 and TB > 20)", "TA > 
10 and TB > 20");
         assertRewrite("((TB > 30 and TA > 40) and TA > 20) and (TB > 10 and TB 
> 20)", "TB > 30 and TA > 40");
         assertRewrite("(TA > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA > 
10 and TB > 10 or TB > 20");
-        assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 
20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))");
+        assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 
20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB < 10 or TB > 20))");
         assertRewriteNotNull("TA in (1,2,3) and TA > 10", "FALSE");
         assertRewrite("TA in (1,2,3) and TA > 10", "TA is null and null");
         assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)");
@@ -119,7 +148,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)");
         assertRewriteNotNull("TA in (1,2,3) and TA < 1", "FALSE");
         assertRewrite("TA in (1,2,3) and TA < 1", "TA is null and null");
-        assertRewrite("TA in (1,2,3) or TA < 1", "TA in (1,2,3) or TA < 1");
+        assertRewrite("TA in (1,2,3) or TA < 1", "TA in (2,3) or TA <= 1");
         assertRewrite("TA in (1,2,3) or TA in (2,3,4)", "TA in (1,2,3,4)");
         assertRewrite("TA in (1,2,3) or TA in (4,5,6)", "TA in (1,2,3,4,5,6)");
         assertRewrite("TA in (1,2,3) and TA in (4,5,6)", "TA is null and 
null");
@@ -150,12 +179,12 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewrite("TA + TC >= 3 and TA + TC < 3", "TA + TC >= 3 and TA + 
TC < 3");
         assertRewriteNotNull("TA + TC = 1 and TA + TC > 10", "FALSE");
         assertRewrite("TA + TC = 1 and TA + TC > 10", "(TA + TC) is null and 
null");
-        assertRewrite("TA + TC > 5 or TA + TC < 1", "TA + TC > 5 or TA + TC < 
1");
+        assertRewrite("TA + TC > 5 or TA + TC < 1", "TA + TC < 1 or TA + TC > 
5");
         assertRewrite("TA + TC > 5 or TA + TC > 1 or TA + TC > 10", "TA + TC > 
1");
         assertRewrite("TA + TC > 5 or TA + TC > 1 or TA + TC < 10", "(TA + TC) 
is not null or null");
         assertRewrite("TA + TC > 5 and TA + TC > 1 and TA + TC > 10", "TA + TC 
> 10");
         assertRewrite("TA + TC > 5 and TA + TC > 1 and TA + TC < 10", "TA + TC 
> 5 and TA + TC < 10");
-        assertRewrite("TA + TC > 1 or TA + TC < 1", "TA + TC > 1 or TA + TC < 
1");
+        assertRewrite("TA + TC > 1 or TA + TC < 1", "TA + TC < 1 or TA + TC > 
1");
         assertRewrite("TA + TC > 1 or TA + TC < 10", "(TA + TC) is not null or 
null");
         assertRewrite("TA + TC > 5 and TA + TC < 10", "TA + TC > 5 and TA + TC 
< 10");
         assertRewrite("TA + TC > 5 and TA + TC > 10", "TA + TC > 10");
@@ -168,7 +197,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewrite("(TA + TC > 10 or TA + TC > 20) and (TB > 10 and TB > 
20)", "TA + TC > 10 and TB > 20");
         assertRewrite("((TB > 30 and TA + TC > 40) and TA + TC > 20) and (TB > 
10 and TB > 20)", "TB > 30 and TA + TC > 40");
         assertRewrite("(TA + TC > 10 and TB > 10) or (TB > 10 and TB > 20)", 
"TA + TC > 10 and TB > 10 or TB > 20");
-        assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10 
and (TB > 20 or TB < 10))", "(TA + TC > 5 and TB > 10) or (TB > 10 and (TB > 20 
or TB < 10))");
+        assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10 
and (TB > 20 or TB < 10))", "(TA + TC > 5 and TB > 10) or (TB > 10 and (TB < 10 
or TB > 20))");
         assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC > 10", "FALSE");
         assertRewrite("TA + TC in (1,2,3) and TA + TC > 10", "(TA + TC) is 
null and null");
         assertRewrite("TA + TC in (1,2,3) and TA + TC >= 1", "TA + TC in 
(1,2,3)");
@@ -178,7 +207,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewrite("TA + TC in (1,2,3) and TA + TC < 10", "TA + TC in 
(1,2,3)");
         assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC < 1", "FALSE");
         assertRewrite("TA + TC in (1,2,3) and TA + TC < 1", "(TA + TC) is null 
and null");
-        assertRewrite("TA + TC in (1,2,3) or TA + TC < 1", "TA + TC in (1,2,3) 
or TA + TC < 1");
+        assertRewrite("TA + TC in (1,2,3) or TA + TC < 1", "TA + TC in (2,3) 
or TA + TC <= 1");
         assertRewrite("TA + TC in (1,2,3) or TA + TC in (2,3,4)", "TA + TC in 
(1,2,3,4)");
         assertRewrite("TA + TC in (1,2,3) or TA + TC in (4,5,6)", "TA + TC in 
(1,2,3,4,5,6)");
         assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC in (4,5,6)", 
"FALSE");
@@ -221,7 +250,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewriteNotNull("AA = date '2024-01-01' and AA > date 
'2024-01-10'", "FALSE");
         assertRewrite("AA = date '2024-01-01' and AA > date '2024-01-10'", "AA 
is null and null");
         assertRewrite("AA > date '2024-01-05' or AA < date '2024-01-01'",
-                "AA > date '2024-01-05' or AA < date '2024-01-01'");
+                "AA < date '2024-01-01' or AA > date '2024-01-05'");
         assertRewrite("AA > date '2024-01-05' or AA > date '2024-01-01' or AA 
> date '2024-01-10'",
                 "AA > date '2024-01-01'");
         assertRewrite("AA > date '2024-01-05' or AA > date '2024-01-01' or AA 
< date '2024-01-10'", "AA is not null or null");
@@ -231,7 +260,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewrite("AA > date '2024-01-05' and AA > date '2024-01-01' and 
AA < date '2024-01-10'",
                 "AA > date '2024-01-05' and AA < date '2024-01-10'");
         assertRewrite("AA > date '2024-01-05' or AA < date '2024-01-05'",
-                "AA > date '2024-01-05' or AA < date '2024-01-05'");
+                "AA < date '2024-01-05' or AA > date '2024-01-05'");
         assertRewrite("AA > date '2024-01-01' or AA < date '2024-01-10'", "AA 
is not null or null");
         assertRewriteNotNull("AA > date '2024-01-01' or AA < date 
'2024-01-10'", "TRUE");
         assertRewrite("AA > date '2024-01-05' and AA < date '2024-01-10'",
@@ -261,7 +290,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewrite("AA in (date '2024-01-01',date '2024-01-02',date 
'2024-01-03') and AA < date '2024-01-01'",
                 "AA is null and null");
         assertRewrite("AA in (date '2024-01-01',date '2024-01-02',date 
'2024-01-03') or AA < date '2024-01-01'",
-                "AA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') 
or AA < date '2024-01-01'");
+                "AA in (date '2024-01-02',date '2024-01-03') or AA <= date 
'2024-01-01'");
         assertRewrite("AA in (date '2024-01-01',date '2024-01-02') or AA in 
(date '2024-01-02', date '2024-01-03')",
                 "AA in (date '2024-01-01',date '2024-01-02',date 
'2024-01-03')");
         assertRewriteNotNull("AA in (date '2024-01-01',date '2024-01-02') and 
AA in (date '2024-01-03', date '2024-01-04')",
@@ -301,7 +330,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewriteNotNull("CA = timestamp '2024-01-01 10:00:10' and CA > 
timestamp '2024-01-10 00:00:10'", "FALSE");
         assertRewrite("CA = timestamp '2024-01-01 10:00:10' and CA > timestamp 
'2024-01-10 00:00:10'", "CA is null and null");
         assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA < timestamp 
'2024-01-01 00:00:10'",
-                "CA > timestamp '2024-01-05 00:00:10' or CA < timestamp 
'2024-01-01 00:00:10'");
+                "CA < timestamp '2024-01-01 00:00:10' or CA > timestamp 
'2024-01-05 00:00:10'");
         assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA > timestamp 
'2024-01-01 00:00:10' or CA > timestamp '2024-01-10 00:00:10'",
                 "CA > timestamp '2024-01-01 00:00:10'");
         assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA > timestamp 
'2024-01-01 00:00:10' or CA < timestamp '2024-01-10 00:00:10'", "CA is not null 
or null");
@@ -311,7 +340,7 @@ public class SimplifyRangeTest extends ExpressionRewrite {
         assertRewrite("CA > timestamp '2024-01-05 00:00:10' and CA > timestamp 
'2024-01-01 00:00:10' and CA < timestamp '2024-01-10 00:00:10'",
                 "CA > timestamp '2024-01-05 00:00:10' and CA < timestamp 
'2024-01-10 00:00:10'");
         assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA < timestamp 
'2024-01-05 00:00:10'",
-                "CA > timestamp '2024-01-05 00:00:10' or CA < timestamp 
'2024-01-05 00:00:10'");
+                "CA < timestamp '2024-01-05 00:00:10' or CA > timestamp 
'2024-01-05 00:00:10'");
         assertRewrite("CA > timestamp '2024-01-01 00:02:10' or CA < timestamp 
'2024-01-10 00:02:10'", "CA is not null or null");
         assertRewriteNotNull("CA > timestamp '2024-01-01 00:00:00' or CA < 
timestamp '2024-01-10 00:00:00'", "TRUE");
         assertRewrite("CA > timestamp '2024-01-05 01:00:00' and CA < timestamp 
'2024-01-10 01:00:00'",
@@ -364,7 +393,13 @@ public class SimplifyRangeTest extends ExpressionRewrite {
                 "(CA is null and null) OR CB < timestamp '2024-01-05 
00:50:00'");
     }
 
-    @Test
+    private ValueDesc getValueDesc(String expression) {
+        Map<String, Slot> mem = Maps.newHashMap();
+        Expression parseExpression = 
replaceUnboundSlot(PARSER.parseExpression(expression), mem);
+        parseExpression = typeCoercion(parseExpression);
+        return (new RangeInference()).getValue(parseExpression, context);
+    }
+
     private void assertRewrite(String expression, String expected) {
         Map<String, Slot> mem = Maps.newHashMap();
         Expression needRewriteExpression = 
replaceUnboundSlot(PARSER.parseExpression(expression), mem);
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java
index 3d88c131c97..cdc36164ae9 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java
@@ -58,7 +58,7 @@ public class InferTest extends SqlTestBase {
                                 f -> 
ExpressionUtils.and(f.getConjuncts().stream()
                                         .sorted((a, b) -> 
a.toString().compareTo(b.toString()))
                                         .collect(Collectors.toList()))
-                                        .toString().equals("AND[(id#0 >= 
4),OR[(id#0 = 4),(id#0 > 4)]]"))
+                                        .toString().equals("(id#0 >= 4)"))
                     )
 
                 );
@@ -76,7 +76,7 @@ public class InferTest extends SqlTestBase {
                         logicalFilter(
                             leftOuterLogicalJoin(
                                 logicalFilter().when(
-                                        f -> 
f.getPredicate().toString().equals("AND[(id#0 >= 4),OR[(id#0 = 4),(id#0 > 
4)]]")),
+                                        f -> 
f.getPredicate().toString().equals("(id#0 >= 4)")),
                                 logicalFilter().when(
                                         f -> 
f.getPredicate().toString().equals("(id#2 >= 4)")
                                 )


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org


Reply via email to