This is an automated email from the ASF dual-hosted git repository.

parthc pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git


The following commit(s) were added to refs/heads/main by this push:
     new c3dd3a4c2 fix: handle scalar decimal value overflow correctly in ANSI 
mode (#3803)
c3dd3a4c2 is described below

commit c3dd3a4c291f2c89ef043283c84aaaf96a39e8e4
Author: Parth Chandra <[email protected]>
AuthorDate: Fri Mar 27 10:47:01 2026 -0700

    fix: handle scalar decimal value overflow correctly in ANSI mode (#3803)
    
    * fix: handle scalar decimal value overflow correctly.
---
 .../src/math_funcs/internal/checkoverflow.rs       | 183 +++++++++++++++++++--
 .../org/apache/comet/CometExpressionSuite.scala    |  29 ++++
 2 files changed, 196 insertions(+), 16 deletions(-)

diff --git a/native/spark-expr/src/math_funcs/internal/checkoverflow.rs 
b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
index a9e8f6748..f1fb9c2f0 100644
--- a/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
+++ b/native/spark-expr/src/math_funcs/internal/checkoverflow.rs
@@ -199,22 +199,38 @@ impl PhysicalExpr for CheckOverflow {
                 Ok(ColumnarValue::Array(new_array))
             }
             ColumnarValue::Scalar(ScalarValue::Decimal128(v, precision, 
scale)) => {
-                // `fail_on_error` is only true when ANSI is enabled, which we 
don't support yet
-                // (Java side will simply fallback to Spark when it is enabled)
-                assert!(
-                    !self.fail_on_error,
-                    "fail_on_error (ANSI mode) is not supported yet"
-                );
-
-                let new_v: Option<i128> = v.and_then(|v| {
-                    Decimal128Type::validate_decimal_precision(v, precision, 
scale)
-                        .map(|_| v)
-                        .ok()
-                });
-
-                Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
-                    new_v, precision, scale,
-                )))
+                if self.fail_on_error {
+                    if let Some(val) = v {
+                        Decimal128Type::validate_decimal_precision(val, 
precision, scale).map_err(
+                            |_| {
+                                let spark_error =
+                                    crate::error::decimal_overflow_error(val, 
precision, scale);
+                                if let Some(ctx) = &self.query_context {
+                                    DataFusionError::External(Box::new(
+                                        
crate::SparkErrorWithContext::with_context(
+                                            spark_error,
+                                            Arc::clone(ctx),
+                                        ),
+                                    ))
+                                } else {
+                                    
DataFusionError::External(Box::new(spark_error))
+                                }
+                            },
+                        )?;
+                    }
+                    Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+                        v, precision, scale,
+                    )))
+                } else {
+                    let new_v: Option<i128> = v.and_then(|v| {
+                        Decimal128Type::validate_decimal_precision(v, 
precision, scale)
+                            .map(|_| v)
+                            .ok()
+                    });
+                    Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+                        new_v, precision, scale,
+                    )))
+                }
             }
             v => Err(DataFusionError::Execution(format!(
                 "CheckOverflow's child expression should be decimal array, but 
found {v:?}"
@@ -239,3 +255,138 @@ impl PhysicalExpr for CheckOverflow {
         )))
     }
 }
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use arrow::datatypes::{Field, Schema};
+    use arrow::record_batch::RecordBatch;
+    use std::fmt::{Display, Formatter};
+
+    /// Helper that always returns a fixed Decimal128 scalar.
+    #[derive(Debug, Eq, PartialEq, Hash)]
+    struct ScalarChild(Option<i128>, u8, i8);
+
+    impl Display for ScalarChild {
+        fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+            write!(f, "ScalarChild({:?})", self.0)
+        }
+    }
+
+    impl PhysicalExpr for ScalarChild {
+        fn as_any(&self) -> &dyn Any {
+            self
+        }
+        fn data_type(&self, _: &Schema) -> 
datafusion::common::Result<DataType> {
+            Ok(DataType::Decimal128(self.1, self.2))
+        }
+        fn nullable(&self, _: &Schema) -> datafusion::common::Result<bool> {
+            Ok(true)
+        }
+        fn evaluate(&self, _: &RecordBatch) -> 
datafusion::common::Result<ColumnarValue> {
+            Ok(ColumnarValue::Scalar(ScalarValue::Decimal128(
+                self.0, self.1, self.2,
+            )))
+        }
+        fn children(&self) -> Vec<&Arc<dyn PhysicalExpr>> {
+            vec![]
+        }
+        fn with_new_children(
+            self: Arc<Self>,
+            _: Vec<Arc<dyn PhysicalExpr>>,
+        ) -> datafusion::common::Result<Arc<dyn PhysicalExpr>> {
+            Ok(self)
+        }
+        fn fmt_sql(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
+            Display::fmt(self, f)
+        }
+    }
+
+    fn empty_batch() -> RecordBatch {
+        let schema = Schema::new(vec![Field::new("x", DataType::Decimal128(38, 
0), true)]);
+        RecordBatch::new_empty(Arc::new(schema))
+    }
+
+    fn make_check_overflow(
+        value: Option<i128>,
+        precision: u8,
+        scale: i8,
+        fail_on_error: bool,
+    ) -> CheckOverflow {
+        CheckOverflow::new(
+            Arc::new(ScalarChild(value, precision, scale)),
+            DataType::Decimal128(precision, scale),
+            fail_on_error,
+            None,
+            None,
+        )
+    }
+
+    // --- scalar, fail_on_error = false (legacy mode) ---
+
+    #[test]
+    fn test_scalar_no_overflow_legacy() {
+        // 999 fits in precision 3, scale 0 → returned as-is
+        let expr = make_check_overflow(Some(999), 3, 0, false);
+        let result = expr.evaluate(&empty_batch()).unwrap();
+        match result {
+            ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) => 
assert_eq!(v, Some(999)),
+            other => panic!("unexpected: {other:?}"),
+        }
+    }
+
+    #[test]
+    fn test_scalar_overflow_returns_null_in_legacy_mode() {
+        // 1000 does not fit in precision 3 → null, no error
+        let expr = make_check_overflow(Some(1000), 3, 0, false);
+        let result = expr.evaluate(&empty_batch()).unwrap();
+        match result {
+            ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) => 
assert_eq!(v, None),
+            other => panic!("unexpected: {other:?}"),
+        }
+    }
+
+    #[test]
+    fn test_scalar_null_passthrough_legacy() {
+        let expr = make_check_overflow(None, 3, 0, false);
+        let result = expr.evaluate(&empty_batch()).unwrap();
+        match result {
+            ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) => 
assert_eq!(v, None),
+            other => panic!("unexpected: {other:?}"),
+        }
+    }
+
+    // --- scalar, fail_on_error = true (ANSI mode) ---
+
+    #[test]
+    fn test_scalar_no_overflow_ansi() {
+        // 999 fits in precision 3 → returned as-is, no error
+        let expr = make_check_overflow(Some(999), 3, 0, true);
+        let result = expr.evaluate(&empty_batch()).unwrap();
+        match result {
+            ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) => 
assert_eq!(v, Some(999)),
+            other => panic!("unexpected: {other:?}"),
+        }
+    }
+
+    #[test]
+    fn test_scalar_overflow_returns_error_in_ansi_mode() {
+        // 1000 does not fit in precision 3 → error, not Ok(None)
+        // This is the case that previously panicked with "fail_on_error (ANSI 
mode) is not
+        // supported yet".
+        let expr = make_check_overflow(Some(1000), 3, 0, true);
+        let result = expr.evaluate(&empty_batch());
+        assert!(result.is_err(), "expected error on overflow in ANSI mode");
+    }
+
+    #[test]
+    fn test_scalar_null_passthrough_ansi() {
+        // None input → None output even in ANSI mode (no value to overflow)
+        let expr = make_check_overflow(None, 3, 0, true);
+        let result = expr.evaluate(&empty_batch()).unwrap();
+        match result {
+            ColumnarValue::Scalar(ScalarValue::Decimal128(v, 3, 0)) => 
assert_eq!(v, None),
+            other => panic!("unexpected: {other:?}"),
+        }
+    }
+}
diff --git a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala 
b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
index 68c1a82f1..9fdd5a677 100644
--- a/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometExpressionSuite.scala
@@ -1271,6 +1271,35 @@ class CometExpressionSuite extends CometTestBase with 
AdaptiveSparkPlanHelper {
     }
   }
 
+  test("scalar decimal overflow - legacy mode produces null") {
+    // 1.1e19 * 1.1e19 = 1.21e38 fits in i128 (max ~1.7e38) but exceeds 
DECIMAL(38,0)'s
+    // max of 10^38-1, so CheckOverflow nulls the result in legacy (non-ANSI) 
mode.
+    withSQLConf(CometConf.COMET_ENABLED.key -> "true", 
SQLConf.ANSI_ENABLED.key -> "false") {
+      withParquetTable(Seq((BigDecimal("11000000000000000000"), 0)), "tbl") {
+        checkSparkAnswerAndOperator("SELECT _1 * _1 FROM tbl")
+      }
+    }
+  }
+
+  test("scalar decimal overflow - ANSI mode throws ArithmeticException") {
+    // 1.1e19 * 1.1e19 = 1.21e38 overflows DECIMAL(38,0). With ANSI mode on, 
both Spark and
+    // Comet must throw — Comet must not panic or silently return null. Spark 
reports
+    // NUMERIC_VALUE_OUT_OF_RANGE; Comet's WideDecimalBinaryExpr catches the 
overflow first
+    // and surfaces it as an arithmetic overflow error.
+    withSQLConf(CometConf.COMET_ENABLED.key -> "true", 
SQLConf.ANSI_ENABLED.key -> "true") {
+      withParquetTable(Seq((BigDecimal("11000000000000000000"), 0)), "tbl") {
+        val res = sql("SELECT _1 * _1 FROM tbl")
+        checkSparkAnswerMaybeThrows(res) match {
+          case (Some(sparkExc), Some(cometExc)) =>
+            assert(sparkExc.getMessage.contains("NUMERIC_VALUE_OUT_OF_RANGE"))
+            assert(cometExc.getMessage.toLowerCase.contains("overflow"))
+          case _ =>
+            fail("Expected exception for decimal overflow in ANSI mode")
+        }
+      }
+    }
+  }
+
   test("cast decimals to int") {
     Seq(16, 1024).foreach { batchSize =>
       withSQLConf(


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to