This is an automated email from the ASF dual-hosted git repository. morrysnow pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new ae66464d6b0 [refactor](Nereids) refactor infer predicate rule to avoid lost cast (#25637) ae66464d6b0 is described below commit ae66464d6b039e34771fa330ea85194849d43c43 Author: morrySnow <101034200+morrys...@users.noreply.github.com> AuthorDate: Wed Oct 25 14:12:22 2023 +0800 [refactor](Nereids) refactor infer predicate rule to avoid lost cast (#25637) extract slot and literal in comparison predicate. infer new one by equals predicates. use TypeCoercion to add cast on new comparison predicate to ensure it is correct. This reverts "[Fix](Nereids) Add cast comparison with slot reference when inferring predicate (#21171)" commit 58f2593ba1b65713e7b3c1ed39fc84be8cc3ff2c. --- .../rules/rewrite/PredicatePropagation.java | 223 ++++++++++++++++----- .../apache/doris/nereids/util/ExpressionUtils.java | 29 --- .../infer_predicate/infer_predicate.groovy | 2 +- 3 files changed, 170 insertions(+), 84 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java index d5323fc58a8..9341c3db5c1 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/rules/rewrite/PredicatePropagation.java @@ -17,19 +17,28 @@ package org.apache.doris.nereids.rules.rewrite; +import org.apache.doris.nereids.parser.NereidsParser; +import org.apache.doris.nereids.rules.expression.rules.SimplifyComparisonPredicate; import org.apache.doris.nereids.trees.expressions.Cast; import org.apache.doris.nereids.trees.expressions.ComparisonPredicate; import org.apache.doris.nereids.trees.expressions.EqualTo; import org.apache.doris.nereids.trees.expressions.Expression; import org.apache.doris.nereids.trees.expressions.SlotReference; -import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; +import org.apache.doris.nereids.trees.expressions.literal.IntegerLikeLiteral; import org.apache.doris.nereids.types.DataType; +import org.apache.doris.nereids.types.DateTimeType; +import org.apache.doris.nereids.types.DateTimeV2Type; +import org.apache.doris.nereids.types.DateType; +import org.apache.doris.nereids.types.DateV2Type; +import org.apache.doris.nereids.types.coercion.CharacterType; +import org.apache.doris.nereids.types.coercion.DateLikeType; import org.apache.doris.nereids.types.coercion.IntegralType; -import org.apache.doris.nereids.util.ExpressionUtils; +import org.apache.doris.nereids.util.TypeCoercionUtils; import com.google.common.collect.Sets; -import java.util.List; +import java.util.Objects; +import java.util.Optional; import java.util.Set; import java.util.stream.Collectors; @@ -40,19 +49,61 @@ import java.util.stream.Collectors; */ public class PredicatePropagation { + private enum InferType { + NONE(null), + INTEGRAL(IntegralType.class), + STRING(CharacterType.class), + DATE(DateLikeType.class), + OTHER(DataType.class) + ; + + private final Class<? extends DataType> superClazz; + + InferType(Class<? extends DataType> superClazz) { + this.superClazz = superClazz; + } + } + + private class ComparisonInferInfo { + + public final InferType inferType; + public final Optional<Expression> left; + public final Optional<Expression> right; + public final ComparisonPredicate comparisonPredicate; + + public ComparisonInferInfo(InferType inferType, + Optional<Expression> left, Optional<Expression> right, + ComparisonPredicate comparisonPredicate) { + this.inferType = inferType; + this.left = left; + this.right = right; + this.comparisonPredicate = comparisonPredicate; + } + } + /** * infer additional predicates. */ public Set<Expression> infer(Set<Expression> predicates) { Set<Expression> inferred = Sets.newHashSet(); for (Expression predicate : predicates) { - if (canEquivalentInfer(predicate)) { - List<Expression> newInferred = predicates.stream() - .filter(p -> !p.equals(predicate)) - .map(p -> doInfer(predicate, p)) - .collect(Collectors.toList()); - inferred.addAll(newInferred); + if (!(predicate instanceof ComparisonPredicate)) { + continue; + } + ComparisonInferInfo equalInfo = getEquivalentInferInfo((ComparisonPredicate) predicate); + if (equalInfo.inferType == InferType.NONE) { + continue; } + Set<Expression> newInferred = predicates.stream() + .filter(ComparisonPredicate.class::isInstance) + .filter(p -> !p.equals(predicate)) + .map(ComparisonPredicate.class::cast) + .map(this::inferInferInfo) + .filter(predicateInfo -> predicateInfo.inferType != InferType.NONE) + .map(predicateInfo -> doInfer(equalInfo, predicateInfo)) + .filter(Objects::nonNull) + .collect(Collectors.toSet()); + inferred.addAll(newInferred); } inferred.removeAll(predicates); return inferred; @@ -64,64 +115,128 @@ public class PredicatePropagation { * TODO: We should determine whether `expression` satisfies the condition for replacement * eg: Satisfy `expression` is non-deterministic */ - private Expression doInfer(Expression leftSlotEqualToRightSlot, Expression expression) { - return expression.accept(new DefaultExpressionRewriter<Void>() { + private Expression doInfer(ComparisonInferInfo equalInfo, ComparisonInferInfo predicateInfo) { + Expression predicateLeft = predicateInfo.left.get(); + Expression predicateRight = predicateInfo.right.get(); + Expression equalLeft = equalInfo.left.get(); + Expression equalRight = equalInfo.right.get(); + Expression newLeft = inferOneSide(predicateLeft, equalLeft, equalRight); + Expression newRight = inferOneSide(predicateRight, equalLeft, equalRight); + if (newLeft == null || newRight == null) { + return null; + } + ComparisonPredicate newPredicate = (ComparisonPredicate) predicateInfo + .comparisonPredicate.withChildren(newLeft, newRight); + return SimplifyComparisonPredicate.INSTANCE + .rewrite(TypeCoercionUtils.processComparisonPredicate(newPredicate), null); + } - @Override - public Expression visit(Expression expr, Void context) { - return expr; + private Expression inferOneSide(Expression predicateOneSide, Expression equalLeft, Expression equalRight) { + if (predicateOneSide instanceof SlotReference) { + if (predicateOneSide.equals(equalLeft)) { + return equalRight; + } else if (predicateOneSide.equals(equalRight)) { + return equalLeft; } - - @Override - public Expression visitComparisonPredicate(ComparisonPredicate cp, Void context) { - // we need to get expression covered by cast, because we want to infer different datatype - if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.left()) && (cp.right().isConstant())) { - return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.left())); - } else if (ExpressionUtils.isExpressionSlotCoveredByCast(cp.right()) && cp.left().isConstant()) { - return replaceSlot(cp, ExpressionUtils.getDatatypeCoveredByCast(cp.right())); - } - return super.visit(cp, context); + } else if (predicateOneSide.isConstant()) { + if (predicateOneSide instanceof IntegerLikeLiteral) { + return new NereidsParser().parseExpression(((IntegerLikeLiteral) predicateOneSide).toSql()); + } else { + return predicateOneSide; } + } + return null; + } - private boolean isDataTypeValid(DataType originDataType, Expression expr) { - if ((leftSlotEqualToRightSlot.child(0).getDataType() instanceof IntegralType) - && (leftSlotEqualToRightSlot.child(1).getDataType() instanceof IntegralType) - && (originDataType instanceof IntegralType)) { - // infer filter can not be lower than original datatype, or dataset would be wrong - if (!((IntegralType) originDataType).widerThan( - (IntegralType) leftSlotEqualToRightSlot.child(0).getDataType()) - && !((IntegralType) originDataType).widerThan( - (IntegralType) leftSlotEqualToRightSlot.child(1).getDataType())) { - return true; + private Optional<Expression> validForInfer(Expression expression, InferType inferType) { + if (!inferType.superClazz.isAssignableFrom(expression.getDataType().getClass())) { + return Optional.empty(); + } + if (expression instanceof SlotReference || expression.isConstant()) { + return Optional.of(expression); + } + if (inferType == InferType.INTEGRAL) { + if (expression instanceof Cast) { + // avoid cast from wider type to narrower type, such as cast(int as smallint) + // IntegralType dataType = (IntegralType) expression.getDataType(); + // DataType childType = ((Cast) expression).child().getDataType(); + // if (childType instanceof IntegralType && dataType.widerThan((IntegralType) childType)) { + // return validForInfer(((Cast) expression).child(), inferType); + // } + return validForInfer(((Cast) expression).child(), inferType); + } + } else if (inferType == InferType.DATE) { + if (expression instanceof Cast) { + DataType dataType = expression.getDataType(); + DataType childType = ((Cast) expression).child().getDataType(); + // avoid lost precision + if (dataType instanceof DateType) { + if (childType instanceof DateV2Type || childType instanceof DateType) { + return validForInfer(((Cast) expression).child(), inferType); + } + } else if (dataType instanceof DateV2Type) { + if (childType instanceof DateType || childType instanceof DateV2Type) { + return validForInfer(((Cast) expression).child(), inferType); + } + } else if (dataType instanceof DateTimeType) { + if (!(childType instanceof DateTimeV2Type)) { + return validForInfer(((Cast) expression).child(), inferType); } + } else if (dataType instanceof DateTimeV2Type) { + return validForInfer(((Cast) expression).child(), inferType); } - return false; } - - private Expression replaceSlot(Expression expr, DataType originDataType) { - return expr.rewriteUp(e -> { - if (isDataTypeValid(originDataType, leftSlotEqualToRightSlot)) { - if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(0))) { - return leftSlotEqualToRightSlot.child(1); - } else if (ExpressionUtils.isTwoExpressionEqualWithCast(e, leftSlotEqualToRightSlot.child(1))) { - return leftSlotEqualToRightSlot.child(0); - } - } - return e; - }); + } else if (inferType == InferType.STRING) { + if (expression instanceof Cast) { + DataType dataType = expression.getDataType(); + DataType childType = ((Cast) expression).child().getDataType(); + // avoid substring cast such as cast(char(3) as char(2)) + if (dataType.width() <= 0 || (dataType.width() >= childType.width() && childType.width() >= 0)) { + return validForInfer(((Cast) expression).child(), inferType); + } } - }, null); + } else { + return Optional.empty(); + } + return Optional.empty(); + } + + private ComparisonInferInfo inferInferInfo(ComparisonPredicate comparisonPredicate) { + DataType leftType = comparisonPredicate.left().getDataType(); + InferType inferType; + if (leftType instanceof CharacterType) { + inferType = InferType.STRING; + } else if (leftType instanceof IntegralType) { + inferType = InferType.INTEGRAL; + } else if (leftType instanceof DateLikeType) { + inferType = InferType.DATE; + } else { + inferType = InferType.OTHER; + } + Optional<Expression> left = validForInfer(comparisonPredicate.left(), inferType); + Optional<Expression> right = validForInfer(comparisonPredicate.right(), inferType); + if (!left.isPresent() || !right.isPresent()) { + inferType = InferType.NONE; + } + return new ComparisonInferInfo(inferType, left, right, comparisonPredicate); } /** * Currently only equivalence derivation is supported * and requires that the left and right sides of an expression must be slot */ - private boolean canEquivalentInfer(Expression predicate) { - return predicate instanceof EqualTo - && predicate.children().stream().allMatch(e -> - (e instanceof SlotReference) || (e instanceof Cast && e.child(0) instanceof SlotReference)) - && predicate.child(0).getDataType().equals(predicate.child(1).getDataType()); + private ComparisonInferInfo getEquivalentInferInfo(ComparisonPredicate predicate) { + if (!(predicate instanceof EqualTo)) { + return new ComparisonInferInfo(InferType.NONE, + Optional.of(predicate.left()), Optional.of(predicate.right()), predicate); + } + ComparisonInferInfo info = inferInferInfo(predicate); + if (info.inferType == InferType.NONE) { + return info; + } + if (info.left.get() instanceof SlotReference && info.right.get() instanceof SlotReference) { + return info; + } + return new ComparisonInferInfo(InferType.NONE, info.left, info.right, info.comparisonPredicate); } - } diff --git a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java index 41c4f423045..1e67808c614 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java +++ b/fe/fe-core/src/main/java/org/apache/doris/nereids/util/ExpressionUtils.java @@ -39,7 +39,6 @@ import org.apache.doris.nereids.trees.expressions.literal.BooleanLiteral; import org.apache.doris.nereids.trees.expressions.literal.Literal; import org.apache.doris.nereids.trees.expressions.literal.NullLiteral; import org.apache.doris.nereids.trees.expressions.visitor.DefaultExpressionRewriter; -import org.apache.doris.nereids.types.DataType; import com.google.common.base.Preconditions; import com.google.common.base.Predicate; @@ -253,34 +252,6 @@ public class ExpressionUtils { } } - /** - * get slot covered by cast - * example: input: cast(cast(table.columnA)) output: columnA.datatype - * - */ - public static DataType getDatatypeCoveredByCast(Expression expr) { - if (expr instanceof Cast) { - return getDatatypeCoveredByCast(((Cast) expr).child()); - } - return expr.getDataType(); - } - - /** - * judge if expression is slot covered by cast - * example: cast(cast(table.columnA)) - */ - public static boolean isExpressionSlotCoveredByCast(Expression expr) { - if (expr instanceof Cast) { - return isExpressionSlotCoveredByCast(((Cast) expr).child()); - } - return expr instanceof SlotReference; - } - - public static boolean isTwoExpressionEqualWithCast(Expression left, Expression right) { - return ExpressionUtils.extractSlotOrCastOnSlot(left) - .equals(ExpressionUtils.extractSlotOrCastOnSlot(right)); - } - /** * Replace expression node in the expression tree by `replaceMap` in top-down manner. * For example. diff --git a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy index a1621f1c239..c5942680ea7 100644 --- a/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy +++ b/regression-test/suites/nereids_p0/infer_predicate/infer_predicate.groovy @@ -41,7 +41,7 @@ suite("test_infer_predicate") { explain { sql "select * from infer_tb1 inner join infer_tb2 where cast(infer_tb2.k4 as int) = infer_tb1.k2 and infer_tb2.k4 = 1;" - notContains "PREDICATES: k2" + contains "PREDICATES: k2" } explain { --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org