This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new d0b154bd5 bug: Fix string decimal type throw right exception (#3248)
d0b154bd5 is described below

commit d0b154bd5e524c58bfcbfeee26b8fc6621b91463
Author: B Vadlamani <[email protected]>
AuthorDate: Sat Jan 31 16:34:23 2026 -0800

    bug: Fix string decimal type throw right exception (#3248)
---
 native/spark-expr/benches/cast_from_string.rs      |  70 +++++++++
 native/spark-expr/src/conversion_funcs/cast.rs     | 166 ++++++++++++---------
 native/spark-expr/src/error.rs                     |   3 +
 .../scala/org/apache/comet/CometCastSuite.scala    |  35 +++--
 4 files changed, 194 insertions(+), 80 deletions(-)

diff --git a/native/spark-expr/benches/cast_from_string.rs 
b/native/spark-expr/benches/cast_from_string.rs
index a09afae6e..9b2cb73fb 100644
--- a/native/spark-expr/benches/cast_from_string.rs
+++ b/native/spark-expr/benches/cast_from_string.rs
@@ -68,6 +68,31 @@ fn criterion_benchmark(c: &mut Criterion) {
         b.iter(|| cast_to_i64.evaluate(&decimal_batch).unwrap());
     });
     group.finish();
+
+    // str -> decimal benchmark
+    let decimal_string_batch = create_decimal_cast_string_batch();
+    for (mode, mode_name) in [
+        (EvalMode::Legacy, "legacy"),
+        (EvalMode::Ansi, "ansi"),
+        (EvalMode::Try, "try"),
+    ] {
+        let spark_cast_options = SparkCastOptions::new(mode, "", false);
+        let cast_to_decimal_38_10 = Cast::new(
+            expr.clone(),
+            DataType::Decimal128(38, 10),
+            spark_cast_options,
+        );
+
+        let mut group = c.benchmark_group(format!("cast_string_to_decimal/{}", 
mode_name));
+        group.bench_function("decimal_38_10", |b| {
+            b.iter(|| {
+                cast_to_decimal_38_10
+                    .evaluate(&decimal_string_batch)
+                    .unwrap()
+            });
+        });
+        group.finish();
+    }
 }
 
 /// Create batch with small integer strings that fit in i8 range (for i8/i16 
benchmarks)
@@ -118,6 +143,51 @@ fn create_decimal_string_batch() -> RecordBatch {
     RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
 }
 
+/// Create batch with decimal strings for string-to-decimal cast perf 
evaluation
+fn create_decimal_cast_string_batch() -> RecordBatch {
+    let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8, 
true)]));
+    let mut b = StringBuilder::new();
+    for i in 0..1000 {
+        if i % 10 == 0 {
+            b.append_null();
+        } else {
+            // Generate various decimal formats
+            match i % 5 {
+                0 => {
+                    // gen simple decimals (ex :  "123.45"
+                    let int_part: u32 = rand::random::<u32>() % 1000000;
+                    let dec_part: u32 = rand::random::<u32>() % 100000;
+                    b.append_value(format!("{}.{}", int_part, dec_part));
+                }
+                1 => {
+                    // gen scientific notation like "123e5"
+                    let mantissa: u32 = rand::random::<u32>() % 1000;
+                    let exp: i8 = (rand::random::<i8>() % 10).abs();
+                    b.append_value(format!("{}.{}E{}", mantissa / 100, 
mantissa % 100, exp));
+                }
+                2 => {
+                    // Negative numbers
+                    let int_part: u32 = rand::random::<u32>() % 1000000;
+                    let dec_part: u32 = rand::random::<u32>() % 100000;
+                    b.append_value(format!("-{}.{}", int_part, dec_part));
+                }
+                3 => {
+                    // Ints only
+                    let val: i32 = rand::random::<i32>() % 1000000;
+                    b.append_value(format!("{}", val));
+                }
+                _ => {
+                    // Small decimals (ex : 0.001)
+                    let dec_part: u32 = rand::random::<u32>() % 100000;
+                    b.append_value(format!("0.{:05}", dec_part));
+                }
+            }
+        }
+    }
+    let array = b.finish();
+    RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
+}
+
 fn config() -> Criterion {
     Criterion::default()
 }
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs 
b/native/spark-expr/src/conversion_funcs/cast.rs
index 9ccfc3e6a..186a10c9a 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -2321,8 +2321,8 @@ fn cast_string_to_decimal256_impl(
 }
 
 /// Parse a string to decimal following Spark's behavior
-fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) -> 
SparkResult<Option<i128>> {
-    let string_bytes = s.as_bytes();
+fn parse_string_to_decimal(input_str: &str, precision: u8, scale: i8) -> 
SparkResult<Option<i128>> {
+    let string_bytes = input_str.as_bytes();
     let mut start = 0;
     let mut end = string_bytes.len();
 
@@ -2334,7 +2334,7 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale: 
i8) -> SparkResult<Opt
         end -= 1;
     }
 
-    let trimmed = &s[start..end];
+    let trimmed = &input_str[start..end];
 
     if trimmed.is_empty() {
         return Ok(None);
@@ -2351,73 +2351,101 @@ fn parse_string_to_decimal(s: &str, precision: u8, 
scale: i8) -> SparkResult<Opt
         return Ok(None);
     }
 
-    // validate and parse mantissa and exponent
-    match parse_decimal_str(trimmed) {
-        Ok((mantissa, exponent)) => {
-            // Convert to target scale
-            let target_scale = scale as i32;
-            let scale_adjustment = target_scale - exponent;
+    // validate and parse mantissa and exponent or bubble up the error
+    let (mantissa, exponent) = parse_decimal_str(trimmed, input_str, 
precision, scale)?;
 
-            let scaled_value = if scale_adjustment >= 0 {
-                // Need to multiply (increase scale) but return None if scale 
is too high to fit i128
-                if scale_adjustment > 38 {
-                    return Ok(None);
-                }
-                mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
-            } else {
-                // Need to multiply (increase scale) but return None if scale 
is too high to fit i128
-                let abs_scale_adjustment = (-scale_adjustment) as u32;
-                if abs_scale_adjustment > 38 {
-                    return Ok(Some(0));
-                }
+    // Early return mantissa 0, Spark checks if it fits digits and throw error 
in ansi
+    if mantissa == 0 {
+        if exponent < -37 {
+            return Err(SparkError::NumericOutOfRange {
+                value: input_str.to_string(),
+            });
+        }
+        return Ok(Some(0));
+    }
 
-                let divisor = 10_i128.pow(abs_scale_adjustment);
-                let quotient_opt = mantissa.checked_div(divisor);
-                // Check if divisor is 0
-                if quotient_opt.is_none() {
-                    return Ok(None);
-                }
-                let quotient = quotient_opt.unwrap();
-                let remainder = mantissa % divisor;
-
-                // Round half up: if abs(remainder) >= divisor/2, round away 
from zero
-                let half_divisor = divisor / 2;
-                let rounded = if remainder.abs() >= half_divisor {
-                    if mantissa >= 0 {
-                        quotient + 1
-                    } else {
-                        quotient - 1
-                    }
-                } else {
-                    quotient
-                };
-                Some(rounded)
-            };
+    // scale adjustment
+    let target_scale = scale as i32;
+    let scale_adjustment = target_scale - exponent;
 
-            match scaled_value {
-                Some(value) => {
-                    // Check if it fits target precision
-                    if is_validate_decimal_precision(value, precision) {
-                        Ok(Some(value))
-                    } else {
-                        Ok(None)
-                    }
-                }
-                None => {
-                    // Overflow while scaling
-                    Ok(None)
-                }
+    let scaled_value = if scale_adjustment >= 0 {
+        // Need to multiply (increase scale) but return None if scale is too 
high to fit i128
+        if scale_adjustment > 38 {
+            return Ok(None);
+        }
+        mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
+    } else {
+        // Need to divide (decrease scale)
+        let abs_scale_adjustment = (-scale_adjustment) as u32;
+        if abs_scale_adjustment > 38 {
+            return Ok(Some(0));
+        }
+
+        let divisor = 10_i128.pow(abs_scale_adjustment);
+        let quotient_opt = mantissa.checked_div(divisor);
+        // Check if divisor is 0
+        if quotient_opt.is_none() {
+            return Ok(None);
+        }
+        let quotient = quotient_opt.unwrap();
+        let remainder = mantissa % divisor;
+
+        // Round half up: if abs(remainder) >= divisor/2, round away from zero
+        let half_divisor = divisor / 2;
+        let rounded = if remainder.abs() >= half_divisor {
+            if mantissa >= 0 {
+                quotient + 1
+            } else {
+                quotient - 1
+            }
+        } else {
+            quotient
+        };
+        Some(rounded)
+    };
+
+    match scaled_value {
+        Some(value) => {
+            if is_validate_decimal_precision(value, precision) {
+                Ok(Some(value))
+            } else {
+                // Value ok but exceeds precision mentioned . THrow error
+                Err(SparkError::NumericValueOutOfRange {
+                    value: trimmed.to_string(),
+                    precision,
+                    scale,
+                })
             }
         }
-        Err(_) => Ok(None),
+        None => {
+            // Overflow when scaling raise exception
+            Err(SparkError::NumericValueOutOfRange {
+                value: trimmed.to_string(),
+                precision,
+                scale,
+            })
+        }
     }
 }
 
+fn invalid_decimal_cast(value: &str, precision: u8, scale: i8) -> SparkError {
+    invalid_value(
+        value,
+        "STRING",
+        &format!("DECIMAL({},{})", precision, scale),
+    )
+}
+
 /// Parse a decimal string into mantissa and scale
-/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
-fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
+/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) , 0e50 -> (0,50) etc
+fn parse_decimal_str(
+    s: &str,
+    original_str: &str,
+    precision: u8,
+    scale: i8,
+) -> SparkResult<(i128, i32)> {
     if s.is_empty() {
-        return Err("Empty string".to_string());
+        return Err(invalid_decimal_cast(original_str, precision, scale));
     }
 
     let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e', 
'E'].contains(&c)) {
@@ -2426,7 +2454,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), 
String> {
         // Parse exponent
         let exp: i32 = exponent_part
             .parse()
-            .map_err(|e| format!("Invalid exponent: {}", e))?;
+            .map_err(|_| invalid_decimal_cast(original_str, precision, 
scale))?;
 
         (mantissa_part, exp)
     } else {
@@ -2441,13 +2469,13 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), 
String> {
     };
 
     if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
-        return Err("Invalid sign format".to_string());
+        return Err(invalid_decimal_cast(original_str, precision, scale));
     }
 
     let (integral_part, fractional_part) = match mantissa_str.find('.') {
         Some(dot_pos) => {
             if mantissa_str[dot_pos + 1..].contains('.') {
-                return Err("Multiple decimal points".to_string());
+                return Err(invalid_decimal_cast(original_str, precision, 
scale));
             }
             (&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
         }
@@ -2455,15 +2483,15 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), 
String> {
     };
 
     if integral_part.is_empty() && fractional_part.is_empty() {
-        return Err("No digits found".to_string());
+        return Err(invalid_decimal_cast(original_str, precision, scale));
     }
 
     if !integral_part.is_empty() && !integral_part.bytes().all(|b| 
b.is_ascii_digit()) {
-        return Err("Invalid integral part".to_string());
+        return Err(invalid_decimal_cast(original_str, precision, scale));
     }
 
     if !fractional_part.is_empty() && !fractional_part.bytes().all(|b| 
b.is_ascii_digit()) {
-        return Err("Invalid fractional part".to_string());
+        return Err(invalid_decimal_cast(original_str, precision, scale));
     }
 
     // Parse integral part
@@ -2473,7 +2501,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), 
String> {
     } else {
         integral_part
             .parse()
-            .map_err(|_| "Invalid integral part".to_string())?
+            .map_err(|_| invalid_decimal_cast(original_str, precision, scale))?
     };
 
     // Parse fractional part
@@ -2483,14 +2511,14 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32), 
String> {
     } else {
         fractional_part
             .parse()
-            .map_err(|_| "Invalid fractional part".to_string())?
+            .map_err(|_| invalid_decimal_cast(original_str, precision, scale))?
     };
 
     // Combine: value = integral * 10^fractional_scale + fractional
     let mantissa = integral_value
         .checked_mul(10_i128.pow(fractional_scale as u32))
         .and_then(|v| v.checked_add(fractional_value))
-        .ok_or("Overflow in mantissa calculation")?;
+        .ok_or_else(|| invalid_decimal_cast(original_str, precision, scale))?;
 
     let final_mantissa = if negative { -mantissa } else { mantissa };
     // final scale = fractional_scale - exponent
diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs
index 4b00b70eb..c39a05cd4 100644
--- a/native/spark-expr/src/error.rs
+++ b/native/spark-expr/src/error.rs
@@ -39,6 +39,9 @@ pub enum SparkError {
         scale: i8,
     },
 
+    #[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be 
interpreted as a numeric since it has more than 38 digits.")]
+    NumericOutOfRange { value: String },
+
     #[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\" 
cannot be cast to \"{to_type}\" \
         due to an overflow. Use `try_cast` to tolerate overflow and return 
NULL instead. If necessary \
         set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 26bb810b7..269925be4 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -33,7 +33,6 @@ import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.internal.SQLConf
 import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType, 
DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType, 
StructField, StructType}
 
-import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
 import org.apache.comet.expressions.{CometCast, CometEvalMode}
 import org.apache.comet.rules.CometScanTypeChecker
 import org.apache.comet.serde.Compatible
@@ -709,8 +708,6 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   test("cast StringType to DecimalType(10,2) (does not support fullwidth 
unicode digits)") {
     withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
-      // TODO fix for Spark 4.0.0
-      assume(!isSpark40Plus)
       val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
       Seq(true, false).foreach(ansiEnabled =>
         castTest(values, DataTypes.createDecimalType(10, 2), testAnsi = 
ansiEnabled))
@@ -719,18 +716,38 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   test("cast StringType to DecimalType(2,2)") {
     withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
-      // TODO fix for Spark 4.0.0
-      assume(!isSpark40Plus)
       val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
       Seq(true, false).foreach(ansiEnabled =>
         castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = 
ansiEnabled))
     }
   }
 
+  test("cast StringType to DecimalType check if right exception message is 
thrown") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
+      val values = Seq("d11307\n").toDF("a")
+      Seq(true, false).foreach(ansiEnabled =>
+        castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = 
ansiEnabled))
+    }
+  }
+
+  test("cast StringType to DecimalType(2,2) check if right exception is being 
thrown") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
+      val values = gen.generateInts(10000).map("    " + _).toDF("a")
+      Seq(true, false).foreach(ansiEnabled =>
+        castTest(values, DataTypes.createDecimalType(2, 2), testAnsi = 
ansiEnabled))
+    }
+  }
+
+  test("cast StringType to DecimalType(38,10) high precision - check 0 
mantissa") {
+    withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
+      val values = Seq("0e31", "000e3375", "0e40", "0E+695", 
"0e5887677").toDF("a")
+      Seq(true, false).foreach(ansiEnabled =>
+        castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = 
ansiEnabled))
+    }
+  }
+
   test("cast StringType to DecimalType(38,10) high precision") {
     withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
-      // TODO fix for Spark 4.0.0
-      assume(!isSpark40Plus)
       val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
       Seq(true, false).foreach(ansiEnabled =>
         castTest(values, DataTypes.createDecimalType(38, 10), testAnsi = 
ansiEnabled))
@@ -739,8 +756,6 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   test("cast StringType to DecimalType(10,2) basic values") {
     withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
-      // TODO fix for Spark 4.0.0
-      assume(!isSpark40Plus)
       val values = Seq(
         "123.45",
         "-67.89",
@@ -766,8 +781,6 @@ class CometCastSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
 
   test("cast StringType to Decimal type scientific notation") {
     withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) -> 
"true") {
-      // TODO fix for Spark 4.0.0
-      assume(!isSpark40Plus)
       val values = Seq(
         "1.23E-5",
         "1.23e10",


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to