This is an automated email from the ASF dual-hosted git repository. lingmiao pushed a commit to branch master in repository https://gitbox.apache.org/repos/asf/doris.git
The following commit(s) were added to refs/heads/master by this push: new f758e1166a [fix] Fix RewriteBinaryPredicatesRule which causes wrong query results in some cases. (#10551) f758e1166a is described below commit f758e1166a2844d507806c2ddc5376cea036d31b Author: luozenglin <37725793+luozeng...@users.noreply.github.com> AuthorDate: Wed Jul 6 15:39:27 2022 +0800 [fix] Fix RewriteBinaryPredicatesRule which causes wrong query results in some cases. (#10551) During the query planning phase, the binary predicate rewrite optimization process converting DecimalLiteral to integers may overflow, resulting in false values like "id = 12345678901.0" (see the issue for detailed examples). This pr fixes a possible overflow and optimizes the case where DecimalLiteral is not in the column type value range. Issue Number: close #10544 --- .../org/apache/doris/analysis/DecimalLiteral.java | 6 ++ .../java/org/apache/doris/analysis/IntLiteral.java | 22 ++++ .../doris/rewrite/RewriteBinaryPredicatesRule.java | 72 +++++++++---- .../org/apache/doris/analysis/SelectStmtTest.java | 40 +++---- .../java/org/apache/doris/planner/PlannerTest.java | 4 +- .../rewrite/RewriteBinaryPredicatesRuleTest.java | 118 +++++++++++++++++++++ 6 files changed, 219 insertions(+), 43 deletions(-) diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java index e938a46361..3e5bf9abc7 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java @@ -246,6 +246,12 @@ public class DecimalLiteral extends LiteralExpr { } else if (targetType.isFloatingPointType()) { return new FloatLiteral(value.doubleValue(), targetType); } else if (targetType.isIntegerType()) { + // If the integer part of BigDecimal is too big to fit into long, + // longValue() will only return the low-order 64-bit value. + if (value.compareTo(BigDecimal.valueOf(Long.MAX_VALUE)) > 0 + || value.compareTo(BigDecimal.valueOf(Long.MIN_VALUE)) < 0) { + throw new AnalysisException("Integer part of " + value + " exceeds storage range of Long Type."); + } return new IntLiteral(value.longValue(), targetType); } else if (targetType.isStringType()) { return new StringLiteral(value.toString()); diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java b/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java index 00662c5e6a..4d4f673822 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java +++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java @@ -172,6 +172,28 @@ public class IntLiteral extends LiteralExpr { return new IntLiteral(value); } + public static IntLiteral createMaxValue(Type type) { + long value = 0L; + switch (type.getPrimitiveType()) { + case TINYINT: + value = TINY_INT_MAX; + break; + case SMALLINT: + value = SMALL_INT_MAX; + break; + case INT: + value = INT_MAX; + break; + case BIGINT: + value = BIG_INT_MAX; + break; + default: + Preconditions.checkState(false); + } + + return new IntLiteral(value); + } + @Override public boolean isMinValue() { switch (type.getPrimitiveType()) { diff --git a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java index 4ed232b4b1..a18797b657 100644 --- a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java +++ b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java @@ -19,10 +19,13 @@ package org.apache.doris.rewrite; import org.apache.doris.analysis.Analyzer; import org.apache.doris.analysis.BinaryPredicate; +import org.apache.doris.analysis.BinaryPredicate.Operator; import org.apache.doris.analysis.BoolLiteral; import org.apache.doris.analysis.CastExpr; import org.apache.doris.analysis.DecimalLiteral; import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.IntLiteral; +import org.apache.doris.analysis.LiteralExpr; import org.apache.doris.analysis.SlotRef; import org.apache.doris.catalog.Type; import org.apache.doris.common.AnalysisException; @@ -53,36 +56,61 @@ public class RewriteBinaryPredicatesRule implements ExprRewriteRule { * 3) "select * from T where t1 != 2.0" is converted to "select * from T where t1 != 2" * 4) "select * from T where t1 != 2.1" is converted to "select * from T" * 5) "select * from T where t1 <= 2.0" is converted to "select * from T where t1 <= 2" - * 6) "select * from T where t1 <= 2.1" is converted to "select * from T where t1 <3" + * 6) "select * from T where t1 <= 2.1" is converted to "select * from T where t1 <=2" * 7) "select * from T where t1 >= 2.0" is converted to "select * from T where t1 >= 2" * 8) "select * from T where t1 >= 2.1" is converted to "select * from T where t1> 2" * 9) "select * from T where t1 <2.0" is converted to "select * from T where t1 <2" - * 10) "select * from T where t1 <2.1" is converted to "select * from T where t1 <3" + * 10) "select * from T where t1 <2.1" is converted to "select * from T where t1 <=2" * 11) "select * from T where t1> 2.0" is converted to "select * from T where t1> 2" * 12) "select * from T where t1> 2.1" is converted to "select * from T where t1> 2" */ - private Expr rewriteBigintSlotRefCompareDecimalLiteral(Expr expr0, Expr expr1, BinaryPredicate.Operator op) - throws AnalysisException { - if (((DecimalLiteral) expr1).getDoubleValue() % (int) (((DecimalLiteral) expr1).getDoubleValue()) != 0) { - if (op == BinaryPredicate.Operator.EQ || op == BinaryPredicate.Operator.EQ_FOR_NULL) { - return new BoolLiteral(false); - } else if (op == BinaryPredicate.Operator.NE) { + private Expr rewriteBigintSlotRefCompareDecimalLiteral(Expr expr0, DecimalLiteral expr1, + BinaryPredicate.Operator op) { + Type columnType = expr0.getSrcSlotRef().getColumn().getType(); + try { + // Convert childExpr to column type and compare the converted values. There are 3 possible situations: + // case 1. The value of childExpr exceeds the range of the column type, then castTo() will throw an + // exception. For example, the value of childExpr is 128.0 and the column type is tinyint. + // case 2. childExpr is converted to column type, but the value of childExpr loses precision. + // For example, 2.1 is converted to 2; + // case 3. childExpr is precisely converted to column type. For example, 2.0 is converted to 2. + LiteralExpr newExpr = (LiteralExpr) expr1.castTo(columnType); + int compResult = expr1.compareLiteral(newExpr); + // case 2 + if (compResult != 0) { + if (op == Operator.EQ || op == Operator.EQ_FOR_NULL) { + return new BoolLiteral(false); + } else if (op == Operator.NE) { + return new BoolLiteral(true); + } + + if (compResult > 0) { + if (op == Operator.LT) { + op = Operator.LE; + } else if (op == Operator.GE) { + op = Operator.GT; + } + } else { + if (op == Operator.LE) { + op = Operator.LT; + } else if (op == Operator.GT) { + op = Operator.GE; + } + } + } + // case 3 + return new BinaryPredicate(op, expr0.castTo(columnType), newExpr); + } catch (AnalysisException e) { + // case 1 + IntLiteral colTypeMinValue = IntLiteral.createMinValue(columnType); + IntLiteral colTypeMaxValue = IntLiteral.createMaxValue(columnType); + if (op == Operator.NE || ((expr1).compareLiteral(colTypeMinValue) < 0 && (op == Operator.GE + || op == Operator.GT)) || ((expr1).compareLiteral(colTypeMaxValue) > 0 && (op == Operator.LE + || op == Operator.LT))) { return new BoolLiteral(true); - } else if (op == BinaryPredicate.Operator.LE) { - ((DecimalLiteral) expr1).roundCeiling(); - op = BinaryPredicate.Operator.LT; - } else if (op == BinaryPredicate.Operator.GE) { - ((DecimalLiteral) expr1).roundFloor(); - op = BinaryPredicate.Operator.GT; - } else if (op == BinaryPredicate.Operator.LT) { - ((DecimalLiteral) expr1).roundCeiling(); - } else if (op == BinaryPredicate.Operator.GT) { - ((DecimalLiteral) expr1).roundFloor(); } + return new BoolLiteral(false); } - expr0 = expr0.getChild(0); - expr1 = expr1.castTo(Type.BIGINT); - return new BinaryPredicate(op, expr0, expr1); } @Override @@ -95,7 +123,7 @@ public class RewriteBinaryPredicatesRule implements ExprRewriteRule { Expr expr1 = expr.getChild(1); if (expr0 instanceof CastExpr && expr0.getType() == Type.DECIMALV2 && expr0.getChild(0) instanceof SlotRef && expr0.getChild(0).getType().getResultType() == Type.BIGINT && expr1 instanceof DecimalLiteral) { - return rewriteBigintSlotRefCompareDecimalLiteral(expr0, expr1, op); + return rewriteBigintSlotRefCompareDecimalLiteral(expr0, (DecimalLiteral) expr1, op); } return expr; } diff --git a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java index 5577a566be..9083fea261 100755 --- a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java @@ -272,25 +272,27 @@ public class SelectStmtTest { + " );"; SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql, ctx); stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), ctx).getExprRewriter()); - String rewritedFragment1 = "((`t1`.`k2` = `t4`.`k2` AND `t3`.`k3` = `t1`.`k3` " - + "AND ((`t1`.`k4` >= 50 AND `t1`.`k4` <= 200) AND " - + "(`t3`.`k1` = 'D' OR `t3`.`k1` = 'S' OR `t3`.`k1` = 'W') " - + "AND (`t4`.`k3` = '2 yr Degree' OR `t4`.`k3` = 'Advanced Degree' OR `t4`.`k3` = 'Secondary') " - + "AND (`t4`.`k4` = 1 OR `t4`.`k4` = 3))) " - + "AND ((`t3`.`k1` = 'D' AND `t4`.`k3` = '2 yr Degree' " - + "AND `t1`.`k4` >= 100 AND `t1`.`k4` <= 150 AND `t4`.`k4` = 3) " - + "OR (`t3`.`k1` = 'S' AND `t4`.`k3` = 'Secondary' AND `t1`.`k4` >= 50 " - + "AND `t1`.`k4` <= 100 AND `t4`.`k4` = 1) OR (`t3`.`k1` = 'W' AND `t4`.`k3` = 'Advanced Degree' " - + "AND `t1`.`k4` >= 150 AND `t1`.`k4` <= 200 AND `t4`.`k4` = 1)))"; - String rewritedFragment2 = "((`t1`.`k1` = `t5`.`k1` AND `t5`.`k2` = 'United States' " - + "AND ((`t1`.`k4` >= 50 AND `t1`.`k4` <= 300) " - + "AND `t5`.`k3` IN ('CO', 'IL', 'MN', 'OH', 'MT', 'NM', 'TX', 'MO', 'MI'))) " - + "AND ((`t5`.`k3` IN ('CO', 'IL', 'MN') AND `t1`.`k4` >= 100 AND `t1`.`k4` <= 200) " - + "OR (`t5`.`k3` IN ('OH', 'MT', 'NM') AND `t1`.`k4` >= 150 AND `t1`.`k4` <= 300) OR (`t5`.`k3` IN " - + "('TX', 'MO', 'MI') AND `t1`.`k4` >= 50 AND `t1`.`k4` <= 250)))"; - System.out.println(stmt.toSql()); - Assert.assertTrue(stmt.toSql().contains(rewritedFragment1)); - Assert.assertTrue(stmt.toSql().contains(rewritedFragment2)); + String commonExpr1 = "`t1`.`k2` = `t4`.`k2`"; + String commonExpr2 = "`t3`.`k3` = `t1`.`k3`"; + String commonExpr3 = "`t1`.`k1` = `t5`.`k1`"; + String commonExpr4 = "t5`.`k2` = 'United States'"; + String betweenExpanded1 = "`t1`.`k4` >= 100 AND `t1`.`k4` <= 150"; + String betweenExpanded2 = "`t1`.`k4` >= 50 AND `t1`.`k4` <= 100"; + String betweenExpanded3 = "`t1`.`k4` >= 50 AND `t1`.`k4` <= 250"; + + String rewrittenSql = stmt.toSql(); + System.out.println(rewrittenSql); + Assert.assertTrue(rewrittenSql.contains(commonExpr1)); + Assert.assertEquals(rewrittenSql.indexOf(commonExpr1), rewrittenSql.lastIndexOf(commonExpr1)); + Assert.assertTrue(rewrittenSql.contains(commonExpr2)); + Assert.assertEquals(rewrittenSql.indexOf(commonExpr2), rewrittenSql.lastIndexOf(commonExpr2)); + Assert.assertTrue(rewrittenSql.contains(commonExpr3)); + Assert.assertEquals(rewrittenSql.indexOf(commonExpr3), rewrittenSql.lastIndexOf(commonExpr3)); + Assert.assertTrue(rewrittenSql.contains(commonExpr4)); + Assert.assertEquals(rewrittenSql.indexOf(commonExpr4), rewrittenSql.lastIndexOf(commonExpr4)); + Assert.assertTrue(rewrittenSql.contains(betweenExpanded1)); + Assert.assertTrue(rewrittenSql.contains(betweenExpanded2)); + Assert.assertTrue(rewrittenSql.contains(betweenExpanded3)); String sql2 = "select\n" + " avg(t1.k4)\n" diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java b/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java index d52314518a..5b50e1972f 100644 --- a/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java +++ b/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java @@ -422,11 +422,11 @@ public class PlannerTest extends TestWithFeService { compare.accept("select * from db1.tbl2 where k1 != 2.0", "select * from db1.tbl2 where k1 != 2"); compare.accept("select * from db1.tbl2 where k1 != 2.1", "select * from db1.tbl2"); compare.accept("select * from db1.tbl2 where k1 <= 2.0", "select * from db1.tbl2 where k1 <= 2"); - compare.accept("select * from db1.tbl2 where k1 <= 2.1", "select * from db1.tbl2 where k1 < 3"); + compare.accept("select * from db1.tbl2 where k1 <= 2.1", "select * from db1.tbl2 where k1 <= 2"); compare.accept("select * from db1.tbl2 where k1 >= 2.0", "select * from db1.tbl2 where k1 >= 2"); compare.accept("select * from db1.tbl2 where k1 >= 2.1", "select * from db1.tbl2 where k1 > 2"); compare.accept("select * from db1.tbl2 where k1 < 2.0", "select * from db1.tbl2 where k1 < 2"); - compare.accept("select * from db1.tbl2 where k1 < 2.1", "select * from db1.tbl2 where k1 < 3"); + compare.accept("select * from db1.tbl2 where k1 < 2.1", "select * from db1.tbl2 where k1 <= 2"); compare.accept("select * from db1.tbl2 where k1 > 2.0", "select * from db1.tbl2 where k1 > 2"); compare.accept("select * from db1.tbl2 where k1 > 2.1", "select * from db1.tbl2 where k1 > 2"); } diff --git a/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRuleTest.java b/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRuleTest.java new file mode 100644 index 0000000000..0b06502e11 --- /dev/null +++ b/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRuleTest.java @@ -0,0 +1,118 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +package org.apache.doris.rewrite; + +import org.apache.doris.analysis.BinaryPredicate; +import org.apache.doris.analysis.BinaryPredicate.Operator; +import org.apache.doris.analysis.BoolLiteral; +import org.apache.doris.analysis.Expr; +import org.apache.doris.analysis.LiteralExpr; +import org.apache.doris.analysis.SelectStmt; +import org.apache.doris.catalog.PrimitiveType; +import org.apache.doris.qe.StmtExecutor; +import org.apache.doris.utframe.TestWithFeService; + +import org.junit.jupiter.api.Assertions; +import org.junit.jupiter.api.Test; + +public class RewriteBinaryPredicatesRuleTest extends TestWithFeService { + @Override + protected void runBeforeAll() throws Exception { + connectContext = createDefaultCtx(); + createDatabase("db"); + useDatabase("db"); + String createTable = "create table table1(id smallint, cost bigint sum) " + + "aggregate key(`id`) distributed by hash (`id`) buckets 4 " + + "properties (\"replication_num\"=\"1\");"; + createTable(createTable); + } + + @Test + public void testNormal() throws Exception { + testBase(Operator.EQ, "2.0", Operator.EQ, 2L); + testBoolean(Operator.EQ, "2.5", false); + + testBase(Operator.NE, "2.0", Operator.NE, 2L); + testBoolean(Operator.NE, "2.5", true); + + testBase(Operator.LE, "2.0", Operator.LE, 2L); + testBase(Operator.LE, "-2.5", Operator.LT, -2L); + testBase(Operator.LE, "2.5", Operator.LE, 2L); + + testBase(Operator.GE, "2.0", Operator.GE, 2L); + testBase(Operator.GE, "-2.5", Operator.GE, -2L); + testBase(Operator.GE, "2.5", Operator.GT, 2L); + + testBase(Operator.LT, "2.0", Operator.LT, 2L); + testBase(Operator.LT, "-2.5", Operator.LT, -2L); + testBase(Operator.LT, "2.5", Operator.LE, 2L); + + testBase(Operator.GT, "2.0", Operator.GT, 2L); + testBase(Operator.GT, "-2.5", Operator.GE, -2L); + testBase(Operator.GT, "2.5", Operator.GT, 2L); + } + + @Test + public void testOutOfRange() throws Exception { + // 32767 -32768 + testBoolean(Operator.EQ, "-32769.0", false); + testBase(Operator.EQ, "32767.0", Operator.EQ, 32767L); + + testBoolean(Operator.NE, "32768.0", true); + + testBoolean(Operator.LE, "32768.2", true); + testBoolean(Operator.LE, "-32769.1", false); + testBase(Operator.LE, "32767.0", Operator.LE, 32767L); + + testBoolean(Operator.GE, "32768.1", false); + testBoolean(Operator.GE, "-32769.1", true); + testBase(Operator.GE, "32767.0", Operator.GE, 32767L); + + testBoolean(Operator.LT, "32768.1", true); + testBoolean(Operator.LT, "-32769.1", false); + testBase(Operator.LT, "32767.1", Operator.LE, 32767L); + + testBoolean(Operator.GT, "32768.1", false); + testBoolean(Operator.GT, "-32769.1", true); + testBase(Operator.GT, "32767.0", Operator.GT, 32767L); + } + + private void testBase(Operator operator, String queryLiteral, Operator expectedOperator, long expectedChild1) + throws Exception { + Expr expr1 = getExpr(operator, queryLiteral); + Assertions.assertTrue(expr1 instanceof BinaryPredicate); + Assertions.assertEquals(expectedOperator, ((BinaryPredicate) expr1).getOp()); + Assertions.assertEquals(PrimitiveType.SMALLINT, expr1.getChild(0).getType().getPrimitiveType()); + Assertions.assertEquals(PrimitiveType.SMALLINT, expr1.getChild(1).getType().getPrimitiveType()); + Assertions.assertEquals(expectedChild1, ((LiteralExpr) expr1.getChild(1)).getLongValue()); + } + + private void testBoolean(Operator operator, String queryLiteral, boolean result) throws Exception { + Expr expr1 = getExpr(operator, queryLiteral); + Assertions.assertTrue(expr1 instanceof BoolLiteral); + Assertions.assertEquals(result, ((BoolLiteral) expr1).getValue()); + } + + private Expr getExpr(Operator operator, String queryLiteral) throws Exception { + String queryFormat = "select * from table1 where id %s %s;"; + String query = String.format(queryFormat, operator.toString(), queryLiteral); + StmtExecutor executor1 = getSqlStmtExecutor(query); + Assertions.assertNotNull(executor1); + return ((SelectStmt) executor1.getParsedStmt()).getWhereClause(); + } +} --------------------------------------------------------------------- To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org For additional commands, e-mail: commits-h...@doris.apache.org