This is an automated email from the ASF dual-hosted git repository.

gortiz pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/pinot.git


The following commit(s) were added to refs/heads/master by this push:
     new 2188ca2b81d Fix CAST evaluation with literal-only operands in MSQE 
(#16421)
2188ca2b81d is described below

commit 2188ca2b81dc4e83dbad6cbeef528dbe031f349b
Author: Yash Mayya <[email protected]>
AuthorDate: Fri Jul 25 14:38:35 2025 +0530

    Fix CAST evaluation with literal-only operands in MSQE (#16421)
---
 .../scalar/DataTypeConversionFunctions.java        | 66 +++++++++++-----------
 .../tests/ErrorCodesIntegrationTest.java           |  2 +-
 .../rel/rules/PinotEvaluateLiteralRule.java        | 13 ++++-
 .../resources/queries/LiteralEvaluationPlans.json  | 32 ++++++++++-
 .../resources/queries/PhysicalOptimizerPlans.json  | 26 ++++-----
 5 files changed, 88 insertions(+), 51 deletions(-)

diff --git 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
index 347031ec864..c18ffed8129 100644
--- 
a/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
+++ 
b/pinot-common/src/main/java/org/apache/pinot/common/function/scalar/DataTypeConversionFunctions.java
@@ -38,43 +38,43 @@ public class DataTypeConversionFunctions {
 
   @ScalarFunction
   public static Object cast(Object value, String targetTypeLiteral) {
-    try {
-      Class<?> clazz = value.getClass();
-      // TODO: Support cast for MV
-      Preconditions.checkArgument(!clazz.isArray() | clazz == byte[].class, 
"%s must not be an array type", clazz);
-      PinotDataType sourceType = PinotDataType.getSingleValueType(clazz);
-      String transformed = targetTypeLiteral.toUpperCase();
-      PinotDataType targetDataType;
-      switch (transformed) {
-        case "BIGINT":
-          targetDataType = LONG;
-          break;
-        case "DECIMAL":
-          targetDataType = BIG_DECIMAL;
-          break;
-        case "INT":
-          targetDataType = INTEGER;
-          break;
-        case "VARBINARY":
-          targetDataType = BYTES;
-          break;
-        case "VARCHAR":
-          targetDataType = STRING;
-          break;
-        default:
+    Class<?> clazz = value.getClass();
+    // TODO: Support cast for MV
+    Preconditions.checkArgument(!clazz.isArray() | clazz == byte[].class, "%s 
must not be an array type", clazz);
+    PinotDataType sourceType = PinotDataType.getSingleValueType(clazz);
+    String transformed = targetTypeLiteral.toUpperCase();
+    PinotDataType targetDataType;
+    switch (transformed) {
+      case "BIGINT":
+        targetDataType = LONG;
+        break;
+      case "DECIMAL":
+        targetDataType = BIG_DECIMAL;
+        break;
+      case "INT":
+        targetDataType = INTEGER;
+        break;
+      case "VARBINARY":
+        targetDataType = BYTES;
+        break;
+      case "VARCHAR":
+        targetDataType = STRING;
+        break;
+      default:
+        try {
           targetDataType = PinotDataType.valueOf(transformed);
-          break;
-      }
-      if (sourceType == STRING && (targetDataType == INTEGER || targetDataType 
== LONG)) {
-        if (String.valueOf(value).contains(".")) {
-          // convert integers via double to avoid parse errors
-          return targetDataType.convert(DOUBLE.convert(value, sourceType), 
DOUBLE);
+        } catch (IllegalArgumentException e) {
+          throw new IllegalArgumentException("Unknown data type: " + 
targetTypeLiteral);
         }
+        break;
+    }
+    if (sourceType == STRING && (targetDataType == INTEGER || targetDataType 
== LONG)) {
+      if (String.valueOf(value).contains(".")) {
+        // convert integers via double to avoid parse errors
+        return targetDataType.convert(DOUBLE.convert(value, sourceType), 
DOUBLE);
       }
-      return targetDataType.convert(value, sourceType);
-    } catch (IllegalArgumentException e) {
-      throw new IllegalArgumentException("Unknown data type: " + 
targetTypeLiteral);
     }
+    return targetDataType.convert(value, sourceType);
   }
 
   /**
diff --git 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/ErrorCodesIntegrationTest.java
 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/ErrorCodesIntegrationTest.java
index 383a71fd1d4..069e3fa382d 100644
--- 
a/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/ErrorCodesIntegrationTest.java
+++ 
b/pinot-integration-tests/src/test/java/org/apache/pinot/integration/tests/ErrorCodesIntegrationTest.java
@@ -138,7 +138,7 @@ public abstract class ErrorCodesIntegrationTest extends 
BaseClusterIntegrationTe
       throws Exception {
     // ArrTime expects a numeric type
     testQueryException("SELECT COUNT(*) FROM mytable where ArrTime = 'potato'",
-        useMultiStageQueryEngine() ? QueryErrorCode.QUERY_EXECUTION : 
QueryErrorCode.QUERY_VALIDATION);
+        useMultiStageQueryEngine() ? QueryErrorCode.QUERY_PLANNING : 
QueryErrorCode.QUERY_VALIDATION);
   }
 
   @Test
diff --git 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java
 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java
index ca634a04c2e..fea69ab5d21 100644
--- 
a/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java
+++ 
b/pinot-query-planner/src/main/java/org/apache/pinot/calcite/rel/rules/PinotEvaluateLiteralRule.java
@@ -161,15 +161,14 @@ public class PinotEvaluateLiteralRule {
     assert operands.stream().allMatch(
         operand -> operand instanceof RexLiteral || (operand instanceof 
RexCall && ((RexCall) operand).getOperands()
             .stream().allMatch(op -> op instanceof RexLiteral)));
+
     int numArguments = operands.size();
     ColumnDataType[] argumentTypes = new ColumnDataType[numArguments];
     Object[] arguments = new Object[numArguments];
     for (int i = 0; i < numArguments; i++) {
       RexNode rexNode = operands.get(i);
       RexLiteral rexLiteral;
-      if (rexNode instanceof RexCall && ((RexCall) 
rexNode).getOperator().getKind() == SqlKind.CAST) {
-        rexLiteral = (RexLiteral) ((RexCall) rexNode).getOperands().get(0);
-      } else if (rexNode instanceof RexLiteral) {
+      if (rexNode instanceof RexLiteral) {
         rexLiteral = (RexLiteral) rexNode;
       } else {
         // Function operands cannot be evaluated, skip
@@ -178,6 +177,14 @@ public class PinotEvaluateLiteralRule {
       argumentTypes[i] = 
RelToPlanNodeConverter.convertToColumnDataType(rexLiteral.getType());
       arguments[i] = getLiteralValue(rexLiteral);
     }
+
+    if (rexCall.getKind() == SqlKind.CAST) {
+      // Handle separately because the CAST operator only has one operand (the 
value to be cast) and the type to be cast
+      // to is determined by the operator's return type. Pinot's CAST function 
implementation requires two arguments:
+      // the value to be cast and the target type.
+      argumentTypes = new ColumnDataType[]{argumentTypes[0], 
ColumnDataType.STRING};
+      arguments = new Object[]{arguments[0], 
RelToPlanNodeConverter.convertToColumnDataType(rexCall.getType()).name()};
+    }
     String canonicalName = 
FunctionRegistry.canonicalize(PinotRuleUtils.extractFunctionName(rexCall));
     FunctionInfo functionInfo = 
FunctionRegistry.lookupFunctionInfo(canonicalName, argumentTypes);
     if (functionInfo == null) {
diff --git 
a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json 
b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
index f23e0662e07..e63b4aab8a6 100644
--- a/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/LiteralEvaluationPlans.json
@@ -45,7 +45,7 @@
         "sql": "EXPLAIN PLAN FOR SELECT timestampDiff(DAY, CAST(ts as 
TIMESTAMP), CAST(dateTrunc('MONTH', FROMDATETIME('1997-02-01 00:00:00', 
'yyyy-MM-dd HH:mm:ss')) as TIMESTAMP)) FROM d",
         "output": [
           "Execution Plan",
-          "\nLogicalProject(EXPR$0=[TIMESTAMPDIFF(FLAG(DAY), 
CAST($7):TIMESTAMP(0) NOT NULL, CAST(854755200000:BIGINT):TIMESTAMP(0) NOT 
NULL)])",
+          "\nLogicalProject(EXPR$0=[TIMESTAMPDIFF(FLAG(DAY), 
CAST($7):TIMESTAMP(0) NOT NULL, 1997-02-01 00:00:00)])",
           "\n  PinotLogicalTableScan(table=[[default, d]])",
           "\n"
         ]
@@ -253,6 +253,36 @@
           "\n"
         ]
       },
+      {
+        "description": "filter with implicit cast",
+        "sql": "EXPLAIN PLAN FOR SELECT * FROM a WHERE col3 > '10.5'",
+        "output": [
+          "Execution Plan",
+          "\nLogicalFilter(condition=[>($2, 10)])",
+          "\n  PinotLogicalTableScan(table=[[default, a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "filter with explicit cast",
+        "sql": "EXPLAIN PLAN FOR SELECT * FROM a WHERE col3 > CAST('10.5' AS 
INT)",
+        "output": [
+          "Execution Plan",
+          "\nLogicalFilter(condition=[>($2, 10)])",
+          "\n  PinotLogicalTableScan(table=[[default, a]])",
+          "\n"
+        ]
+      },
+      {
+        "description": "filter with nested explicit casts",
+        "sql": "EXPLAIN PLAN FOR SELECT * FROM a WHERE col3 > CAST(CAST('10.5' 
AS LONG) AS INT)",
+        "output": [
+          "Execution Plan",
+          "\nLogicalFilter(condition=[>($2, 10)])",
+          "\n  PinotLogicalTableScan(table=[[default, a]])",
+          "\n"
+        ]
+      },
       {
         "description": "select non-exist literal function",
         "sql": "EXPLAIN PLAN FOR Select nonExistFun(1,2) FROM a",
diff --git 
a/pinot-query-planner/src/test/resources/queries/PhysicalOptimizerPlans.json 
b/pinot-query-planner/src/test/resources/queries/PhysicalOptimizerPlans.json
index ca7dc973074..106d0e7f9a9 100644
--- a/pinot-query-planner/src/test/resources/queries/PhysicalOptimizerPlans.json
+++ b/pinot-query-planner/src/test/resources/queries/PhysicalOptimizerPlans.json
@@ -466,7 +466,7 @@
     "queries": [
       {
         "description": "Self semi-joins",
-        "sql": "SET usePhysicalOptimizer=true; EXPLAIN PLAN FOR SELECT col1, 
col2 FROM a WHERE col2 IN (SELECT col2 FROM a WHERE col3 = 'foo') AND col2 IN 
(SELECT col2 FROM a WHERE col3 = 'bar')",
+        "sql": "SET usePhysicalOptimizer=true; EXPLAIN PLAN FOR SELECT col1, 
col2 FROM a WHERE col2 IN (SELECT col2 FROM a WHERE col3 = '1') AND col2 IN 
(SELECT col2 FROM a WHERE col3 = '2')",
         "output": [
           "Execution Plan",
           "\nPhysicalExchange(exchangeStrategy=[SINGLETON_EXCHANGE])",
@@ -477,18 +477,18 @@
           "\n          PhysicalTableScan(table=[[default, a]])",
           "\n      PhysicalExchange(exchangeStrategy=[IDENTITY_EXCHANGE])",
           "\n        PhysicalProject(col2=[$1])",
-          "\n          PhysicalFilter(condition=[=($2, 
CAST(_UTF-8'foo'):INTEGER NOT NULL)])",
+          "\n          PhysicalFilter(condition=[=($2, 1)])",
           "\n            PhysicalTableScan(table=[[default, a]])",
           "\n    PhysicalExchange(exchangeStrategy=[IDENTITY_EXCHANGE])",
           "\n      PhysicalProject(col2=[$1])",
-          "\n        PhysicalFilter(condition=[=($2, CAST(_UTF-8'bar'):INTEGER 
NOT NULL)])",
+          "\n        PhysicalFilter(condition=[=($2, 2)])",
           "\n          PhysicalTableScan(table=[[default, a]])",
           "\n"
         ]
       },
       {
         "description": "Self semi and anti semi-joins",
-        "sql": "SET usePhysicalOptimizer=true; EXPLAIN PLAN FOR SELECT col1, 
col2 FROM a WHERE col2 IN (SELECT col2 FROM a WHERE col3 = 'foo') AND col2 NOT 
IN (SELECT col2 FROM a WHERE col3 = 'bar') AND col2 IN (SELECT col2 FROM a 
WHERE col3 = 'lorem')",
+        "sql": "SET usePhysicalOptimizer=true; EXPLAIN PLAN FOR SELECT col1, 
col2 FROM a WHERE col2 IN (SELECT col2 FROM a WHERE col3 = '1') AND col2 NOT IN 
(SELECT col2 FROM a WHERE col3 = '2') AND col2 IN (SELECT col2 FROM a WHERE 
col3 = '3')",
         "output": [
           "Execution Plan",
           "\nPhysicalExchange(exchangeStrategy=[SINGLETON_EXCHANGE])",
@@ -503,23 +503,23 @@
           "\n                  PhysicalTableScan(table=[[default, a]])",
           "\n              
PhysicalExchange(exchangeStrategy=[IDENTITY_EXCHANGE])",
           "\n                PhysicalProject(col2=[$1])",
-          "\n                  PhysicalFilter(condition=[=($2, 
CAST(_UTF-8'foo'):INTEGER NOT NULL)])",
+          "\n                  PhysicalFilter(condition=[=($2, 1)])",
           "\n                    PhysicalTableScan(table=[[default, a]])",
           "\n          PhysicalExchange(exchangeStrategy=[IDENTITY_EXCHANGE])",
           "\n            PhysicalAggregate(group=[{0}], agg#0=[MIN($1)], 
aggType=[DIRECT])",
           "\n              PhysicalProject(col2=[$1], $f1=[true])",
-          "\n                PhysicalFilter(condition=[=($2, 
CAST(_UTF-8'bar'):INTEGER NOT NULL)])",
+          "\n                PhysicalFilter(condition=[=($2, 2)])",
           "\n                  PhysicalTableScan(table=[[default, a]])",
           "\n    PhysicalExchange(exchangeStrategy=[IDENTITY_EXCHANGE])",
           "\n      PhysicalProject(col2=[$1])",
-          "\n        PhysicalFilter(condition=[=($2, 
CAST(_UTF-8'lorem'):INTEGER NOT NULL)])",
+          "\n        PhysicalFilter(condition=[=($2, 3)])",
           "\n          PhysicalTableScan(table=[[default, a]])",
           "\n"
         ]
       },
       {
         "description": "Self semi and anti semi-joins with aggregation in the 
end",
-        "sql": "SET usePhysicalOptimizer=true; EXPLAIN PLAN FOR SELECT col1, 
COUNT(*) FROM a WHERE col2 IN (SELECT col2 FROM a WHERE col3 = 'foo') AND col2 
NOT IN (SELECT col2 FROM a WHERE col3 = 'bar') AND col2 IN (SELECT col2 FROM a 
WHERE col3 = 'lorem') GROUP BY col1",
+        "sql": "SET usePhysicalOptimizer=true; EXPLAIN PLAN FOR SELECT col1, 
COUNT(*) FROM a WHERE col2 IN (SELECT col2 FROM a WHERE col3 = '1') AND col2 
NOT IN (SELECT col2 FROM a WHERE col3 = '2') AND col2 IN (SELECT col2 FROM a 
WHERE col3 = '3') GROUP BY col1",
         "output": [
           "Execution Plan",
           "\nPhysicalExchange(exchangeStrategy=[SINGLETON_EXCHANGE])",
@@ -537,16 +537,16 @@
           "\n                        PhysicalTableScan(table=[[default, a]])",
           "\n                    
PhysicalExchange(exchangeStrategy=[IDENTITY_EXCHANGE])",
           "\n                      PhysicalProject(col2=[$1])",
-          "\n                        PhysicalFilter(condition=[=($2, 
CAST(_UTF-8'foo'):INTEGER NOT NULL)])",
+          "\n                        PhysicalFilter(condition=[=($2, 1)])",
           "\n                          PhysicalTableScan(table=[[default, 
a]])",
           "\n                
PhysicalExchange(exchangeStrategy=[IDENTITY_EXCHANGE])",
           "\n                  PhysicalAggregate(group=[{0}], agg#0=[MIN($1)], 
aggType=[DIRECT])",
           "\n                    PhysicalProject(col2=[$1], $f1=[true])",
-          "\n                      PhysicalFilter(condition=[=($2, 
CAST(_UTF-8'bar'):INTEGER NOT NULL)])",
+          "\n                      PhysicalFilter(condition=[=($2, 2)])",
           "\n                        PhysicalTableScan(table=[[default, a]])",
           "\n          PhysicalExchange(exchangeStrategy=[IDENTITY_EXCHANGE])",
           "\n            PhysicalProject(col2=[$1])",
-          "\n              PhysicalFilter(condition=[=($2, 
CAST(_UTF-8'lorem'):INTEGER NOT NULL)])",
+          "\n              PhysicalFilter(condition=[=($2, 3)])",
           "\n                PhysicalTableScan(table=[[default, a]])",
           "\n"
         ]
@@ -613,7 +613,7 @@
     "queries": [
       {
         "description": "Union, distinct, etc. but still maximally identity 
exchange",
-        "sql": "SET usePhysicalOptimizer=true; EXPLAIN PLAN FOR WITH tmp AS 
(SELECT col2 FROM a WHERE col1 = 'foo' UNION ALL SELECT col2 FROM a WHERE col3 
= 'bar'), tmp2 AS (SELECT DISTINCT col2 FROM tmp) SELECT COUNT(*), col3 FROM a 
WHERE col2 IN (SELECT col2 FROM tmp2) GROUP BY col3",
+        "sql": "SET usePhysicalOptimizer=true; EXPLAIN PLAN FOR WITH tmp AS 
(SELECT col2 FROM a WHERE col1 = 'foo' UNION ALL SELECT col2 FROM a WHERE col3 
= '1'), tmp2 AS (SELECT DISTINCT col2 FROM tmp) SELECT COUNT(*), col3 FROM a 
WHERE col2 IN (SELECT col2 FROM tmp2) GROUP BY col3",
         "output": [
           "Execution Plan",
           "\nPhysicalExchange(exchangeStrategy=[SINGLETON_EXCHANGE])",
@@ -633,7 +633,7 @@
           "\n                      PhysicalTableScan(table=[[default, a]])",
           "\n                
PhysicalExchange(exchangeStrategy=[IDENTITY_EXCHANGE])",
           "\n                  PhysicalProject(col2=[$1])",
-          "\n                    PhysicalFilter(condition=[=($2, 
CAST(_UTF-8'bar'):INTEGER NOT NULL)])",
+          "\n                    PhysicalFilter(condition=[=($2, 1)])",
           "\n                      PhysicalTableScan(table=[[default, a]])",
           "\n"
         ]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to