andygrove commented on code in PR #615:
URL: https://github.com/apache/datafusion-comet/pull/615#discussion_r1676466184


##########
core/src/execution/datafusion/expressions/cast.rs:
##########
@@ -1208,6 +1277,260 @@ fn do_cast_string_to_int<
     Ok(Some(result))
 }
 
+fn cast_string_to_decimal128(
+    str: &str,
+    precision: u8,
+    scale: i8,
+    eval_mode: EvalMode,
+) -> CometResult<Option<i128>> {
+    let cast = match parse_decimal::<Decimal128Type>(str, precision, scale) {
+        Some(v) => v,
+        None => {
+            if eval_mode == EvalMode::Ansi {
+                let type_name = format!("DECIMAL({},{})", precision, scale);
+                return none_or_err(eval_mode, &type_name, str);
+            } else {
+                return Ok(None);
+            }
+        }
+    };
+    Ok(Some(cast))
+}
+
+fn cast_string_to_decimal256(
+    str: &str,
+    precision: u8,
+    scale: i8,
+    eval_mode: EvalMode,
+) -> CometResult<Option<i256>> {
+    let cast = match parse_decimal::<Decimal256Type>(str, precision, scale) {
+        Some(v) => v,
+        None => {
+            if eval_mode == EvalMode::Ansi {
+                let type_name = format!("DECIMAL({},{})", precision, scale);
+                return none_or_err(eval_mode, &type_name, str);
+            } else {
+                return Ok(None);
+            }
+        }
+    };
+    Ok(Some(cast))
+}
+
+/// Copied from arrow-rs, modified to replicate Spark's behavior
+pub fn parse_decimal<T: DecimalType>(s: &str, precision: u8, scale: i8) -> 
Option<T::Native> {
+    let mut result: <T as ArrowPrimitiveType>::Native = T::Native::usize_as(0);
+    let mut fractionals = 0;
+    let mut digits: u8 = 0;
+    let mut exponent: i8 = 0;
+    let mut leading_zeros: i8 = 0;
+    let mut has_exponent = false;
+    let mut has_negative_exp = false;
+    let base = T::Native::usize_as(10);
+
+    let bs = s.as_bytes();
+    let (bs, negative) = match bs.first() {
+        Some(b'-') => (&bs[1..], true),
+        Some(b'+') => (&bs[1..], false),
+        _ => (bs, false),
+    };
+
+    if bs.is_empty() {
+        return None;
+    }
+
+    let mut bs = bs.iter();
+    while let Some(b) = bs.next() {
+        match b {
+            b'0'..=b'9' => {
+                if digits == 0 && *b == b'0' {
+                    leading_zeros += 1;
+                    continue;
+                }
+                digits += 1;
+                result = result.mul_wrapping(base);
+                result = result.add_wrapping(T::Native::usize_as((b - b'0') as 
usize));
+            }
+            b'.' => {
+                while let Some(b) = bs.next() {
+                    if *b == b'e' || *b == b'E' {
+                        match parse_exponent(
+                            bs.by_ref(),
+                            &mut digits,
+                            &mut fractionals,
+                            &mut leading_zeros,
+                            &mut has_exponent,
+                            &mut has_negative_exp,
+                        ) {
+                            None => {
+                                return None;
+                            }
+                            Some(v) => {
+                                exponent = v;
+                                continue;
+                            }
+                        }
+                    }
+                    if !b.is_ascii_digit() {
+                        return None;
+                    }
+                    fractionals += 1;
+                    digits += 1;
+                    result = result.mul_wrapping(base);
+                    result = result.add_wrapping(T::Native::usize_as((b - 
b'0') as usize));
+                }
+
+                // +00. is a valid decimal value in Spark but +. is invalid.
+                if digits == 0 && leading_zeros == 0 {
+                    return None;
+                }
+            }
+            b'e' | b'E' => match parse_exponent(
+                &mut bs,
+                &mut digits,
+                &mut fractionals,
+                &mut leading_zeros,
+                &mut has_exponent,
+                &mut has_negative_exp,
+            ) {
+                None => {
+                    return None;
+                }
+                Some(v) => {
+                    exponent = v;
+                    continue;
+                }
+            },
+            _ => {
+                return None;
+            }
+        }
+    }
+
+    result = adjust_decimal_scale::<T>(
+        result,
+        precision,
+        scale,
+        base,
+        exponent,
+        fractionals,
+        digits,
+        has_exponent,
+        has_negative_exp,
+    )?;
+
+    return Some(if negative {
+        result.neg_wrapping()
+    } else {
+        result
+    });
+}
+
+fn parse_exponent(
+    bs: &mut std::slice::Iter<u8>,
+    digits: &mut u8,
+    fractionals: &mut i8,
+    leading_zeros: &mut i8,
+    has_exponent: &mut bool,
+    has_negative_exp: &mut bool,
+) -> Option<i8> {
+    if *digits == 0 && *leading_zeros == 0 {
+        return None;
+    }
+
+    let mut exponent: i8 = 0;
+    let mut has_exp_digits = false;
+    *has_exponent = true;
+
+    for b in bs.by_ref() {
+        if *b == b'-' {
+            if *has_negative_exp || has_exp_digits {
+                return None;
+            }
+            *has_negative_exp = true;
+            continue;
+        } else if *b == b'+' {
+            if has_exp_digits {
+                return None;
+            }
+            continue;
+        }
+        if !b.is_ascii_digit() {
+            return None;
+        }
+        exponent = exponent.checked_mul(10)?.checked_add((b - b'0') as i8)?;
+        has_exp_digits = true;
+    }
+
+    if !has_exp_digits {
+        return None;
+    }
+    if !*has_negative_exp {
+        let cur_fractional = *fractionals as i8;
+        *fractionals = max(0, *fractionals - exponent);
+        exponent = max(0, exponent - cur_fractional);
+    } else {
+        *fractionals += exponent;
+    }
+    *digits += exponent.abs() as u8;
+
+    return Some(exponent);
+}
+
+fn adjust_decimal_scale<T: DecimalType + ArrowPrimitiveType>(
+    result: T::Native,
+    precision: u8,
+    scale: i8,
+    base: T::Native,
+    exponent: i8,
+    fractionals: i8,
+    digits: u8,
+    has_exponent: bool,
+    has_negative_exp: bool,
+) -> Option<T::Native> {
+    let mut res = result;
+
+    match fractionals.cmp(&scale) {
+        std::cmp::Ordering::Less => {
+            let drop = scale - fractionals;
+            if drop as u8 + digits > precision {
+                if res == T::Native::usize_as(0) && (exponent as u8) < 38 {
+                    return Some(res);
+                }
+                return None;
+            }
+            if has_exponent {
+                res = if !has_negative_exp {
+                    res.mul_wrapping(base.pow_wrapping(exponent as _))
+                } else {
+                    res.div_wrapping(base.pow_wrapping(exponent.abs() as _))
+                };
+            }
+            res = res.mul_wrapping(base.pow_wrapping(drop as _));
+        }
+        std::cmp::Ordering::Greater => {
+            // Since the fractional part is greater than the scale, we need to 
round the result
+            let diff = fractionals - scale;
+            let divisor = base.pow_wrapping(diff as _);
+            let quotient = res.div_wrapping(divisor);
+            let remainder = res.sub_wrapping(quotient.mul_wrapping(divisor));
+            if remainder >= 
T::Native::usize_as(5).mul_wrapping(base.pow_wrapping((diff - 1) as _))
+            {
+                res = quotient.add_wrapping(T::Native::usize_as(1));
+            } else {
+                res = quotient;
+            }
+        }
+        std::cmp::Ordering::Equal => {
+            if digits > precision {
+                return None;
+            }
+        }
+    }
+
+    return Some(res);

Review Comment:
   ```suggestion
       Some(res)
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to