This is an automated email from the ASF dual-hosted git repository.

morrysnow pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new ae66464d6b0 [refactor](Nereids) refactor infer predicate rule to avoid 
lost cast (#25637)
ae66464d6b0 is described below

commit ae66464d6b039e34771fa330ea85194849d43c43
Author: morrySnow <101034200+morrys...@users.noreply.github.com>
AuthorDate: Wed Oct 25 14:12:22 2023 +0800

    [refactor](Nereids) refactor infer predicate rule to avoid lost cast 
(#25637)
    
    extract slot and literal in comparison predicate. infer new one by equals 
predicates.
    use TypeCoercion to add cast on new comparison predicate to ensure it is 
correct.
    
    This reverts "[Fix](Nereids) Add cast comparison with slot reference when 
inferring predicate (#21171)"
    commit 58f2593ba1b65713e7b3c1ed39fc84be8cc3ff2c.
---
 .../rules/rewrite/PredicatePropagation.java        | 223 ++++++++++++++++-----
 .../apache/doris/nereids/util/ExpressionUtils.java |  29 ---
 .../infer_predicate/infer_predicate.groovy         |   2 +-
 3 files changed, 170 insertions(+), 84 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
index d5323fc58a8..9341c3db5c1 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java
@@ -17,19 +17,28 @@
 
 package org.apache.doris.nereids.rules.rewrite;
 
+import org.apache.doris.nereids.parser.NereidsParser;
+import 
org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.Cast;
 import org.apache.doris.nereids.trees.expressions.ComparisonPredicate;
 import org.apache.doris.nereids.trees.expressions.EqualTo;
 import org.apache.doris.nereids.trees.expressions.Expression;
 import org.apache.doris.nereids.trees.expressions.SlotReference;
-import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
+import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral;
 import org.apache.doris.nereids.types.DataType;
+import org.apache.doris.nereids.types.DateTimeType;
+import org.apache.doris.nereids.types.DateTimeV2Type;
+import org.apache.doris.nereids.types.DateType;
+import org.apache.doris.nereids.types.DateV2Type;
+import org.apache.doris.nereids.types.coercion.CharacterType;
+import org.apache.doris.nereids.types.coercion.DateLikeType;
 import org.apache.doris.nereids.types.coercion.IntegralType;
-import org.apache.doris.nereids.util.ExpressionUtils;
+import org.apache.doris.nereids.util.TypeCoercionUtils;
 
 import com.google.common.collect.Sets;
 
-import java.util.List;
+import java.util.Objects;
+import java.util.Optional;
 import java.util.Set;
 import java.util.stream.Collectors;
 
@@ -40,19 +49,61 @@ import java.util.stream.Collectors;
  */
 public class PredicatePropagation {
 
+    private enum InferType {
+        NONE(null),
+        INTEGRAL(IntegralType.class),
+        STRING(CharacterType.class),
+        DATE(DateLikeType.class),
+        OTHER(DataType.class)
+        ;
+
+        private final Class<? extends DataType> superClazz;
+
+        InferType(Class<? extends DataType> superClazz) {
+            this.superClazz = superClazz;
+        }
+    }
+
+    private class ComparisonInferInfo {
+
+        public final InferType inferType;
+        public final Optional<Expression> left;
+        public final Optional<Expression> right;
+        public final ComparisonPredicate comparisonPredicate;
+
+        public ComparisonInferInfo(InferType inferType,
+                Optional<Expression> left, Optional<Expression> right,
+                ComparisonPredicate comparisonPredicate) {
+            this.inferType = inferType;
+            this.left = left;
+            this.right = right;
+            this.comparisonPredicate = comparisonPredicate;
+        }
+    }
+
     /**
      * infer additional predicates.
      */
     public Set<Expression> infer(Set<Expression> predicates) {
         Set<Expression> inferred = Sets.newHashSet();
         for (Expression predicate : predicates) {
-            if (canEquivalentInfer(predicate)) {
-                List<Expression> newInferred = predicates.stream()
-                        .filter(p -> !p.equals(predicate))
-                        .map(p -> doInfer(predicate, p))
-                        .collect(Collectors.toList());
-                inferred.addAll(newInferred);
+            if (!(predicate instanceof ComparisonPredicate)) {
+                continue;
+            }
+            ComparisonInferInfo equalInfo = 
getEquivalentInferInfo((ComparisonPredicate) predicate);
+            if (equalInfo.inferType == InferType.NONE) {
+                continue;
             }
+            Set<Expression> newInferred = predicates.stream()
+                    .filter(ComparisonPredicate.class::isInstance)
+                    .filter(p -> !p.equals(predicate))
+                    .map(ComparisonPredicate.class::cast)
+                    .map(this::inferInferInfo)
+                    .filter(predicateInfo -> predicateInfo.inferType != 
InferType.NONE)
+                    .map(predicateInfo -> doInfer(equalInfo, predicateInfo))
+                    .filter(Objects::nonNull)
+                    .collect(Collectors.toSet());
+            inferred.addAll(newInferred);
         }
         inferred.removeAll(predicates);
         return inferred;
@@ -64,64 +115,128 @@ public class PredicatePropagation {
      * TODO: We should determine whether `expression` satisfies the condition 
for replacement
      *       eg: Satisfy `expression` is non-deterministic
      */
-    private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression 
expression) {
-        return expression.accept(new DefaultExpressionRewriter<Void>() {
+    private Expression doInfer(ComparisonInferInfo equalInfo, 
ComparisonInferInfo predicateInfo) {
+        Expression predicateLeft = predicateInfo.left.get();
+        Expression predicateRight = predicateInfo.right.get();
+        Expression equalLeft = equalInfo.left.get();
+        Expression equalRight = equalInfo.right.get();
+        Expression newLeft = inferOneSide(predicateLeft, equalLeft, 
equalRight);
+        Expression newRight = inferOneSide(predicateRight, equalLeft, 
equalRight);
+        if (newLeft == null || newRight == null) {
+            return null;
+        }
+        ComparisonPredicate newPredicate = (ComparisonPredicate) predicateInfo
+                .comparisonPredicate.withChildren(newLeft, newRight);
+        return SimplifyComparisonPredicate.INSTANCE
+                
.rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null);
+    }
 
-            @Override
-            public Expression visit(Expression expr, Void context) {
-                return expr;
+    private Expression inferOneSide(Expression predicateOneSide, Expression 
equalLeft, Expression equalRight) {
+        if (predicateOneSide instanceof SlotReference) {
+            if (predicateOneSide.equals(equalLeft)) {
+                return equalRight;
+            } else if (predicateOneSide.equals(equalRight)) {
+                return equalLeft;
             }
-
-            @Override
-            public Expression visitComparisonPredicate(ComparisonPredicate cp, 
Void context) {
-                // we need to get expression covered by cast, because we want 
to infer different datatype
-                if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) 
&& (cp.right().isConstant())) {
-                    return replaceSlot(cp, 
ExpressionUtils.getDatatypeCoveredByCast(cp.left()));
-                } else if 
(ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && 
cp.left().isConstant()) {
-                    return replaceSlot(cp, 
ExpressionUtils.getDatatypeCoveredByCast(cp.right()));
-                }
-                return super.visit(cp, context);
+        } else if (predicateOneSide.isConstant()) {
+            if (predicateOneSide instanceof IntegerLikeLiteral) {
+                return new 
NereidsParser().parseExpression(((IntegerLikeLiteral) 
predicateOneSide).toSql());
+            } else {
+                return predicateOneSide;
             }
+        }
+        return null;
+    }
 
-            private boolean isDataTypeValid(DataType originDataType, 
Expression expr) {
-                if ((leftSlotEqualToRightSlot.child(0).getDataType() 
instanceof IntegralType)
-                        && (leftSlotEqualToRightSlot.child(1).getDataType() 
instanceof IntegralType)
-                                && (originDataType instanceof IntegralType)) {
-                    // infer filter can not be lower than original datatype, 
or dataset would be wrong
-                    if (!((IntegralType) originDataType).widerThan(
-                            (IntegralType) 
leftSlotEqualToRightSlot.child(0).getDataType())
-                                    && !((IntegralType) 
originDataType).widerThan(
-                                            (IntegralType) 
leftSlotEqualToRightSlot.child(1).getDataType())) {
-                        return true;
+    private Optional<Expression> validForInfer(Expression expression, 
InferType inferType) {
+        if 
(!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) {
+            return Optional.empty();
+        }
+        if (expression instanceof SlotReference || expression.isConstant()) {
+            return Optional.of(expression);
+        }
+        if (inferType == InferType.INTEGRAL) {
+            if (expression instanceof Cast) {
+                // avoid cast from wider type to narrower type, such as 
cast(int as smallint)
+                // IntegralType dataType = (IntegralType) 
expression.getDataType();
+                // DataType childType = ((Cast) 
expression).child().getDataType();
+                // if (childType instanceof IntegralType && 
dataType.widerThan((IntegralType) childType)) {
+                //     return validForInfer(((Cast) expression).child(), 
inferType);
+                // }
+                return validForInfer(((Cast) expression).child(), inferType);
+            }
+        } else if (inferType == InferType.DATE) {
+            if (expression instanceof Cast) {
+                DataType dataType = expression.getDataType();
+                DataType childType = ((Cast) expression).child().getDataType();
+                // avoid lost precision
+                if (dataType instanceof DateType) {
+                    if (childType instanceof DateV2Type || childType 
instanceof DateType) {
+                        return validForInfer(((Cast) expression).child(), 
inferType);
+                    }
+                } else if (dataType instanceof DateV2Type) {
+                    if (childType instanceof DateType || childType instanceof 
DateV2Type) {
+                        return validForInfer(((Cast) expression).child(), 
inferType);
+                    }
+                } else if (dataType instanceof DateTimeType) {
+                    if (!(childType instanceof DateTimeV2Type)) {
+                        return validForInfer(((Cast) expression).child(), 
inferType);
                     }
+                } else if (dataType instanceof DateTimeV2Type) {
+                    return validForInfer(((Cast) expression).child(), 
inferType);
                 }
-                return false;
             }
-
-            private Expression replaceSlot(Expression expr, DataType 
originDataType) {
-                return expr.rewriteUp(e -> {
-                    if (isDataTypeValid(originDataType, 
leftSlotEqualToRightSlot)) {
-                        if (ExpressionUtils.isTwoExpressionEqualWithCast(e, 
leftSlotEqualToRightSlot.child(0))) {
-                            return leftSlotEqualToRightSlot.child(1);
-                        } else if 
(ExpressionUtils.isTwoExpressionEqualWithCast(e, 
leftSlotEqualToRightSlot.child(1))) {
-                            return leftSlotEqualToRightSlot.child(0);
-                        }
-                    }
-                    return e;
-                });
+        } else if (inferType == InferType.STRING) {
+            if (expression instanceof Cast) {
+                DataType dataType = expression.getDataType();
+                DataType childType = ((Cast) expression).child().getDataType();
+                // avoid substring cast such as cast(char(3) as char(2))
+                if (dataType.width() <= 0 || (dataType.width() >= 
childType.width() && childType.width() >= 0)) {
+                    return validForInfer(((Cast) expression).child(), 
inferType);
+                }
             }
-        }, null);
+        } else {
+            return Optional.empty();
+        }
+        return Optional.empty();
+    }
+
+    private ComparisonInferInfo inferInferInfo(ComparisonPredicate 
comparisonPredicate) {
+        DataType leftType = comparisonPredicate.left().getDataType();
+        InferType inferType;
+        if (leftType instanceof CharacterType) {
+            inferType = InferType.STRING;
+        } else if (leftType instanceof IntegralType) {
+            inferType = InferType.INTEGRAL;
+        } else if (leftType instanceof DateLikeType) {
+            inferType = InferType.DATE;
+        } else {
+            inferType = InferType.OTHER;
+        }
+        Optional<Expression> left = validForInfer(comparisonPredicate.left(), 
inferType);
+        Optional<Expression> right = 
validForInfer(comparisonPredicate.right(), inferType);
+        if (!left.isPresent() || !right.isPresent()) {
+            inferType = InferType.NONE;
+        }
+        return new ComparisonInferInfo(inferType, left, right, 
comparisonPredicate);
     }
 
     /**
      * Currently only equivalence derivation is supported
      * and requires that the left and right sides of an expression must be slot
      */
-    private boolean canEquivalentInfer(Expression predicate) {
-        return predicate instanceof EqualTo
-                && predicate.children().stream().allMatch(e ->
-                    (e instanceof SlotReference) || (e instanceof Cast && 
e.child(0) instanceof SlotReference))
-                && 
predicate.child(0).getDataType().equals(predicate.child(1).getDataType());
+    private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate 
predicate) {
+        if (!(predicate instanceof EqualTo)) {
+            return new ComparisonInferInfo(InferType.NONE,
+                    Optional.of(predicate.left()), 
Optional.of(predicate.right()), predicate);
+        }
+        ComparisonInferInfo info = inferInferInfo(predicate);
+        if (info.inferType == InferType.NONE) {
+            return info;
+        }
+        if (info.left.get() instanceof SlotReference && info.right.get() 
instanceof SlotReference) {
+            return info;
+        }
+        return new ComparisonInferInfo(InferType.NONE, info.left, info.right, 
info.comparisonPredicate);
     }
-
 }
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
index 41c4f423045..1e67808c614 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java
@@ -39,7 +39,6 @@ import 
org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral;
 import org.apache.doris.nereids.trees.expressions.literal.Literal;
 import org.apache.doris.nereids.trees.expressions.literal.NullLiteral;
 import 
org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter;
-import org.apache.doris.nereids.types.DataType;
 
 import com.google.common.base.Preconditions;
 import com.google.common.base.Predicate;
@@ -253,34 +252,6 @@ public class ExpressionUtils {
         }
     }
 
-    /**
-     * get slot covered by cast
-     * example: input: cast(cast(table.columnA)) output: columnA.datatype
-     *
-     */
-    public static DataType getDatatypeCoveredByCast(Expression expr) {
-        if (expr instanceof Cast) {
-            return getDatatypeCoveredByCast(((Cast) expr).child());
-        }
-        return expr.getDataType();
-    }
-
-    /**
-     * judge if expression is slot covered by cast
-     * example: cast(cast(table.columnA))
-     */
-    public static boolean isExpressionSlotCoveredByCast(Expression expr) {
-        if (expr instanceof Cast) {
-            return isExpressionSlotCoveredByCast(((Cast) expr).child());
-        }
-        return expr instanceof SlotReference;
-    }
-
-    public static boolean isTwoExpressionEqualWithCast(Expression left, 
Expression right) {
-        return ExpressionUtils.extractSlotOrCastOnSlot(left)
-            .equals(ExpressionUtils.extractSlotOrCastOnSlot(right));
-    }
-
     /**
      * Replace expression node in the expression tree by `replaceMap` in 
top-down manner.
      * For example.
diff --git 
a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy 
b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
index a1621f1c239..c5942680ea7 100644
--- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
+++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy
@@ -41,7 +41,7 @@ suite("test_infer_predicate") {
 
     explain {
         sql "select * from infer_tb1 inner join infer_tb2 where 
cast(infer_tb2.k4 as int) = infer_tb1.k2  and infer_tb2.k4 = 1;"
-        notContains "PREDICATES: k2"
+        contains "PREDICATES: k2"
     }
 
     explain {


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to