This is an automated email from the ASF dual-hosted git repository.

lingmiao pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/doris.git


The following commit(s) were added to refs/heads/master by this push:
     new f758e1166a [fix] Fix RewriteBinaryPredicatesRule which causes wrong 
query results in some cases. (#10551)
f758e1166a is described below

commit f758e1166a2844d507806c2ddc5376cea036d31b
Author: luozenglin <37725793+luozeng...@users.noreply.github.com>
AuthorDate: Wed Jul 6 15:39:27 2022 +0800

    [fix] Fix RewriteBinaryPredicatesRule which causes wrong query results in 
some cases. (#10551)
    
    During the query planning phase, the binary predicate rewrite optimization 
process converting DecimalLiteral to integers may overflow, resulting in false 
values like "id = 12345678901.0" (see the issue for detailed examples).
    
    This pr fixes a possible overflow and optimizes the case where 
DecimalLiteral is not in the column type value range.
    
    Issue Number: close #10544
---
 .../org/apache/doris/analysis/DecimalLiteral.java  |   6 ++
 .../java/org/apache/doris/analysis/IntLiteral.java |  22 ++++
 .../doris/rewrite/RewriteBinaryPredicatesRule.java |  72 +++++++++----
 .../org/apache/doris/analysis/SelectStmtTest.java  |  40 +++----
 .../java/org/apache/doris/planner/PlannerTest.java |   4 +-
 .../rewrite/RewriteBinaryPredicatesRuleTest.java   | 118 +++++++++++++++++++++
 6 files changed, 219 insertions(+), 43 deletions(-)

diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java
index e938a46361..3e5bf9abc7 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/DecimalLiteral.java
@@ -246,6 +246,12 @@ public class DecimalLiteral extends LiteralExpr {
         } else if (targetType.isFloatingPointType()) {
             return new FloatLiteral(value.doubleValue(), targetType);
         } else if (targetType.isIntegerType()) {
+            // If the integer part of BigDecimal is too big to fit into long,
+            // longValue() will only return the low-order 64-bit value.
+            if (value.compareTo(BigDecimal.valueOf(Long.MAX_VALUE)) > 0
+                    || value.compareTo(BigDecimal.valueOf(Long.MIN_VALUE)) < 
0) {
+                throw new AnalysisException("Integer part of " + value + " 
exceeds storage range of Long Type.");
+            }
             return new IntLiteral(value.longValue(), targetType);
         } else if (targetType.isStringType()) {
             return new StringLiteral(value.toString());
diff --git a/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java 
b/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java
index 00662c5e6a..4d4f673822 100644
--- a/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java
+++ b/fe/fe-core/src/main/java/org/apache/doris/analysis/IntLiteral.java
@@ -172,6 +172,28 @@ public class IntLiteral extends LiteralExpr {
         return new IntLiteral(value);
     }
 
+    public static IntLiteral createMaxValue(Type type) {
+        long value = 0L;
+        switch (type.getPrimitiveType()) {
+            case TINYINT:
+                value = TINY_INT_MAX;
+                break;
+            case SMALLINT:
+                value = SMALL_INT_MAX;
+                break;
+            case INT:
+                value = INT_MAX;
+                break;
+            case BIGINT:
+                value = BIG_INT_MAX;
+                break;
+            default:
+                Preconditions.checkState(false);
+        }
+
+        return new IntLiteral(value);
+    }
+
     @Override
     public boolean isMinValue() {
         switch (type.getPrimitiveType()) {
diff --git 
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java
 
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java
index 4ed232b4b1..a18797b657 100644
--- 
a/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java
+++ 
b/fe/fe-core/src/main/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRule.java
@@ -19,10 +19,13 @@ package org.apache.doris.rewrite;
 
 import org.apache.doris.analysis.Analyzer;
 import org.apache.doris.analysis.BinaryPredicate;
+import org.apache.doris.analysis.BinaryPredicate.Operator;
 import org.apache.doris.analysis.BoolLiteral;
 import org.apache.doris.analysis.CastExpr;
 import org.apache.doris.analysis.DecimalLiteral;
 import org.apache.doris.analysis.Expr;
+import org.apache.doris.analysis.IntLiteral;
+import org.apache.doris.analysis.LiteralExpr;
 import org.apache.doris.analysis.SlotRef;
 import org.apache.doris.catalog.Type;
 import org.apache.doris.common.AnalysisException;
@@ -53,36 +56,61 @@ public class RewriteBinaryPredicatesRule implements 
ExprRewriteRule {
      * 3) "select * from T where t1 != 2.0" is converted to "select * from T 
where t1 != 2"
      * 4) "select * from T where t1 != 2.1" is converted to "select * from T"
      * 5) "select * from T where t1 <= 2.0" is converted to "select * from T 
where t1 <= 2"
-     * 6) "select * from T where t1 <= 2.1" is converted to "select * from T 
where t1 <3"
+     * 6) "select * from T where t1 <= 2.1" is converted to "select * from T 
where t1 <=2"
      * 7) "select * from T where t1 >= 2.0" is converted to "select * from T 
where t1 >= 2"
      * 8) "select * from T where t1 >= 2.1" is converted to "select * from T 
where t1> 2"
      * 9) "select * from T where t1 <2.0" is converted to "select * from T 
where t1 <2"
-     * 10) "select * from T where t1 <2.1" is converted to "select * from T 
where t1 <3"
+     * 10) "select * from T where t1 <2.1" is converted to "select * from T 
where t1 <=2"
      * 11) "select * from T where t1> 2.0" is converted to "select * from T 
where t1> 2"
      * 12) "select * from T where t1> 2.1" is converted to "select * from T 
where t1> 2"
      */
-    private Expr rewriteBigintSlotRefCompareDecimalLiteral(Expr expr0, Expr 
expr1, BinaryPredicate.Operator op)
-            throws AnalysisException {
-        if (((DecimalLiteral) expr1).getDoubleValue() % (int) 
(((DecimalLiteral) expr1).getDoubleValue()) != 0) {
-            if (op == BinaryPredicate.Operator.EQ || op == 
BinaryPredicate.Operator.EQ_FOR_NULL) {
-                return new BoolLiteral(false);
-            } else if (op == BinaryPredicate.Operator.NE) {
+    private Expr rewriteBigintSlotRefCompareDecimalLiteral(Expr expr0, 
DecimalLiteral expr1,
+            BinaryPredicate.Operator op) {
+        Type columnType = expr0.getSrcSlotRef().getColumn().getType();
+        try {
+            // Convert childExpr to column type and compare the converted 
values. There are 3 possible situations:
+            // case 1. The value of childExpr exceeds the range of the column 
type, then castTo() will throw an
+            //   exception. For example, the value of childExpr is 128.0 and 
the column type is tinyint.
+            // case 2. childExpr is converted to column type, but the value of 
childExpr loses precision.
+            //   For example, 2.1 is converted to 2;
+            // case 3. childExpr is precisely converted to column type. For 
example, 2.0 is converted to 2.
+            LiteralExpr newExpr = (LiteralExpr) expr1.castTo(columnType);
+            int compResult = expr1.compareLiteral(newExpr);
+            // case 2
+            if (compResult != 0) {
+                if (op == Operator.EQ || op == Operator.EQ_FOR_NULL) {
+                    return new BoolLiteral(false);
+                } else if (op == Operator.NE) {
+                    return new BoolLiteral(true);
+                }
+
+                if (compResult > 0) {
+                    if (op == Operator.LT) {
+                        op = Operator.LE;
+                    } else if (op == Operator.GE) {
+                        op = Operator.GT;
+                    }
+                } else {
+                    if (op == Operator.LE) {
+                        op = Operator.LT;
+                    } else if (op == Operator.GT) {
+                        op = Operator.GE;
+                    }
+                }
+            }
+            // case 3
+            return new BinaryPredicate(op, expr0.castTo(columnType), newExpr);
+        } catch (AnalysisException e) {
+            // case 1
+            IntLiteral colTypeMinValue = IntLiteral.createMinValue(columnType);
+            IntLiteral colTypeMaxValue = IntLiteral.createMaxValue(columnType);
+            if (op == Operator.NE || ((expr1).compareLiteral(colTypeMinValue) 
< 0 && (op == Operator.GE
+                    || op == Operator.GT)) || 
((expr1).compareLiteral(colTypeMaxValue) > 0 && (op == Operator.LE
+                    || op == Operator.LT))) {
                 return new BoolLiteral(true);
-            } else if (op == BinaryPredicate.Operator.LE) {
-                ((DecimalLiteral) expr1).roundCeiling();
-                op = BinaryPredicate.Operator.LT;
-            } else if (op == BinaryPredicate.Operator.GE) {
-                ((DecimalLiteral) expr1).roundFloor();
-                op = BinaryPredicate.Operator.GT;
-            } else if (op == BinaryPredicate.Operator.LT) {
-                ((DecimalLiteral) expr1).roundCeiling();
-            } else if (op == BinaryPredicate.Operator.GT) {
-                ((DecimalLiteral) expr1).roundFloor();
             }
+            return new BoolLiteral(false);
         }
-        expr0 = expr0.getChild(0);
-        expr1 = expr1.castTo(Type.BIGINT);
-        return new BinaryPredicate(op, expr0, expr1);
     }
 
     @Override
@@ -95,7 +123,7 @@ public class RewriteBinaryPredicatesRule implements 
ExprRewriteRule {
         Expr expr1 = expr.getChild(1);
         if (expr0 instanceof CastExpr && expr0.getType() == Type.DECIMALV2 && 
expr0.getChild(0) instanceof SlotRef
                 && expr0.getChild(0).getType().getResultType() == Type.BIGINT 
&& expr1 instanceof DecimalLiteral) {
-            return rewriteBigintSlotRefCompareDecimalLiteral(expr0, expr1, op);
+            return rewriteBigintSlotRefCompareDecimalLiteral(expr0, 
(DecimalLiteral) expr1, op);
         }
         return expr;
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
index 5577a566be..9083fea261 100755
--- a/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/analysis/SelectStmtTest.java
@@ -272,25 +272,27 @@ public class SelectStmtTest {
                 + "   );";
         SelectStmt stmt = (SelectStmt) UtFrameUtils.parseAndAnalyzeStmt(sql, 
ctx);
         stmt.rewriteExprs(new Analyzer(ctx.getCatalog(), 
ctx).getExprRewriter());
-        String rewritedFragment1 = "((`t1`.`k2` = `t4`.`k2` AND `t3`.`k3` = 
`t1`.`k3` "
-                + "AND ((`t1`.`k4` >= 50 AND `t1`.`k4` <= 200) AND "
-                + "(`t3`.`k1` = 'D' OR `t3`.`k1` = 'S' OR `t3`.`k1` = 'W') "
-                + "AND (`t4`.`k3` = '2 yr Degree' OR `t4`.`k3` = 'Advanced 
Degree' OR `t4`.`k3` = 'Secondary') "
-                + "AND (`t4`.`k4` = 1 OR `t4`.`k4` = 3))) "
-                + "AND ((`t3`.`k1` = 'D' AND `t4`.`k3` = '2 yr Degree' "
-                + "AND `t1`.`k4` >= 100 AND `t1`.`k4` <= 150 AND `t4`.`k4` = 
3) "
-                + "OR (`t3`.`k1` = 'S' AND `t4`.`k3` = 'Secondary' AND 
`t1`.`k4` >= 50 "
-                + "AND `t1`.`k4` <= 100 AND `t4`.`k4` = 1) OR (`t3`.`k1` = 'W' 
AND `t4`.`k3` = 'Advanced Degree' "
-                + "AND `t1`.`k4` >= 150 AND `t1`.`k4` <= 200 AND `t4`.`k4` = 
1)))";
-        String rewritedFragment2 = "((`t1`.`k1` = `t5`.`k1` AND `t5`.`k2` = 
'United States' "
-                + "AND ((`t1`.`k4` >= 50 AND `t1`.`k4` <= 300) "
-                + "AND `t5`.`k3` IN ('CO', 'IL', 'MN', 'OH', 'MT', 'NM', 'TX', 
'MO', 'MI'))) "
-                + "AND ((`t5`.`k3` IN ('CO', 'IL', 'MN') AND `t1`.`k4` >= 100 
AND `t1`.`k4` <= 200) "
-                + "OR (`t5`.`k3` IN ('OH', 'MT', 'NM') AND `t1`.`k4` >= 150 
AND `t1`.`k4` <= 300) OR (`t5`.`k3` IN "
-                + "('TX', 'MO', 'MI') AND `t1`.`k4` >= 50 AND `t1`.`k4` <= 
250)))";
-        System.out.println(stmt.toSql());
-        Assert.assertTrue(stmt.toSql().contains(rewritedFragment1));
-        Assert.assertTrue(stmt.toSql().contains(rewritedFragment2));
+        String commonExpr1 = "`t1`.`k2` = `t4`.`k2`";
+        String commonExpr2 = "`t3`.`k3` = `t1`.`k3`";
+        String commonExpr3 = "`t1`.`k1` = `t5`.`k1`";
+        String commonExpr4 = "t5`.`k2` = 'United States'";
+        String betweenExpanded1 = "`t1`.`k4` >= 100 AND `t1`.`k4` <= 150";
+        String betweenExpanded2 = "`t1`.`k4` >= 50 AND `t1`.`k4` <= 100";
+        String betweenExpanded3 = "`t1`.`k4` >= 50 AND `t1`.`k4` <= 250";
+
+        String rewrittenSql = stmt.toSql();
+        System.out.println(rewrittenSql);
+        Assert.assertTrue(rewrittenSql.contains(commonExpr1));
+        Assert.assertEquals(rewrittenSql.indexOf(commonExpr1), 
rewrittenSql.lastIndexOf(commonExpr1));
+        Assert.assertTrue(rewrittenSql.contains(commonExpr2));
+        Assert.assertEquals(rewrittenSql.indexOf(commonExpr2), 
rewrittenSql.lastIndexOf(commonExpr2));
+        Assert.assertTrue(rewrittenSql.contains(commonExpr3));
+        Assert.assertEquals(rewrittenSql.indexOf(commonExpr3), 
rewrittenSql.lastIndexOf(commonExpr3));
+        Assert.assertTrue(rewrittenSql.contains(commonExpr4));
+        Assert.assertEquals(rewrittenSql.indexOf(commonExpr4), 
rewrittenSql.lastIndexOf(commonExpr4));
+        Assert.assertTrue(rewrittenSql.contains(betweenExpanded1));
+        Assert.assertTrue(rewrittenSql.contains(betweenExpanded2));
+        Assert.assertTrue(rewrittenSql.contains(betweenExpanded3));
 
         String sql2 = "select\n"
                 + "   avg(t1.k4)\n"
diff --git a/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java 
b/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java
index d52314518a..5b50e1972f 100644
--- a/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java
+++ b/fe/fe-core/src/test/java/org/apache/doris/planner/PlannerTest.java
@@ -422,11 +422,11 @@ public class PlannerTest extends TestWithFeService {
         compare.accept("select * from db1.tbl2 where k1 != 2.0", "select * 
from db1.tbl2 where k1 != 2");
         compare.accept("select * from db1.tbl2 where k1 != 2.1", "select * 
from db1.tbl2");
         compare.accept("select * from db1.tbl2 where k1 <= 2.0", "select * 
from db1.tbl2 where k1 <= 2");
-        compare.accept("select * from db1.tbl2 where k1 <= 2.1", "select * 
from db1.tbl2 where k1 < 3");
+        compare.accept("select * from db1.tbl2 where k1 <= 2.1", "select * 
from db1.tbl2 where k1 <= 2");
         compare.accept("select * from db1.tbl2 where k1 >= 2.0", "select * 
from db1.tbl2 where k1 >= 2");
         compare.accept("select * from db1.tbl2 where k1 >= 2.1", "select * 
from db1.tbl2 where k1 > 2");
         compare.accept("select * from db1.tbl2 where k1 < 2.0", "select * from 
db1.tbl2 where k1 < 2");
-        compare.accept("select * from db1.tbl2 where k1 < 2.1", "select * from 
db1.tbl2 where k1 < 3");
+        compare.accept("select * from db1.tbl2 where k1 < 2.1", "select * from 
db1.tbl2 where k1 <= 2");
         compare.accept("select * from db1.tbl2 where k1 > 2.0", "select * from 
db1.tbl2 where k1 > 2");
         compare.accept("select * from db1.tbl2 where k1 > 2.1", "select * from 
db1.tbl2 where k1 > 2");
     }
diff --git 
a/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRuleTest.java
 
b/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRuleTest.java
new file mode 100644
index 0000000000..0b06502e11
--- /dev/null
+++ 
b/fe/fe-core/src/test/java/org/apache/doris/rewrite/RewriteBinaryPredicatesRuleTest.java
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+package org.apache.doris.rewrite;
+
+import org.apache.doris.analysis.BinaryPredicate;
+import org.apache.doris.analysis.BinaryPredicate.Operator;
+import org.apache.doris.analysis.BoolLiteral;
+import org.apache.doris.analysis.Expr;
+import org.apache.doris.analysis.LiteralExpr;
+import org.apache.doris.analysis.SelectStmt;
+import org.apache.doris.catalog.PrimitiveType;
+import org.apache.doris.qe.StmtExecutor;
+import org.apache.doris.utframe.TestWithFeService;
+
+import org.junit.jupiter.api.Assertions;
+import org.junit.jupiter.api.Test;
+
+public class RewriteBinaryPredicatesRuleTest extends TestWithFeService {
+    @Override
+    protected void runBeforeAll() throws Exception {
+        connectContext = createDefaultCtx();
+        createDatabase("db");
+        useDatabase("db");
+        String createTable = "create table table1(id smallint, cost bigint 
sum) "
+                + "aggregate key(`id`) distributed by hash (`id`) buckets 4 "
+                + "properties (\"replication_num\"=\"1\");";
+        createTable(createTable);
+    }
+
+    @Test
+    public void testNormal() throws Exception {
+        testBase(Operator.EQ, "2.0", Operator.EQ, 2L);
+        testBoolean(Operator.EQ, "2.5", false);
+
+        testBase(Operator.NE, "2.0", Operator.NE, 2L);
+        testBoolean(Operator.NE, "2.5", true);
+
+        testBase(Operator.LE, "2.0", Operator.LE, 2L);
+        testBase(Operator.LE, "-2.5", Operator.LT, -2L);
+        testBase(Operator.LE, "2.5", Operator.LE, 2L);
+
+        testBase(Operator.GE, "2.0", Operator.GE, 2L);
+        testBase(Operator.GE, "-2.5", Operator.GE, -2L);
+        testBase(Operator.GE, "2.5", Operator.GT, 2L);
+
+        testBase(Operator.LT, "2.0", Operator.LT, 2L);
+        testBase(Operator.LT, "-2.5", Operator.LT, -2L);
+        testBase(Operator.LT, "2.5", Operator.LE, 2L);
+
+        testBase(Operator.GT, "2.0", Operator.GT, 2L);
+        testBase(Operator.GT, "-2.5", Operator.GE, -2L);
+        testBase(Operator.GT, "2.5", Operator.GT, 2L);
+    }
+
+    @Test
+    public void testOutOfRange() throws Exception {
+        // 32767 -32768
+        testBoolean(Operator.EQ, "-32769.0", false);
+        testBase(Operator.EQ, "32767.0", Operator.EQ, 32767L);
+
+        testBoolean(Operator.NE, "32768.0", true);
+
+        testBoolean(Operator.LE, "32768.2", true);
+        testBoolean(Operator.LE, "-32769.1", false);
+        testBase(Operator.LE, "32767.0", Operator.LE, 32767L);
+
+        testBoolean(Operator.GE, "32768.1", false);
+        testBoolean(Operator.GE, "-32769.1", true);
+        testBase(Operator.GE, "32767.0", Operator.GE, 32767L);
+
+        testBoolean(Operator.LT, "32768.1", true);
+        testBoolean(Operator.LT, "-32769.1", false);
+        testBase(Operator.LT, "32767.1", Operator.LE, 32767L);
+
+        testBoolean(Operator.GT, "32768.1", false);
+        testBoolean(Operator.GT, "-32769.1", true);
+        testBase(Operator.GT, "32767.0", Operator.GT, 32767L);
+    }
+
+    private void testBase(Operator operator, String queryLiteral, Operator 
expectedOperator, long expectedChild1)
+            throws Exception {
+        Expr expr1 = getExpr(operator, queryLiteral);
+        Assertions.assertTrue(expr1 instanceof BinaryPredicate);
+        Assertions.assertEquals(expectedOperator, ((BinaryPredicate) 
expr1).getOp());
+        Assertions.assertEquals(PrimitiveType.SMALLINT, 
expr1.getChild(0).getType().getPrimitiveType());
+        Assertions.assertEquals(PrimitiveType.SMALLINT, 
expr1.getChild(1).getType().getPrimitiveType());
+        Assertions.assertEquals(expectedChild1, ((LiteralExpr) 
expr1.getChild(1)).getLongValue());
+    }
+
+    private void testBoolean(Operator operator, String queryLiteral, boolean 
result) throws Exception {
+        Expr expr1 = getExpr(operator, queryLiteral);
+        Assertions.assertTrue(expr1 instanceof BoolLiteral);
+        Assertions.assertEquals(result, ((BoolLiteral) expr1).getValue());
+    }
+
+    private Expr getExpr(Operator operator, String queryLiteral) throws 
Exception {
+        String queryFormat = "select * from table1 where id %s %s;";
+        String query = String.format(queryFormat, operator.toString(), 
queryLiteral);
+        StmtExecutor executor1 = getSqlStmtExecutor(query);
+        Assertions.assertNotNull(executor1);
+        return ((SelectStmt) executor1.getParsedStmt()).getWhereClause();
+    }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: commits-unsubscr...@doris.apache.org
For additional commands, e-mail: commits-h...@doris.apache.org

Reply via email to