This is an automated email from the ASF dual-hosted git repository.

parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new 74b41aa57 fix: handle inf/-inf/nan in ShimSparkErrorConverter cast 
overflow (#3768)
74b41aa57 is described below

commit 74b41aa57924e5dbb3d021221efdff20be734c8f
Author: Manu Zhang <[email protected]>
AuthorDate: Sat Mar 28 00:43:10 2026 +0800

    fix: handle inf/-inf/nan in ShimSparkErrorConverter cast overflow (#3768)
    
    Normalize inf/nan literals for float/double cast overflow conversion across 
Spark 3.4/3.5/4.0 and add unit tests in SparkErrorConverterSuite for 
float/double inf/-inf/nan.
    
    Co-authored-by: Codex <[email protected]>
---
 .github/workflows/pr_build_linux.yml               |   1 +
 .github/workflows/pr_build_macos.yml               |   1 +
 .../sql/comet/shims/ShimSparkErrorConverter.scala  |  23 ++++-
 .../sql/comet/shims/ShimSparkErrorConverter.scala  |  23 ++++-
 .../sql/comet/shims/ShimSparkErrorConverter.scala  |  23 ++++-
 .../apache/comet/SparkErrorConverterSuite.scala    | 104 +++++++++++++++++++++
 6 files changed, 169 insertions(+), 6 deletions(-)

diff --git a/.github/workflows/pr_build_linux.yml 
b/.github/workflows/pr_build_linux.yml
index 899fa6139..6811d6c2b 100644
--- a/.github/workflows/pr_build_linux.yml
+++ b/.github/workflows/pr_build_linux.yml
@@ -338,6 +338,7 @@ jobs:
               org.apache.comet.CometCsvExpressionSuite
               org.apache.comet.CometJsonExpressionSuite
               org.apache.comet.CometDateTimeUtilsSuite
+              org.apache.comet.SparkErrorConverterSuite
               org.apache.comet.expressions.conditional.CometIfSuite
               org.apache.comet.expressions.conditional.CometCoalesceSuite
               org.apache.comet.expressions.conditional.CometCaseWhenSuite
diff --git a/.github/workflows/pr_build_macos.yml 
b/.github/workflows/pr_build_macos.yml
index 53001b04e..8362a6cfb 100644
--- a/.github/workflows/pr_build_macos.yml
+++ b/.github/workflows/pr_build_macos.yml
@@ -213,6 +213,7 @@ jobs:
               org.apache.comet.CometJsonExpressionSuite
               org.apache.comet.CometCsvExpressionSuite
               org.apache.comet.CometDateTimeUtilsSuite
+              org.apache.comet.SparkErrorConverterSuite
               org.apache.comet.expressions.conditional.CometIfSuite
               org.apache.comet.expressions.conditional.CometCoalesceSuite
               org.apache.comet.expressions.conditional.CometCaseWhenSuite
diff --git 
a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
 
b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index 6eee3f5bc..ba37f8c94 100644
--- 
a/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++ 
b/spark/src/main/spark-3.4/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -44,6 +44,25 @@ trait ShimSparkErrorConverter {
   private def sqlCtx(context: Array[QueryContext]): SQLQueryContext =
     context.headOption.map(_.asInstanceOf[SQLQueryContext]).getOrElse(null)
 
+  private def parseFloatLiteral(value: String): Float = {
+    value.toLowerCase match {
+      case "inf" | "+inf" | "infinity" | "+infinity" => Float.PositiveInfinity
+      case "-inf" | "-infinity" => Float.NegativeInfinity
+      case "nan" | "+nan" | "-nan" => Float.NaN
+      case _ => value.toFloat
+    }
+  }
+
+  private def parseDoubleLiteral(value: String): Double = {
+    val normalized = value.toLowerCase.stripSuffix("d")
+    normalized match {
+      case "inf" | "+inf" | "infinity" | "+infinity" => Double.PositiveInfinity
+      case "-inf" | "-infinity" => Double.NegativeInfinity
+      case "nan" | "+nan" | "-nan" => Double.NaN
+      case _ => normalized.toDouble
+    }
+  }
+
   def convertErrorType(
       errorType: String,
       errorClass: String,
@@ -207,8 +226,8 @@ trait ShimSparkErrorConverter {
           case LongType =>
             val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) 
else valueStr
             cleanStr.toLong
-          case FloatType => valueStr.toFloat
-          case DoubleType => valueStr.toDouble
+          case FloatType => parseFloatLiteral(valueStr)
+          case DoubleType => parseDoubleLiteral(valueStr)
           case StringType => UTF8String.fromString(valueStr)
           case _ => valueStr
         }
diff --git 
a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
 
b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index 75316c51e..1d140e190 100644
--- 
a/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++ 
b/spark/src/main/spark-3.5/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -44,6 +44,25 @@ trait ShimSparkErrorConverter {
   private def sqlCtx(context: Array[QueryContext]): SQLQueryContext =
     context.headOption.map(_.asInstanceOf[SQLQueryContext]).getOrElse(null)
 
+  private def parseFloatLiteral(value: String): Float = {
+    value.toLowerCase match {
+      case "inf" | "+inf" | "infinity" | "+infinity" => Float.PositiveInfinity
+      case "-inf" | "-infinity" => Float.NegativeInfinity
+      case "nan" | "+nan" | "-nan" => Float.NaN
+      case _ => value.toFloat
+    }
+  }
+
+  private def parseDoubleLiteral(value: String): Double = {
+    val normalized = value.toLowerCase.stripSuffix("d")
+    normalized match {
+      case "inf" | "+inf" | "infinity" | "+infinity" => Double.PositiveInfinity
+      case "-inf" | "-infinity" => Double.NegativeInfinity
+      case "nan" | "+nan" | "-nan" => Double.NaN
+      case _ => normalized.toDouble
+    }
+  }
+
   def convertErrorType(
       errorType: String,
       errorClass: String,
@@ -205,8 +224,8 @@ trait ShimSparkErrorConverter {
           case LongType =>
             val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) 
else valueStr
             cleanStr.toLong
-          case FloatType => valueStr.toFloat
-          case DoubleType => valueStr.toDouble
+          case FloatType => parseFloatLiteral(valueStr)
+          case DoubleType => parseDoubleLiteral(valueStr)
           case StringType => UTF8String.fromString(valueStr)
           case _ => valueStr
         }
diff --git 
a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
 
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
index fc13a58a4..a787fb801 100644
--- 
a/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
+++ 
b/spark/src/main/spark-4.0/org/apache/spark/sql/comet/shims/ShimSparkErrorConverter.scala
@@ -37,6 +37,25 @@ object ShimSparkErrorConverter {
  */
 trait ShimSparkErrorConverter {
 
+  private def parseFloatLiteral(value: String): Float = {
+    value.toLowerCase match {
+      case "inf" | "+inf" | "infinity" | "+infinity" => Float.PositiveInfinity
+      case "-inf" | "-infinity" => Float.NegativeInfinity
+      case "nan" | "+nan" | "-nan" => Float.NaN
+      case _ => value.toFloat
+    }
+  }
+
+  private def parseDoubleLiteral(value: String): Double = {
+    val normalized = value.toLowerCase.stripSuffix("d")
+    normalized match {
+      case "inf" | "+inf" | "infinity" | "+infinity" => Double.PositiveInfinity
+      case "-inf" | "-infinity" => Double.NegativeInfinity
+      case "nan" | "+nan" | "-nan" => Double.NaN
+      case _ => normalized.toDouble
+    }
+  }
+
   /**
    * Convert error type string and parameters to appropriate Spark exception. 
Version-specific
    * implementations call the correct QueryExecutionErrors.* methods.
@@ -213,8 +232,8 @@ trait ShimSparkErrorConverter {
             // Strip "L" suffix for BIGINT literals
             val cleanStr = if (valueStr.endsWith("L")) valueStr.dropRight(1) 
else valueStr
             cleanStr.toLong
-          case FloatType => valueStr.toFloat
-          case DoubleType => valueStr.toDouble
+          case FloatType => parseFloatLiteral(valueStr)
+          case DoubleType => parseDoubleLiteral(valueStr)
           case StringType => UTF8String.fromString(valueStr)
           case _ => valueStr // Fallback to string
         }
diff --git 
a/spark/src/test/scala/org/apache/comet/SparkErrorConverterSuite.scala 
b/spark/src/test/scala/org/apache/comet/SparkErrorConverterSuite.scala
new file mode 100644
index 000000000..d3e2c2c64
--- /dev/null
+++ b/spark/src/test/scala/org/apache/comet/SparkErrorConverterSuite.scala
@@ -0,0 +1,104 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+
+package org.apache.comet
+
+import org.scalatest.funsuite.AnyFunSuite
+
+class SparkErrorConverterSuite extends AnyFunSuite {
+  private def castOverflowError(fromType: String, value: String): Throwable = {
+    SparkErrorConverter
+      .convertErrorType(
+        "CastOverFlow",
+        "CAST_OVERFLOW",
+        Map("fromType" -> fromType, "toType" -> "INT", "value" -> value),
+        Array.empty,
+        null)
+      .getOrElse(fail("Expected CastOverFlow to be converted to a Spark 
exception"))
+  }
+
+  private def assertCastOverflowContains(
+      fromType: String,
+      value: String,
+      expectedMessagePart: String): Unit = {
+    val err = castOverflowError(fromType, value)
+    assert(
+      !err.isInstanceOf[NumberFormatException],
+      s"Unexpected parse failure for $fromType $value")
+    assert(
+      err.getMessage.contains(expectedMessagePart),
+      s"Expected '${err.getMessage}' to contain '$expectedMessagePart' for 
$fromType $value")
+  }
+
+  private def assertCastOverflowContainsNaN(fromType: String, value: String): 
Unit = {
+    val err = castOverflowError(fromType, value)
+    assert(
+      !err.isInstanceOf[NumberFormatException],
+      s"Unexpected parse failure for $fromType $value")
+    assert(
+      err.getMessage.toLowerCase.contains("nan"),
+      s"Expected '${err.getMessage}' to contain NaN for $fromType $value")
+  }
+
+  test("CastOverFlow conversion handles all float positive infinity literals") 
{
+    Seq("inf", "+inf", "infinity", "+infinity").foreach { value =>
+      assertCastOverflowContains("FLOAT", value, "Infinity")
+    }
+  }
+
+  test("CastOverFlow conversion handles all float negative infinity literals") 
{
+    Seq("-inf", "-infinity").foreach { value =>
+      assertCastOverflowContains("FLOAT", value, "-Infinity")
+    }
+  }
+
+  test("CastOverFlow conversion handles all float NaN literals") {
+    Seq("nan", "+nan", "-nan").foreach { value =>
+      assertCastOverflowContainsNaN("FLOAT", value)
+    }
+  }
+
+  test("CastOverFlow conversion handles float standard numeric literal 
fallback") {
+    assertCastOverflowContains("FLOAT", "1.5", "1.5")
+  }
+
+  test("CastOverFlow conversion handles all double positive infinity 
literals") {
+    Seq("inf", "infd", "+inf", "+infd", "infinity", "infinityd", "+infinity", 
"+infinityd")
+      .foreach { value =>
+        assertCastOverflowContains("DOUBLE", value, "Infinity")
+      }
+  }
+
+  test("CastOverFlow conversion handles all double negative infinity 
literals") {
+    Seq("-inf", "-infd", "-infinity", "-infinityd").foreach { value =>
+      assertCastOverflowContains("DOUBLE", value, "-Infinity")
+    }
+  }
+
+  test("CastOverFlow conversion handles all double NaN literals") {
+    Seq("nan", "nand", "+nan", "+nand", "-nan", "-nand").foreach { value =>
+      assertCastOverflowContainsNaN("DOUBLE", value)
+    }
+  }
+
+  test("CastOverFlow conversion handles double standard numeric literal 
fallback") {
+    assertCastOverflowContains("DOUBLE", "1.5", "1.5")
+    assertCastOverflowContains("DOUBLE", "1.5d", "1.5")
+  }
+}


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to