This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/datafusion-comet.git
The following commit(s) were added to refs/heads/main by this push:
new d0b154bd5 bug: Fix string decimal type throw right exception (#3248)
d0b154bd5 is described below
commit d0b154bd5e524c58bfcbfeee26b8fc6621b91463
Author: B Vadlamani <[email protected]>
AuthorDate: Sat Jan 31 16:34:23 2026 -0800
bug: Fix string decimal type throw right exception (#3248)
---
native/spark-expr/benches/cast_from_string.rs | 70 +++++++++
native/spark-expr/src/conversion_funcs/cast.rs | 166 ++++++++++++---------
native/spark-expr/src/error.rs | 3 +
.../scala/org/apache/comet/CometCastSuite.scala | 35 +++--
4 files changed, 194 insertions(+), 80 deletions(-)
diff --git a/native/spark-expr/benches/cast_from_string.rs
b/native/spark-expr/benches/cast_from_string.rs
index a09afae6e..9b2cb73fb 100644
--- a/native/spark-expr/benches/cast_from_string.rs
+++ b/native/spark-expr/benches/cast_from_string.rs
@@ -68,6 +68,31 @@ fn criterion_benchmark(c: &mut Criterion) {
b.iter(|| cast_to_i64.evaluate(&decimal_batch).unwrap());
});
group.finish();
+
+ // str -> decimal benchmark
+ let decimal_string_batch = create_decimal_cast_string_batch();
+ for (mode, mode_name) in [
+ (EvalMode::Legacy, "legacy"),
+ (EvalMode::Ansi, "ansi"),
+ (EvalMode::Try, "try"),
+ ] {
+ let spark_cast_options = SparkCastOptions::new(mode, "", false);
+ let cast_to_decimal_38_10 = Cast::new(
+ expr.clone(),
+ DataType::Decimal128(38, 10),
+ spark_cast_options,
+ );
+
+ let mut group = c.benchmark_group(format!("cast_string_to_decimal/{}",
mode_name));
+ group.bench_function("decimal_38_10", |b| {
+ b.iter(|| {
+ cast_to_decimal_38_10
+ .evaluate(&decimal_string_batch)
+ .unwrap()
+ });
+ });
+ group.finish();
+ }
}
/// Create batch with small integer strings that fit in i8 range (for i8/i16
benchmarks)
@@ -118,6 +143,51 @@ fn create_decimal_string_batch() -> RecordBatch {
RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
}
+/// Create batch with decimal strings for string-to-decimal cast perf
evaluation
+fn create_decimal_cast_string_batch() -> RecordBatch {
+ let schema = Arc::new(Schema::new(vec![Field::new("a", DataType::Utf8,
true)]));
+ let mut b = StringBuilder::new();
+ for i in 0..1000 {
+ if i % 10 == 0 {
+ b.append_null();
+ } else {
+ // Generate various decimal formats
+ match i % 5 {
+ 0 => {
+ // gen simple decimals (ex : "123.45"
+ let int_part: u32 = rand::random::<u32>() % 1000000;
+ let dec_part: u32 = rand::random::<u32>() % 100000;
+ b.append_value(format!("{}.{}", int_part, dec_part));
+ }
+ 1 => {
+ // gen scientific notation like "123e5"
+ let mantissa: u32 = rand::random::<u32>() % 1000;
+ let exp: i8 = (rand::random::<i8>() % 10).abs();
+ b.append_value(format!("{}.{}E{}", mantissa / 100,
mantissa % 100, exp));
+ }
+ 2 => {
+ // Negative numbers
+ let int_part: u32 = rand::random::<u32>() % 1000000;
+ let dec_part: u32 = rand::random::<u32>() % 100000;
+ b.append_value(format!("-{}.{}", int_part, dec_part));
+ }
+ 3 => {
+ // Ints only
+ let val: i32 = rand::random::<i32>() % 1000000;
+ b.append_value(format!("{}", val));
+ }
+ _ => {
+ // Small decimals (ex : 0.001)
+ let dec_part: u32 = rand::random::<u32>() % 100000;
+ b.append_value(format!("0.{:05}", dec_part));
+ }
+ }
+ }
+ }
+ let array = b.finish();
+ RecordBatch::try_new(schema, vec![Arc::new(array)]).unwrap()
+}
+
fn config() -> Criterion {
Criterion::default()
}
diff --git a/native/spark-expr/src/conversion_funcs/cast.rs
b/native/spark-expr/src/conversion_funcs/cast.rs
index 9ccfc3e6a..186a10c9a 100644
--- a/native/spark-expr/src/conversion_funcs/cast.rs
+++ b/native/spark-expr/src/conversion_funcs/cast.rs
@@ -2321,8 +2321,8 @@ fn cast_string_to_decimal256_impl(
}
/// Parse a string to decimal following Spark's behavior
-fn parse_string_to_decimal(s: &str, precision: u8, scale: i8) ->
SparkResult<Option<i128>> {
- let string_bytes = s.as_bytes();
+fn parse_string_to_decimal(input_str: &str, precision: u8, scale: i8) ->
SparkResult<Option<i128>> {
+ let string_bytes = input_str.as_bytes();
let mut start = 0;
let mut end = string_bytes.len();
@@ -2334,7 +2334,7 @@ fn parse_string_to_decimal(s: &str, precision: u8, scale:
i8) -> SparkResult<Opt
end -= 1;
}
- let trimmed = &s[start..end];
+ let trimmed = &input_str[start..end];
if trimmed.is_empty() {
return Ok(None);
@@ -2351,73 +2351,101 @@ fn parse_string_to_decimal(s: &str, precision: u8,
scale: i8) -> SparkResult<Opt
return Ok(None);
}
- // validate and parse mantissa and exponent
- match parse_decimal_str(trimmed) {
- Ok((mantissa, exponent)) => {
- // Convert to target scale
- let target_scale = scale as i32;
- let scale_adjustment = target_scale - exponent;
+ // validate and parse mantissa and exponent or bubble up the error
+ let (mantissa, exponent) = parse_decimal_str(trimmed, input_str,
precision, scale)?;
- let scaled_value = if scale_adjustment >= 0 {
- // Need to multiply (increase scale) but return None if scale
is too high to fit i128
- if scale_adjustment > 38 {
- return Ok(None);
- }
- mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
- } else {
- // Need to multiply (increase scale) but return None if scale
is too high to fit i128
- let abs_scale_adjustment = (-scale_adjustment) as u32;
- if abs_scale_adjustment > 38 {
- return Ok(Some(0));
- }
+ // Early return mantissa 0, Spark checks if it fits digits and throw error
in ansi
+ if mantissa == 0 {
+ if exponent < -37 {
+ return Err(SparkError::NumericOutOfRange {
+ value: input_str.to_string(),
+ });
+ }
+ return Ok(Some(0));
+ }
- let divisor = 10_i128.pow(abs_scale_adjustment);
- let quotient_opt = mantissa.checked_div(divisor);
- // Check if divisor is 0
- if quotient_opt.is_none() {
- return Ok(None);
- }
- let quotient = quotient_opt.unwrap();
- let remainder = mantissa % divisor;
-
- // Round half up: if abs(remainder) >= divisor/2, round away
from zero
- let half_divisor = divisor / 2;
- let rounded = if remainder.abs() >= half_divisor {
- if mantissa >= 0 {
- quotient + 1
- } else {
- quotient - 1
- }
- } else {
- quotient
- };
- Some(rounded)
- };
+ // scale adjustment
+ let target_scale = scale as i32;
+ let scale_adjustment = target_scale - exponent;
- match scaled_value {
- Some(value) => {
- // Check if it fits target precision
- if is_validate_decimal_precision(value, precision) {
- Ok(Some(value))
- } else {
- Ok(None)
- }
- }
- None => {
- // Overflow while scaling
- Ok(None)
- }
+ let scaled_value = if scale_adjustment >= 0 {
+ // Need to multiply (increase scale) but return None if scale is too
high to fit i128
+ if scale_adjustment > 38 {
+ return Ok(None);
+ }
+ mantissa.checked_mul(10_i128.pow(scale_adjustment as u32))
+ } else {
+ // Need to divide (decrease scale)
+ let abs_scale_adjustment = (-scale_adjustment) as u32;
+ if abs_scale_adjustment > 38 {
+ return Ok(Some(0));
+ }
+
+ let divisor = 10_i128.pow(abs_scale_adjustment);
+ let quotient_opt = mantissa.checked_div(divisor);
+ // Check if divisor is 0
+ if quotient_opt.is_none() {
+ return Ok(None);
+ }
+ let quotient = quotient_opt.unwrap();
+ let remainder = mantissa % divisor;
+
+ // Round half up: if abs(remainder) >= divisor/2, round away from zero
+ let half_divisor = divisor / 2;
+ let rounded = if remainder.abs() >= half_divisor {
+ if mantissa >= 0 {
+ quotient + 1
+ } else {
+ quotient - 1
+ }
+ } else {
+ quotient
+ };
+ Some(rounded)
+ };
+
+ match scaled_value {
+ Some(value) => {
+ if is_validate_decimal_precision(value, precision) {
+ Ok(Some(value))
+ } else {
+ // Value ok but exceeds precision mentioned . THrow error
+ Err(SparkError::NumericValueOutOfRange {
+ value: trimmed.to_string(),
+ precision,
+ scale,
+ })
}
}
- Err(_) => Ok(None),
+ None => {
+ // Overflow when scaling raise exception
+ Err(SparkError::NumericValueOutOfRange {
+ value: trimmed.to_string(),
+ precision,
+ scale,
+ })
+ }
}
}
+fn invalid_decimal_cast(value: &str, precision: u8, scale: i8) -> SparkError {
+ invalid_value(
+ value,
+ "STRING",
+ &format!("DECIMAL({},{})", precision, scale),
+ )
+}
+
/// Parse a decimal string into mantissa and scale
-/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3)
-fn parse_decimal_str(s: &str) -> Result<(i128, i32), String> {
+/// e.g., "123.45" -> (12345, 2), "-0.001" -> (-1, 3) , 0e50 -> (0,50) etc
+fn parse_decimal_str(
+ s: &str,
+ original_str: &str,
+ precision: u8,
+ scale: i8,
+) -> SparkResult<(i128, i32)> {
if s.is_empty() {
- return Err("Empty string".to_string());
+ return Err(invalid_decimal_cast(original_str, precision, scale));
}
let (mantissa_str, exponent) = if let Some(e_pos) = s.find(|c| ['e',
'E'].contains(&c)) {
@@ -2426,7 +2454,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32),
String> {
// Parse exponent
let exp: i32 = exponent_part
.parse()
- .map_err(|e| format!("Invalid exponent: {}", e))?;
+ .map_err(|_| invalid_decimal_cast(original_str, precision,
scale))?;
(mantissa_part, exp)
} else {
@@ -2441,13 +2469,13 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32),
String> {
};
if mantissa_str.starts_with('+') || mantissa_str.starts_with('-') {
- return Err("Invalid sign format".to_string());
+ return Err(invalid_decimal_cast(original_str, precision, scale));
}
let (integral_part, fractional_part) = match mantissa_str.find('.') {
Some(dot_pos) => {
if mantissa_str[dot_pos + 1..].contains('.') {
- return Err("Multiple decimal points".to_string());
+ return Err(invalid_decimal_cast(original_str, precision,
scale));
}
(&mantissa_str[..dot_pos], &mantissa_str[dot_pos + 1..])
}
@@ -2455,15 +2483,15 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32),
String> {
};
if integral_part.is_empty() && fractional_part.is_empty() {
- return Err("No digits found".to_string());
+ return Err(invalid_decimal_cast(original_str, precision, scale));
}
if !integral_part.is_empty() && !integral_part.bytes().all(|b|
b.is_ascii_digit()) {
- return Err("Invalid integral part".to_string());
+ return Err(invalid_decimal_cast(original_str, precision, scale));
}
if !fractional_part.is_empty() && !fractional_part.bytes().all(|b|
b.is_ascii_digit()) {
- return Err("Invalid fractional part".to_string());
+ return Err(invalid_decimal_cast(original_str, precision, scale));
}
// Parse integral part
@@ -2473,7 +2501,7 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32),
String> {
} else {
integral_part
.parse()
- .map_err(|_| "Invalid integral part".to_string())?
+ .map_err(|_| invalid_decimal_cast(original_str, precision, scale))?
};
// Parse fractional part
@@ -2483,14 +2511,14 @@ fn parse_decimal_str(s: &str) -> Result<(i128, i32),
String> {
} else {
fractional_part
.parse()
- .map_err(|_| "Invalid fractional part".to_string())?
+ .map_err(|_| invalid_decimal_cast(original_str, precision, scale))?
};
// Combine: value = integral * 10^fractional_scale + fractional
let mantissa = integral_value
.checked_mul(10_i128.pow(fractional_scale as u32))
.and_then(|v| v.checked_add(fractional_value))
- .ok_or("Overflow in mantissa calculation")?;
+ .ok_or_else(|| invalid_decimal_cast(original_str, precision, scale))?;
let final_mantissa = if negative { -mantissa } else { mantissa };
// final scale = fractional_scale - exponent
diff --git a/native/spark-expr/src/error.rs b/native/spark-expr/src/error.rs
index 4b00b70eb..c39a05cd4 100644
--- a/native/spark-expr/src/error.rs
+++ b/native/spark-expr/src/error.rs
@@ -39,6 +39,9 @@ pub enum SparkError {
scale: i8,
},
+ #[error("[NUMERIC_OUT_OF_SUPPORTED_RANGE] The value {value} cannot be
interpreted as a numeric since it has more than 38 digits.")]
+ NumericOutOfRange { value: String },
+
#[error("[CAST_OVERFLOW] The value {value} of the type \"{from_type}\"
cannot be cast to \"{to_type}\" \
due to an overflow. Use `try_cast` to tolerate overflow and return
NULL instead. If necessary \
set \"spark.sql.ansi.enabled\" to \"false\" to bypass this error.")]
diff --git a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
index 26bb810b7..269925be4 100644
--- a/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
+++ b/spark/src/test/scala/org/apache/comet/CometCastSuite.scala
@@ -33,7 +33,6 @@ import org.apache.spark.sql.functions.col
import org.apache.spark.sql.internal.SQLConf
import org.apache.spark.sql.types.{ArrayType, BooleanType, ByteType, DataType,
DataTypes, DecimalType, IntegerType, LongType, ShortType, StringType,
StructField, StructType}
-import org.apache.comet.CometSparkSessionExtensions.isSpark40Plus
import org.apache.comet.expressions.{CometCast, CometEvalMode}
import org.apache.comet.rules.CometScanTypeChecker
import org.apache.comet.serde.Compatible
@@ -709,8 +708,6 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
test("cast StringType to DecimalType(10,2) (does not support fullwidth
unicode digits)") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
- // TODO fix for Spark 4.0.0
- assume(!isSpark40Plus)
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(10, 2), testAnsi =
ansiEnabled))
@@ -719,18 +716,38 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
test("cast StringType to DecimalType(2,2)") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
- // TODO fix for Spark 4.0.0
- assume(!isSpark40Plus)
val values = gen.generateStrings(dataSize, numericPattern, 12).toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(2, 2), testAnsi =
ansiEnabled))
}
}
+ test("cast StringType to DecimalType check if right exception message is
thrown") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
+ val values = Seq("d11307\n").toDF("a")
+ Seq(true, false).foreach(ansiEnabled =>
+ castTest(values, DataTypes.createDecimalType(2, 2), testAnsi =
ansiEnabled))
+ }
+ }
+
+ test("cast StringType to DecimalType(2,2) check if right exception is being
thrown") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
+ val values = gen.generateInts(10000).map(" " + _).toDF("a")
+ Seq(true, false).foreach(ansiEnabled =>
+ castTest(values, DataTypes.createDecimalType(2, 2), testAnsi =
ansiEnabled))
+ }
+ }
+
+ test("cast StringType to DecimalType(38,10) high precision - check 0
mantissa") {
+ withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
+ val values = Seq("0e31", "000e3375", "0e40", "0E+695",
"0e5887677").toDF("a")
+ Seq(true, false).foreach(ansiEnabled =>
+ castTest(values, DataTypes.createDecimalType(38, 10), testAnsi =
ansiEnabled))
+ }
+ }
+
test("cast StringType to DecimalType(38,10) high precision") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
- // TODO fix for Spark 4.0.0
- assume(!isSpark40Plus)
val values = gen.generateStrings(dataSize, numericPattern, 38).toDF("a")
Seq(true, false).foreach(ansiEnabled =>
castTest(values, DataTypes.createDecimalType(38, 10), testAnsi =
ansiEnabled))
@@ -739,8 +756,6 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
test("cast StringType to DecimalType(10,2) basic values") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
- // TODO fix for Spark 4.0.0
- assume(!isSpark40Plus)
val values = Seq(
"123.45",
"-67.89",
@@ -766,8 +781,6 @@ class CometCastSuite extends CometTestBase with
AdaptiveSparkPlanHelper {
test("cast StringType to Decimal type scientific notation") {
withSQLConf(CometConf.getExprAllowIncompatConfigKey(classOf[Cast]) ->
"true") {
- // TODO fix for Spark 4.0.0
- assume(!isSpark40Plus)
val values = Seq(
"1.23E-5",
"1.23e10",
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]