This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new bbf8a819de1 [opt](nereids) opt range inference for or expression when out of order (#46303) bbf8a819de1 is described below commit bbf8a819de19d36a31aa057e6f10c60c2466eb58 Author: yujun <yu...@selectdb.com> AuthorDate: Thu Jan 9 12:06:45 2025 +0800 [opt](nereids) opt range inference for or expression when out of order (#46303) ### What problem does this PR solve? Problem Summary: For range inference, it will merge multiple value desc whose reference are the same. It will merge two value desc step by step. Diff merge order may get diff result. For range Inference: `x1 op x2 op x3 op x4` If op is `AND`, then the merge order doesn't matter. It will always get the same result. But if op is `OR`, then the merge order does matter. For example: `(a < 10) or ( a > 30) or (a >= 15 and a <= 35)`. When merge the first OP, it will get an UnknownValue: and its source is: `[ (-00, 10), (30, +00) ]`, latter will merge this UnknowValue with RangeValue `[15, 35]`. Since UnknowValue union another value desc will get a new UnknownValue, then then final result is UknownValue(UnknowValue(RangeValue(`a<10`) or RangeValue(`a>30`)) or RangeValue(`a>=15 and a <= 35`)). This is bad. It should merge the 1st and 3rd value desc firstly, latter merge the 2nd value desc, Then finally the merge result is 'TRUE'. In order to achieve this, use a RangeSet to record all the ranges, then RangeSet will auto merge the results. What's more, this pr also: 1. opt 'a > 20 or a = 20' to 'a >= 20'; 2. for the discrete value's options, if an option is in one range, then the option will eliminate. for example: `a <= 10 or a in [1, 2, 3, 11, 12, 13]` will opt to `a <= 10 or a in [11, 12, 13]`; 3. delete toExpr in RangeInference; --- .../rules/expression/rules/RangeInference.java | 197 +++++++++++---------- .../rules/expression/rules/SimplifyRange.java | 23 +-- .../rules/expression/SimplifyRangeTest.java | 63 +++++-- .../apache/doris/nereids/sqltest/InferTest.java | 4 +- 4 files changed, 171 insertions(+), 116 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java index 247856578c2..c78ec7a75fb 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/RangeInference.java @@ -34,15 +34,17 @@ import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.ExpressionVisitor; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; import com.google.common.collect.Multimap; import com.google.common.collect.Multimaps; import com.google.common.collect.Range; +import com.google.common.collect.RangeSet; import com.google.common.collect.Sets; +import com.google.common.collect.TreeRangeSet; import java.util.ArrayList; -import java.util.Arrays; import java.util.Collection; import java.util.LinkedHashMap; import java.util.List; @@ -118,18 +120,17 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, @Override public ValueDesc visitAnd(And and, ExpressionRewriteContext context) { - return simplify(context, and, ExpressionUtils.extractConjunction(and), + return simplify(context, ExpressionUtils.extractConjunction(and), ValueDesc::intersect, true); } @Override public ValueDesc visitOr(Or or, ExpressionRewriteContext context) { - return simplify(context, or, ExpressionUtils.extractDisjunction(or), + return simplify(context, ExpressionUtils.extractDisjunction(or), ValueDesc::union, false); } - private ValueDesc simplify(ExpressionRewriteContext context, - Expression originExpr, List<Expression> predicates, + private ValueDesc simplify(ExpressionRewriteContext context, List<Expression> predicates, BinaryOperator<ValueDesc> op, boolean isAnd) { boolean convertIsNullToEmptyValue = isAnd && predicates.stream().anyMatch(expr -> expr instanceof NullLiteral); @@ -144,7 +145,7 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, // but we don't consider this case here, we should fold IsNull(a) to FALSE using other rule. ValueDesc valueDesc = null; if (convertIsNullToEmptyValue && predicate instanceof IsNull) { - valueDesc = new EmptyValue(context, ((IsNull) predicate).child(), predicate); + valueDesc = new EmptyValue(context, ((IsNull) predicate).child()); } else { valueDesc = predicate.accept(this, context); } @@ -154,7 +155,11 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, List<ValueDesc> valuePerRefs = Lists.newArrayList(); for (Entry<Expression, Collection<ValueDesc>> referenceValues : groupByReference.asMap().entrySet()) { + Expression reference = referenceValues.getKey(); List<ValueDesc> valuePerReference = (List) referenceValues.getValue(); + if (!isAnd) { + valuePerReference = ValueDesc.unionDiscreteAndRange(context, reference, valuePerReference); + } // merge per reference ValueDesc simplifiedValue = valuePerReference.get(0); @@ -170,7 +175,7 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, } // use UnknownValue to wrap different references - return new UnknownValue(context, originExpr, valuePerRefs, isAnd); + return new UnknownValue(context, valuePerRefs, isAnd); } /** @@ -178,12 +183,10 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, */ public abstract static class ValueDesc { ExpressionRewriteContext context; - Expression toExpr; Expression reference; - public ValueDesc(ExpressionRewriteContext context, Expression reference, Expression toExpr) { + public ValueDesc(ExpressionRewriteContext context, Expression reference) { this.context = context; - this.toExpr = toExpr; this.reference = reference; } @@ -191,10 +194,6 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, return reference; } - public Expression getOriginExpr() { - return toExpr; - } - public ExpressionRewriteContext getExpressionRewriteContext() { return context; } @@ -204,16 +203,62 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, /** or */ public static ValueDesc union(ExpressionRewriteContext context, RangeValue range, DiscreteValue discrete, boolean reverseOrder) { - long count = discrete.values.stream().filter(x -> range.range.test(x)).count(); - if (count == discrete.values.size()) { + if (discrete.values.stream().allMatch(x -> range.range.test(x))) { return range; } - Expression toExpr = FoldConstantRuleOnFE.evaluate( - new Or(range.toExpr, discrete.toExpr), context); List<ValueDesc> sourceValues = reverseOrder ? ImmutableList.of(discrete, range) : ImmutableList.of(range, discrete); - return new UnknownValue(context, toExpr, sourceValues, false); + return new UnknownValue(context, sourceValues, false); + } + + /** merge discrete and ranges only, no merge other value desc */ + public static List<ValueDesc> unionDiscreteAndRange(ExpressionRewriteContext context, + Expression reference, List<ValueDesc> valueDescs) { + Set<Literal> discreteValues = Sets.newHashSet(); + for (ValueDesc valueDesc : valueDescs) { + if (valueDesc instanceof DiscreteValue) { + discreteValues.addAll(((DiscreteValue) valueDesc).getValues()); + } + } + + // for 'a > 8 or a = 8', then range (8, +00) can convert to [8, +00) + RangeSet<Literal> rangeSet = TreeRangeSet.create(); + for (ValueDesc valueDesc : valueDescs) { + if (valueDesc instanceof RangeValue) { + Range<Literal> range = ((RangeValue) valueDesc).range; + rangeSet.add(range); + if (range.hasLowerBound() + && range.lowerBoundType() == BoundType.OPEN + && discreteValues.contains(range.lowerEndpoint())) { + rangeSet.add(Range.singleton(range.lowerEndpoint())); + } + if (range.hasUpperBound() + && range.upperBoundType() == BoundType.OPEN + && discreteValues.contains(range.upperEndpoint())) { + rangeSet.add(Range.singleton(range.upperEndpoint())); + } + } + } + + if (!rangeSet.isEmpty()) { + discreteValues.removeIf(x -> rangeSet.contains(x)); + } + + List<ValueDesc> result = Lists.newArrayListWithExpectedSize(valueDescs.size()); + if (!discreteValues.isEmpty()) { + result.add(new DiscreteValue(context, reference, discreteValues)); + } + for (Range<Literal> range : rangeSet.asRanges()) { + result.add(new RangeValue(context, reference, range)); + } + for (ValueDesc valueDesc : valueDescs) { + if (!(valueDesc instanceof DiscreteValue) && !(valueDesc instanceof RangeValue)) { + result.add(valueDesc); + } + } + + return result; } /** intersect */ @@ -221,19 +266,19 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, /** intersect */ public static ValueDesc intersect(ExpressionRewriteContext context, RangeValue range, DiscreteValue discrete) { - DiscreteValue result = new DiscreteValue(context, discrete.reference, discrete.toExpr); - discrete.values.stream().filter(x -> range.range.contains(x)).forEach(result.values::add); - if (!result.values.isEmpty()) { - return result; + Set<Literal> newValues = discrete.values.stream().filter(x -> range.range.contains(x)) + .collect(Collectors.toSet()); + if (newValues.isEmpty()) { + return new EmptyValue(context, range.reference); + } else { + return new DiscreteValue(context, range.reference, newValues); } - Expression originExpr = FoldConstantRuleOnFE.evaluate(new And(range.toExpr, discrete.toExpr), context); - return new EmptyValue(context, range.reference, originExpr); } private static ValueDesc range(ExpressionRewriteContext context, ComparisonPredicate predicate) { Literal value = (Literal) predicate.right(); if (predicate instanceof EqualTo) { - return new DiscreteValue(context, predicate.left(), predicate, value); + return new DiscreteValue(context, predicate.left(), Sets.newHashSet(value)); } Range<Literal> range = null; if (predicate instanceof GreaterThanEqual) { @@ -246,13 +291,13 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, range = Range.lessThan(value); } - return new RangeValue(context, predicate.left(), predicate, range); + return new RangeValue(context, predicate.left(), range); } public static ValueDesc discrete(ExpressionRewriteContext context, InPredicate in) { // Set<Literal> literals = (Set) Utils.fastToImmutableSet(in.getOptions()); Set<Literal> literals = in.getOptions().stream().map(Literal.class::cast).collect(Collectors.toSet()); - return new DiscreteValue(context, in.getCompareExpr(), in, literals); + return new DiscreteValue(context, in.getCompareExpr(), literals); } } @@ -261,8 +306,8 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, */ public static class EmptyValue extends ValueDesc { - public EmptyValue(ExpressionRewriteContext context, Expression reference, Expression toExpr) { - super(context, reference, toExpr); + public EmptyValue(ExpressionRewriteContext context, Expression reference) { + super(context, reference); } @Override @@ -284,9 +329,8 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, public static class RangeValue extends ValueDesc { Range<Literal> range; - public RangeValue(ExpressionRewriteContext context, Expression reference, - Expression toExpr, Range<Literal> range) { - super(context, reference, toExpr); + public RangeValue(ExpressionRewriteContext context, Expression reference, Range<Literal> range) { + super(context, reference); this.range = range; } @@ -300,20 +344,16 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, return other.union(this); } if (other instanceof RangeValue) { - Expression originExpr = FoldConstantRuleOnFE.evaluate(new Or(toExpr, other.toExpr), context); RangeValue o = (RangeValue) other; if (range.isConnected(o.range)) { - return new RangeValue(context, reference, originExpr, range.span(o.range)); + return new RangeValue(context, reference, range.span(o.range)); } - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), false); + return new UnknownValue(context, ImmutableList.of(this, other), false); } if (other instanceof DiscreteValue) { return union(context, this, (DiscreteValue) other, false); } - Expression originExpr = FoldConstantRuleOnFE.evaluate(new Or(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), false); + return new UnknownValue(context, ImmutableList.of(this, other), false); } @Override @@ -322,19 +362,16 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, return other.intersect(this); } if (other instanceof RangeValue) { - Expression originExpr = FoldConstantRuleOnFE.evaluate(new And(toExpr, other.toExpr), context); RangeValue o = (RangeValue) other; if (range.isConnected(o.range)) { - return new RangeValue(context, reference, originExpr, range.intersection(o.range)); + return new RangeValue(context, reference, range.intersection(o.range)); } - return new EmptyValue(context, reference, originExpr); + return new EmptyValue(context, reference); } if (other instanceof DiscreteValue) { return intersect(context, this, (DiscreteValue) other); } - Expression originExpr = FoldConstantRuleOnFE.evaluate(new And(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), true); + return new UnknownValue(context, ImmutableList.of(this, other), true); } @Override @@ -349,17 +386,12 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, * a in (1,2,3) => [1,2,3] */ public static class DiscreteValue extends ValueDesc { - Set<Literal> values; + final Set<Literal> values; public DiscreteValue(ExpressionRewriteContext context, - Expression reference, Expression toExpr, Literal... values) { - this(context, reference, toExpr, Arrays.asList(values)); - } - - public DiscreteValue(ExpressionRewriteContext context, - Expression reference, Expression toExpr, Collection<Literal> values) { - super(context, reference, toExpr); - this.values = Sets.newHashSet(values); + Expression reference, Set<Literal> values) { + super(context, reference); + this.values = values; } public Set<Literal> getValues() { @@ -372,20 +404,15 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, return other.union(this); } if (other instanceof DiscreteValue) { - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.or(toExpr, other.toExpr), context); - DiscreteValue discreteValue = new DiscreteValue(context, reference, originExpr); - discreteValue.values.addAll(((DiscreteValue) other).values); - discreteValue.values.addAll(this.values); - return discreteValue; + Set<Literal> newValues = Sets.newHashSet(); + newValues.addAll(((DiscreteValue) other).values); + newValues.addAll(this.values); + return new DiscreteValue(context, reference, newValues); } if (other instanceof RangeValue) { return union(context, (RangeValue) other, this, true); } - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.or(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), false); + return new UnknownValue(context, ImmutableList.of(this, other), false); } @Override @@ -394,24 +421,19 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, return other.intersect(this); } if (other instanceof DiscreteValue) { - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.and(toExpr, other.toExpr), context); - DiscreteValue discreteValue = new DiscreteValue(context, reference, originExpr); - discreteValue.values.addAll(((DiscreteValue) other).values); - discreteValue.values.retainAll(this.values); - if (discreteValue.values.isEmpty()) { - return new EmptyValue(context, reference, originExpr); + Set<Literal> newValues = Sets.newHashSet(); + newValues.addAll(((DiscreteValue) other).values); + newValues.retainAll(this.values); + if (newValues.isEmpty()) { + return new EmptyValue(context, reference); } else { - return discreteValue; + return new DiscreteValue(context, reference, newValues); } } if (other instanceof RangeValue) { return intersect(context, (RangeValue) other, this); } - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.and(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), true); + return new UnknownValue(context, ImmutableList.of(this, other), true); } @Override @@ -428,14 +450,14 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, private final boolean isAnd; private UnknownValue(ExpressionRewriteContext context, Expression expr) { - super(context, expr, expr); + super(context, expr); sourceValues = ImmutableList.of(); isAnd = false; } - public UnknownValue(ExpressionRewriteContext context, Expression toExpr, + private UnknownValue(ExpressionRewriteContext context, List<ValueDesc> sourceValues, boolean isAnd) { - super(context, getReference(sourceValues, toExpr), toExpr); + super(context, getReference(context, sourceValues, isAnd)); this.sourceValues = ImmutableList.copyOf(sourceValues); this.isAnd = isAnd; } @@ -455,11 +477,12 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, // E union UnknownValue1 = E.union(UnknownValue1) = UnknownValue1, // 2. since E and UnknownValue2's reference not equals, then // E union UnknownValue2 = UnknownValue3(E union UnknownValue2, reference=E union UnknownValue2) - private static Expression getReference(List<ValueDesc> sourceValues, Expression toExpr) { + private static Expression getReference(ExpressionRewriteContext context, + List<ValueDesc> sourceValues, boolean isAnd) { Expression reference = sourceValues.get(0).reference; for (int i = 1; i < sourceValues.size(); i++) { if (!reference.equals(sourceValues.get(i).reference)) { - return toExpr; + return SimplifyRange.INSTANCE.getExpression(context, sourceValues, isAnd); } } return reference; @@ -480,10 +503,7 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, if (other instanceof EmptyValue) { return other.union(this); } - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.or(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), false); + return new UnknownValue(context, ImmutableList.of(this, other), false); } @Override @@ -493,10 +513,7 @@ public class RangeInference extends ExpressionVisitor<RangeInference.ValueDesc, if (other instanceof EmptyValue) { return other.intersect(this); } - Expression originExpr = FoldConstantRuleOnFE.evaluate( - ExpressionUtils.and(toExpr, other.toExpr), context); - return new UnknownValue(context, originExpr, - ImmutableList.of(this, other), true); + return new UnknownValue(context, ImmutableList.of(this, other), true); } } } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java index 576ef6bbf4d..64891882f7d 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/expression/rules/SimplifyRange.java @@ -38,6 +38,7 @@ import org.apache.doris.nereids.trees.expressions.Or; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.util.ExpressionUtils; +import com.google.common.base.Preconditions; import com.google.common.collect.BoundType; import com.google.common.collect.ImmutableList; import com.google.common.collect.Lists; @@ -150,26 +151,28 @@ public class SimplifyRange implements ExpressionPatternRuleFactory { private Expression getExpression(UnknownValue value) { List<ValueDesc> sourceValues = value.getSourceValues(); - Expression originExpr = value.getOriginExpr(); if (sourceValues.isEmpty()) { - return originExpr; + return value.getReference(); + } else { + return getExpression(value.getExpressionRewriteContext(), sourceValues, value.isAnd()); } + } + + /** getExpression */ + public Expression getExpression(ExpressionRewriteContext context, + List<ValueDesc> sourceValues, boolean isAnd) { + Preconditions.checkArgument(!sourceValues.isEmpty()); List<Expression> sourceExprs = Lists.newArrayListWithExpectedSize(sourceValues.size()); for (ValueDesc sourceValue : sourceValues) { Expression expr = getExpression(sourceValue); - if (value.isAnd()) { + if (isAnd) { sourceExprs.addAll(ExpressionUtils.extractConjunction(expr)); } else { sourceExprs.addAll(ExpressionUtils.extractDisjunction(expr)); } } - Expression result = value.isAnd() ? ExpressionUtils.and(sourceExprs) : ExpressionUtils.or(sourceExprs); - result = FoldConstantRuleOnFE.evaluate(result, value.getExpressionRewriteContext()); - // ATTN: we must return original expr, because OrToIn is implemented with MutableState, - // newExpr will lose these states leading to dead loop by OrToIn -> SimplifyRange -> FoldConstantByFE - if (result.equals(originExpr)) { - return originExpr; - } + Expression result = isAnd ? ExpressionUtils.and(sourceExprs) : ExpressionUtils.or(sourceExprs); + result = FoldConstantRuleOnFE.evaluate(result, context); return result; } } diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java index 784600577c3..7393439c5e6 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/rules/expression/SimplifyRangeTest.java @@ -22,6 +22,10 @@ import org.apache.doris.nereids.analyzer.UnboundRelation; import org.apache.doris.nereids.analyzer.UnboundSlot; import org.apache.doris.nereids.parser.NereidsParser; import org.apache.doris.nereids.rules.analysis.ExpressionAnalyzer; +import org.apache.doris.nereids.rules.expression.rules.RangeInference; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.EmptyValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.UnknownValue; +import org.apache.doris.nereids.rules.expression.rules.RangeInference.ValueDesc; import org.apache.doris.nereids.rules.expression.rules.SimplifyRange; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.InPredicate; @@ -60,6 +64,24 @@ public class SimplifyRangeTest extends ExpressionRewrite { context = new ExpressionRewriteContext(cascadesContext); } + @Test + public void testRangeInference() { + ValueDesc valueDesc = getValueDesc("TA IS NULL"); + Assertions.assertInstanceOf(UnknownValue.class, valueDesc); + List<ValueDesc> sourceValues = ((UnknownValue) valueDesc).getSourceValues(); + Assertions.assertEquals(0, sourceValues.size()); + Assertions.assertEquals("TA IS NULL", valueDesc.getReference().toSql()); + + valueDesc = getValueDesc("TA IS NULL AND TB IS NULL AND NULL"); + Assertions.assertInstanceOf(UnknownValue.class, valueDesc); + sourceValues = ((UnknownValue) valueDesc).getSourceValues(); + Assertions.assertEquals(3, sourceValues.size()); + Assertions.assertInstanceOf(EmptyValue.class, sourceValues.get(0)); + Assertions.assertInstanceOf(EmptyValue.class, sourceValues.get(1)); + Assertions.assertEquals("TA", sourceValues.get(0).getReference().toSql()); + Assertions.assertEquals("TB", sourceValues.get(1).getReference().toSql()); + } + @Test public void testSimplify() { executor = new ExpressionRuleExecutor(ImmutableList.of( @@ -69,8 +91,15 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewrite("TA > 3 or TA > null", "TA > 3 OR NULL"); assertRewrite("TA > 3 or TA < null", "TA > 3 OR NULL"); assertRewrite("TA > 3 or TA = null", "TA > 3 OR NULL"); + assertRewrite("TA > 3 or TA = 3 or TA < null", "TA >= 3 OR NULL"); + assertRewrite("TA < 10 or TA in (1, 2, 3, 11, 12, 13)", "TA in (11, 12, 13) OR TA < 10"); + assertRewrite("TA < 10 or TA in (1, 2, 3, 10, 11, 12, 13) or TA > 13 or TA < 10 or TA in (1, 2, 3, 10, 11, 12, 13) or TA > 13", + "TA in (11, 12) OR TA <= 10 OR TA >= 13"); assertRewrite("TA > 3 or TA <> null", "TA > 3 or null"); assertRewrite("TA > 3 or TA <=> null", "TA > 3 or TA <=> null"); + assertRewrite("(TA < 1 or TA > 2) or (TA >= 0 and TA <= 3)", "TA IS NOT NULL OR NULL"); + assertRewrite("TA between 10 and 20 or TA between 100 and 120 or TA between 15 and 25 or TA between 115 and 125", + "TA between 10 and 25 or TA between 100 and 125"); assertRewriteNotNull("TA > 3 and TA > null", "TA > 3 and NULL"); assertRewriteNotNull("TA > 3 and TA < null", "TA > 3 and NULL"); assertRewriteNotNull("TA > 3 and TA = null", "TA > 3 and NULL"); @@ -88,13 +117,13 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewrite("TA >= 3 and TA < 3", "TA >= 3 and TA < 3"); assertRewriteNotNull("TA = 1 and TA > 10", "FALSE"); assertRewrite("TA = 1 and TA > 10", "TA is null and null"); - assertRewrite("TA > 5 or TA < 1", "TA > 5 or TA < 1"); + assertRewrite("TA > 5 or TA < 1", "TA < 1 or TA > 5"); assertRewrite("TA > 5 or TA > 1 or TA > 10", "TA > 1"); assertRewrite("TA > 5 or TA > 1 or TA < 10", "TA is not null or null"); assertRewriteNotNull("TA > 5 or TA > 1 or TA < 10", "TRUE"); assertRewrite("TA > 5 and TA > 1 and TA > 10", "TA > 10"); assertRewrite("TA > 5 and TA > 1 and TA < 10", "TA > 5 and TA < 10"); - assertRewrite("TA > 1 or TA < 1", "TA > 1 or TA < 1"); + assertRewrite("TA > 1 or TA < 1", "TA < 1 or TA > 1"); assertRewrite("TA > 1 or TA < 10", "TA is not null or null"); assertRewriteNotNull("TA > 1 or TA < 10", "TRUE"); assertRewrite("TA > 5 and TA < 10", "TA > 5 and TA < 10"); @@ -109,7 +138,7 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewrite("(TA > 10 or TA > 20) and (TB > 10 and TB > 20)", "TA > 10 and TB > 20"); assertRewrite("((TB > 30 and TA > 40) and TA > 20) and (TB > 10 and TB > 20)", "TB > 30 and TA > 40"); assertRewrite("(TA > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA > 10 and TB > 10 or TB > 20"); - assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))"); + assertRewrite("((TA > 10 or TA > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA > 5 and TB > 10) or (TB > 10 and (TB < 10 or TB > 20))"); assertRewriteNotNull("TA in (1,2,3) and TA > 10", "FALSE"); assertRewrite("TA in (1,2,3) and TA > 10", "TA is null and null"); assertRewrite("TA in (1,2,3) and TA >= 1", "TA in (1,2,3)"); @@ -119,7 +148,7 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewrite("TA in (1,2,3) and TA < 10", "TA in (1,2,3)"); assertRewriteNotNull("TA in (1,2,3) and TA < 1", "FALSE"); assertRewrite("TA in (1,2,3) and TA < 1", "TA is null and null"); - assertRewrite("TA in (1,2,3) or TA < 1", "TA in (1,2,3) or TA < 1"); + assertRewrite("TA in (1,2,3) or TA < 1", "TA in (2,3) or TA <= 1"); assertRewrite("TA in (1,2,3) or TA in (2,3,4)", "TA in (1,2,3,4)"); assertRewrite("TA in (1,2,3) or TA in (4,5,6)", "TA in (1,2,3,4,5,6)"); assertRewrite("TA in (1,2,3) and TA in (4,5,6)", "TA is null and null"); @@ -150,12 +179,12 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewrite("TA + TC >= 3 and TA + TC < 3", "TA + TC >= 3 and TA + TC < 3"); assertRewriteNotNull("TA + TC = 1 and TA + TC > 10", "FALSE"); assertRewrite("TA + TC = 1 and TA + TC > 10", "(TA + TC) is null and null"); - assertRewrite("TA + TC > 5 or TA + TC < 1", "TA + TC > 5 or TA + TC < 1"); + assertRewrite("TA + TC > 5 or TA + TC < 1", "TA + TC < 1 or TA + TC > 5"); assertRewrite("TA + TC > 5 or TA + TC > 1 or TA + TC > 10", "TA + TC > 1"); assertRewrite("TA + TC > 5 or TA + TC > 1 or TA + TC < 10", "(TA + TC) is not null or null"); assertRewrite("TA + TC > 5 and TA + TC > 1 and TA + TC > 10", "TA + TC > 10"); assertRewrite("TA + TC > 5 and TA + TC > 1 and TA + TC < 10", "TA + TC > 5 and TA + TC < 10"); - assertRewrite("TA + TC > 1 or TA + TC < 1", "TA + TC > 1 or TA + TC < 1"); + assertRewrite("TA + TC > 1 or TA + TC < 1", "TA + TC < 1 or TA + TC > 1"); assertRewrite("TA + TC > 1 or TA + TC < 10", "(TA + TC) is not null or null"); assertRewrite("TA + TC > 5 and TA + TC < 10", "TA + TC > 5 and TA + TC < 10"); assertRewrite("TA + TC > 5 and TA + TC > 10", "TA + TC > 10"); @@ -168,7 +197,7 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewrite("(TA + TC > 10 or TA + TC > 20) and (TB > 10 and TB > 20)", "TA + TC > 10 and TB > 20"); assertRewrite("((TB > 30 and TA + TC > 40) and TA + TC > 20) and (TB > 10 and TB > 20)", "TB > 30 and TA + TC > 40"); assertRewrite("(TA + TC > 10 and TB > 10) or (TB > 10 and TB > 20)", "TA + TC > 10 and TB > 10 or TB > 20"); - assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA + TC > 5 and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))"); + assertRewrite("((TA + TC > 10 or TA + TC > 5) and TB > 10) or (TB > 10 and (TB > 20 or TB < 10))", "(TA + TC > 5 and TB > 10) or (TB > 10 and (TB < 10 or TB > 20))"); assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC > 10", "FALSE"); assertRewrite("TA + TC in (1,2,3) and TA + TC > 10", "(TA + TC) is null and null"); assertRewrite("TA + TC in (1,2,3) and TA + TC >= 1", "TA + TC in (1,2,3)"); @@ -178,7 +207,7 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewrite("TA + TC in (1,2,3) and TA + TC < 10", "TA + TC in (1,2,3)"); assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC < 1", "FALSE"); assertRewrite("TA + TC in (1,2,3) and TA + TC < 1", "(TA + TC) is null and null"); - assertRewrite("TA + TC in (1,2,3) or TA + TC < 1", "TA + TC in (1,2,3) or TA + TC < 1"); + assertRewrite("TA + TC in (1,2,3) or TA + TC < 1", "TA + TC in (2,3) or TA + TC <= 1"); assertRewrite("TA + TC in (1,2,3) or TA + TC in (2,3,4)", "TA + TC in (1,2,3,4)"); assertRewrite("TA + TC in (1,2,3) or TA + TC in (4,5,6)", "TA + TC in (1,2,3,4,5,6)"); assertRewriteNotNull("TA + TC in (1,2,3) and TA + TC in (4,5,6)", "FALSE"); @@ -221,7 +250,7 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewriteNotNull("AA = date '2024-01-01' and AA > date '2024-01-10'", "FALSE"); assertRewrite("AA = date '2024-01-01' and AA > date '2024-01-10'", "AA is null and null"); assertRewrite("AA > date '2024-01-05' or AA < date '2024-01-01'", - "AA > date '2024-01-05' or AA < date '2024-01-01'"); + "AA < date '2024-01-01' or AA > date '2024-01-05'"); assertRewrite("AA > date '2024-01-05' or AA > date '2024-01-01' or AA > date '2024-01-10'", "AA > date '2024-01-01'"); assertRewrite("AA > date '2024-01-05' or AA > date '2024-01-01' or AA < date '2024-01-10'", "AA is not null or null"); @@ -231,7 +260,7 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewrite("AA > date '2024-01-05' and AA > date '2024-01-01' and AA < date '2024-01-10'", "AA > date '2024-01-05' and AA < date '2024-01-10'"); assertRewrite("AA > date '2024-01-05' or AA < date '2024-01-05'", - "AA > date '2024-01-05' or AA < date '2024-01-05'"); + "AA < date '2024-01-05' or AA > date '2024-01-05'"); assertRewrite("AA > date '2024-01-01' or AA < date '2024-01-10'", "AA is not null or null"); assertRewriteNotNull("AA > date '2024-01-01' or AA < date '2024-01-10'", "TRUE"); assertRewrite("AA > date '2024-01-05' and AA < date '2024-01-10'", @@ -261,7 +290,7 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewrite("AA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') and AA < date '2024-01-01'", "AA is null and null"); assertRewrite("AA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') or AA < date '2024-01-01'", - "AA in (date '2024-01-01',date '2024-01-02',date '2024-01-03') or AA < date '2024-01-01'"); + "AA in (date '2024-01-02',date '2024-01-03') or AA <= date '2024-01-01'"); assertRewrite("AA in (date '2024-01-01',date '2024-01-02') or AA in (date '2024-01-02', date '2024-01-03')", "AA in (date '2024-01-01',date '2024-01-02',date '2024-01-03')"); assertRewriteNotNull("AA in (date '2024-01-01',date '2024-01-02') and AA in (date '2024-01-03', date '2024-01-04')", @@ -301,7 +330,7 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewriteNotNull("CA = timestamp '2024-01-01 10:00:10' and CA > timestamp '2024-01-10 00:00:10'", "FALSE"); assertRewrite("CA = timestamp '2024-01-01 10:00:10' and CA > timestamp '2024-01-10 00:00:10'", "CA is null and null"); assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA < timestamp '2024-01-01 00:00:10'", - "CA > timestamp '2024-01-05 00:00:10' or CA < timestamp '2024-01-01 00:00:10'"); + "CA < timestamp '2024-01-01 00:00:10' or CA > timestamp '2024-01-05 00:00:10'"); assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA > timestamp '2024-01-01 00:00:10' or CA > timestamp '2024-01-10 00:00:10'", "CA > timestamp '2024-01-01 00:00:10'"); assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA > timestamp '2024-01-01 00:00:10' or CA < timestamp '2024-01-10 00:00:10'", "CA is not null or null"); @@ -311,7 +340,7 @@ public class SimplifyRangeTest extends ExpressionRewrite { assertRewrite("CA > timestamp '2024-01-05 00:00:10' and CA > timestamp '2024-01-01 00:00:10' and CA < timestamp '2024-01-10 00:00:10'", "CA > timestamp '2024-01-05 00:00:10' and CA < timestamp '2024-01-10 00:00:10'"); assertRewrite("CA > timestamp '2024-01-05 00:00:10' or CA < timestamp '2024-01-05 00:00:10'", - "CA > timestamp '2024-01-05 00:00:10' or CA < timestamp '2024-01-05 00:00:10'"); + "CA < timestamp '2024-01-05 00:00:10' or CA > timestamp '2024-01-05 00:00:10'"); assertRewrite("CA > timestamp '2024-01-01 00:02:10' or CA < timestamp '2024-01-10 00:02:10'", "CA is not null or null"); assertRewriteNotNull("CA > timestamp '2024-01-01 00:00:00' or CA < timestamp '2024-01-10 00:00:00'", "TRUE"); assertRewrite("CA > timestamp '2024-01-05 01:00:00' and CA < timestamp '2024-01-10 01:00:00'", @@ -364,7 +393,13 @@ public class SimplifyRangeTest extends ExpressionRewrite { "(CA is null and null) OR CB < timestamp '2024-01-05 00:50:00'"); } - @Test + private ValueDesc getValueDesc(String expression) { + Map<String, Slot> mem = Maps.newHashMap(); + Expression parseExpression = replaceUnboundSlot(PARSER.parseExpression(expression), mem); + parseExpression = typeCoercion(parseExpression); + return (new RangeInference()).getValue(parseExpression, context); + } + private void assertRewrite(String expression, String expected) { Map<String, Slot> mem = Maps.newHashMap(); Expression needRewriteExpression = replaceUnboundSlot(PARSER.parseExpression(expression), mem); diff --git a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java index 3d88c131c97..cdc36164ae9 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/nereids/sqltest/InferTest.java @@ -58,7 +58,7 @@ public class InferTest extends SqlTestBase { f -> ExpressionUtils.and(f.getConjuncts().stream() .sorted((a, b) -> a.toString().compareTo(b.toString())) .collect(Collectors.toList())) - .toString().equals("AND[(id#0 >= 4),OR[(id#0 = 4),(id#0 > 4)]]")) + .toString().equals("(id#0 >= 4)")) ) ); @@ -76,7 +76,7 @@ public class InferTest extends SqlTestBase { logicalFilter( leftOuterLogicalJoin( logicalFilter().when( - f -> f.getPredicate().toString().equals("AND[(id#0 >= 4),OR[(id#0 = 4),(id#0 > 4)]]")), + f -> f.getPredicate().toString().equals("(id#0 >= 4)")), logicalFilter().when( f -> f.getPredicate().toString().equals("(id#2 >= 4)") ) --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org