Jefffrey commented on code in PR #22655:
URL: https://github.com/apache/datafusion/pull/22655#discussion_r3392336262
##########
datafusion/sqllogictest/test_files/math.slt:
##########
@@ -686,6 +686,22 @@ select gcd(-9223372036854775808, 0);
query error DataFusion error: Arrow error: Compute error: Signed integer
overflow in GCD\(0, \-9223372036854775808\)
select gcd(0, -9223372036854775808);
+# gcd decimal
Review Comment:
could we add test cases for:
- negative number inputs
- decimals with different scale/precision inputs
- decimal point `1.23` inputs
##########
datafusion/functions/src/utils.rs:
##########
@@ -123,6 +123,7 @@ where
}
/// Computes a binary math function for input arrays using a specified
function.
+/// Deprecated, use [`calculate_binary_math_cast`] instead.
Review Comment:
im not sure if we should deprecate this as the `calculate_binary_math_cast`
alternative introduces a new argument that is only really relevant for decimals
🤔
##########
datafusion/functions/src/math/gcd.rs:
##########
@@ -141,44 +227,271 @@ fn compute_gcd_with_scalar(arr: &ArrayRef, scalar:
Option<i64>) -> Result<Column
}
Some(scalar_value) => {
let result: PrimitiveArray<Int64Type> =
- prim.try_unary(|val| compute_gcd(val, scalar_value))?;
+ prim.try_unary(|val| gcd_signed_int(val, scalar_value))?;
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
}
None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
}
}
-/// Computes gcd of two unsigned integers using Binary GCD algorithm.
-pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
- if a == 0 {
- return b;
+#[cfg(test)]
+mod tests {
Review Comment:
could we move these tests to SLTs?
##########
datafusion/functions/src/utils.rs:
##########
@@ -133,6 +134,69 @@ pub fn calculate_binary_math<L, R, O, F>(
right: &ColumnarValue,
fun: F,
) -> Result<Arc<PrimitiveArray<O>>>
+where
+ L: ArrowPrimitiveType,
+ R: ArrowPrimitiveType,
+ O: ArrowPrimitiveType,
+ F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
+ R::Native: TryFrom<ScalarValue>,
+{
+ calculate_binary_math_cast::<L, R, O, F>(left, right, fun, &R::DATA_TYPE)
+}
+
+/// Computes a binary math function for input arrays using a specified function
+/// and applies rescaling to given precision and scale.
+/// Deprecated, use [`calculate_binary_decimal_math_cast`] instead.
+/// Generic types:
+/// - `L`: Left array decimal type
+/// - `R`: Right array primitive type
+/// - `O`: Output array decimal type
+/// - `F`: Functor computing `fun(l: L, r: R) -> Result<OutputType>`
+pub fn calculate_binary_decimal_math<L, R, O, F>(
+ left: &dyn Array,
+ right: &ColumnarValue,
+ fun: F,
+ precision: u8,
+ scale: i8,
+) -> Result<Arc<PrimitiveArray<O>>>
+where
+ L: DecimalType,
+ R: ArrowPrimitiveType,
+ O: DecimalType,
+ F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
+ R::Native: TryFrom<ScalarValue>,
+{
+ calculate_binary_decimal_math_cast::<L, R, O, F>(
+ left,
+ right,
+ fun,
+ precision,
+ scale,
+ &R::DATA_TYPE,
+ )
+}
+
+/// Computes a binary math function for input arrays using a specified
function.
+///
+/// It casts the right operand to `cast_target` instead of the default
`R::DATA_TYPE` to preserve
+/// the right operand scale.
+///
+/// # Type Parameters
+/// - `L`: Left array primitive type
+/// - `R`: Right array primitive type
+/// - `O`: Output array primitive type
+/// - `F`: Functor computing `fun(l: L, r: R) -> Result<OutputType>`
+/// # Arguments
+/// - `left`: Left input array
+/// - `right`: Right input array or scalar value
+/// - `fun`: Function of type `F`
+/// - `cast_target`: Data type to cast right operand to before applying
function
+pub fn calculate_binary_math_cast<L, R, O, F>(
Review Comment:
should this just be a private function that the others use internally?
otherwise its a bit confusing to have this as public when only decimals can
really take advantage of `cast_target`
##########
datafusion/functions/src/math/gcd.rs:
##########
@@ -76,37 +76,123 @@ impl ScalarUDFImpl for GcdFunc {
&self.signature
}
- fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
- Ok(DataType::Int64)
+ fn return_type(&self, arg_types: &[DataType]) -> Result<DataType> {
+ Ok(arg_types[0].clone())
+ }
+
+ fn coerce_types(&self, arg_types: &[DataType]) -> Result<Vec<DataType>> {
+ let [arg1, arg2] = take_function_args(self.name(), arg_types)?;
+
+ let coerced_type = match (arg1, arg2) {
+ (DataType::Null, _) | (_, DataType::Null) => Ok(DataType::Int64),
+ (lhs, rhs) if lhs.is_integer() && rhs.is_integer() =>
Ok(DataType::Int64),
+ (lhs, rhs) if lhs.is_decimal() || rhs.is_decimal() => {
+ decimal_coercion(lhs, rhs).map(Ok).unwrap_or_else(|| {
+ exec_err!(
Review Comment:
in `coerce_types` the errors should be plan errors
##########
datafusion/functions/src/utils.rs:
##########
@@ -133,6 +134,69 @@ pub fn calculate_binary_math<L, R, O, F>(
right: &ColumnarValue,
fun: F,
) -> Result<Arc<PrimitiveArray<O>>>
+where
+ L: ArrowPrimitiveType,
+ R: ArrowPrimitiveType,
+ O: ArrowPrimitiveType,
+ F: Fn(L::Native, R::Native) -> Result<O::Native, ArrowError>,
+ R::Native: TryFrom<ScalarValue>,
+{
+ calculate_binary_math_cast::<L, R, O, F>(left, right, fun, &R::DATA_TYPE)
+}
+
+/// Computes a binary math function for input arrays using a specified function
+/// and applies rescaling to given precision and scale.
+/// Deprecated, use [`calculate_binary_decimal_math_cast`] instead.
Review Comment:
instead of just stating its deprecated might be better to explicitly mark it
as so `#[deprecated]`
##########
datafusion/functions/src/math/gcd.rs:
##########
@@ -141,44 +227,271 @@ fn compute_gcd_with_scalar(arr: &ArrayRef, scalar:
Option<i64>) -> Result<Column
}
Some(scalar_value) => {
let result: PrimitiveArray<Int64Type> =
- prim.try_unary(|val| compute_gcd(val, scalar_value))?;
+ prim.try_unary(|val| gcd_signed_int(val, scalar_value))?;
Ok(ColumnarValue::Array(Arc::new(result) as ArrayRef))
}
None => Ok(ColumnarValue::Scalar(ScalarValue::Int64(None))),
}
}
-/// Computes gcd of two unsigned integers using Binary GCD algorithm.
-pub(super) fn unsigned_gcd(mut a: u64, mut b: u64) -> u64 {
- if a == 0 {
- return b;
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use crate::math::common::gcd_signed;
+ use arrow::array::{Array, Decimal128Array, Int64Array};
+ use arrow::datatypes::{DECIMAL128_MAX_PRECISION, Field};
+ use arrow_buffer::i256;
+ use datafusion_common::ScalarValue;
+ use datafusion_common::cast::{as_decimal128_array, as_int64_array};
+ use datafusion_common::config::ConfigOptions;
+ use std::sync::Arc;
+
+ #[test]
+ fn test_i64_array() {
+ let arg_fields = vec![
+ Field::new("a", DataType::Int64, true).into(),
+ Field::new("b", DataType::Int64, true).into(),
+ ];
+ let args = ScalarFunctionArgs {
+ args: vec![
+ ColumnarValue::Array(Arc::new(Int64Array::from(vec![
+ 0, 2, 0, 2, 15, 20,
+ ]))),
+ ColumnarValue::Array(Arc::new(Int64Array::from(vec![
+ 0, 0, 2, 3, 10, 1000,
+ ]))),
+ ],
+ arg_fields,
+ number_rows: 6,
+ return_field: Field::new("f", DataType::Int64, true).into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ };
+ let result = GcdFunc::new()
+ .invoke_with_args(args)
+ .expect("failed to initialize function");
+
+ match result {
+ ColumnarValue::Array(arr) => {
+ let values =
+ as_int64_array(&arr).expect("failed to convert result to
an array");
+ assert_eq!(values.len(), 6);
+ assert_eq!(values.value(0), 0);
+ assert_eq!(values.value(1), 2);
+ assert_eq!(values.value(2), 2);
+ assert_eq!(values.value(3), 1);
+ assert_eq!(values.value(4), 5);
+ assert_eq!(values.value(5), 20);
+ }
+ ColumnarValue::Scalar(_) => {
+ panic!("Expected an array value")
+ }
+ }
}
- if b == 0 {
- return a;
+
+ #[test]
+ fn test_decimal_scalar() {
+ let arg_fields = vec![
+ Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
true)
+ .into(),
+ Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
true)
+ .into(),
+ ];
+ let args = ScalarFunctionArgs {
+ args: vec![
+ ColumnarValue::Scalar(ScalarValue::Decimal128(
+ Some(i128::from(2)),
+ DECIMAL128_MAX_PRECISION,
+ 0,
+ )),
+ ColumnarValue::Scalar(ScalarValue::Decimal128(
+ Some(i128::from(3)),
+ DECIMAL128_MAX_PRECISION,
+ 0,
+ )),
+ ],
+ arg_fields,
+ number_rows: 1,
+ return_field: Field::new(
+ "f",
+ DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
+ true,
+ )
+ .into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ };
+ let result = GcdFunc::new()
+ .invoke_with_args(args)
+ .expect("failed to initialize function power");
+
+ match result {
+ ColumnarValue::Array(arr) => {
+ let ints = as_decimal128_array(&arr)
+ .expect("failed to convert result to an array");
+
+ assert_eq!(ints.len(), 1);
+ assert_eq!(ints.value(0), i128::from(1));
+ // Signature stays the same as input
+ assert_eq!(
+ *arr.data_type(),
+ DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0)
+ );
+ }
+ ColumnarValue::Scalar(_) => {
+ panic!("Expected an array value")
+ }
+ }
}
- let shift = (a | b).trailing_zeros();
- a >>= a.trailing_zeros();
- loop {
- b >>= b.trailing_zeros();
- if a > b {
- swap(&mut a, &mut b);
+ #[test]
+ fn test_decimal_array_scalar() {
+ let arg_fields = vec![
+ Field::new("a", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
true)
+ .into(),
+ Field::new("b", DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
true)
+ .into(),
+ ];
+ let args = ScalarFunctionArgs {
+ args: vec![
+ ColumnarValue::Array(Arc::new(
+ Decimal128Array::from(vec![2, 15])
+ .with_precision_and_scale(DECIMAL128_MAX_PRECISION, 0)
+ .unwrap(),
+ )),
+ ColumnarValue::Scalar(ScalarValue::Decimal128(
+ Some(i128::from(3)),
+ DECIMAL128_MAX_PRECISION,
+ 0,
+ )),
+ ],
+ arg_fields,
+ number_rows: 2,
+ return_field: Field::new(
+ "f",
+ DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0),
+ true,
+ )
+ .into(),
+ config_options: Arc::new(ConfigOptions::default()),
+ };
+ let result = GcdFunc::new()
+ .invoke_with_args(args)
+ .expect("failed to initialize function power");
+
+ match result {
+ ColumnarValue::Array(arr) => {
+ let ints = as_decimal128_array(&arr)
+ .expect("failed to convert result to an array");
+
+ assert_eq!(ints.len(), 2);
+ assert_eq!(ints.value(0), i128::from(1));
+ assert_eq!(ints.value(1), i128::from(3));
+ // Signature stays the same as input
+ assert_eq!(
+ *arr.data_type(),
+ DataType::Decimal128(DECIMAL128_MAX_PRECISION, 0)
+ );
+ }
+ ColumnarValue::Scalar(_) => {
+ panic!("Expected an array value")
+ }
}
- b -= a;
- if b == 0 {
- return a << shift;
+ }
+
+ #[test]
+ fn test_coercion() {
+ let mut coerced = GcdFunc::new()
+ .coerce_types(&[DataType::Int64, DataType::Int32])
+ .expect("coercion should succeed");
+ assert_eq!(coerced, vec![DataType::Int64, DataType::Int64]);
+
+ coerced = GcdFunc::new()
+ .coerce_types(&[DataType::Decimal128(10, 2), DataType::Int32])
+ .expect("coercion should succeed");
+
+ assert_eq!(
+ coerced,
+ vec![DataType::Decimal128(12, 2), DataType::Decimal128(12, 2)]
+ );
+
+ coerced = GcdFunc::new()
+ .coerce_types(&[DataType::Decimal128(10, 2), DataType::Null])
+ .expect("coercion should succeed");
+
+ assert_eq!(coerced, vec![DataType::Int64, DataType::Int64]);
+ }
+
+ const GCD_COMMON_TEST_CASES: [(i64, i64, i64); 18] = [
+ // Basic cases
+ (48, 18, 6),
+ (54, 24, 6),
+ (100, 50, 50),
+ (17, 19, 1),
+ (21, 14, 7),
+ // Edge cases with 0
+ (0, 0, 0),
+ (0, 5, 5),
+ (10, 0, 10),
+ // Same numbers
+ (7, 7, 7),
+ (100, 100, 100),
+ // One is 1
+ (1, 1, 1),
+ (1, 100, 1),
+ (999, 1, 1),
+ // Large numbers
+ (1000000, 500000, 500000),
+ (123456, 789012, 12),
+ (999999, 111111, 111111),
+ // Powers of 2
+ (64, 128, 64),
+ (1024, 2048, 1024),
+ ];
+
+ #[test]
+ fn test_gcd_i64() {
+ let test_cases: Vec<(i64, i64, i64)> = [
+ GCD_COMMON_TEST_CASES.into(),
+ vec![
+ // Max value cases
+ (1, i64::MAX, 1),
+ (i64::MAX, 1, 1),
+ (i64::MAX, i64::MAX, i64::MAX),
+ ],
+ ]
+ .concat();
+
+ // Success cases
+ for (a, b, expected) in test_cases {
+ let actual = gcd_signed(a, b).expect("should succeed");
+ assert_eq!(
+ actual, expected,
+ "euclid_gcd({a}, {b}) expected {expected}, actual {actual}"
+ );
}
}
-}
-/// Computes greatest common divisor using Binary GCD algorithm.
-pub fn compute_gcd(x: i64, y: i64) -> Result<i64, ArrowError> {
- let a = x.unsigned_abs();
- let b = y.unsigned_abs();
- let r = unsigned_gcd(a, b);
- // The result can be up to 2^63 (e.g. gcd(i64::MIN, 0) or
- // gcd(i64::MIN, i64::MIN)), which does not fit into i64.
- r.try_into().map_err(|_| {
- ArrowError::ComputeError(format!("Signed integer overflow in GCD({x},
{y})"))
- })
+ #[test]
+ fn test_gcd_decimal128() {
Review Comment:
these tests for `gcd_signed` seem to duplicate the existing ones in
`common.rs` (which are already located closer to the source of `gcd_signed`
##########
datafusion/functions/src/math/lcm.rs:
##########
@@ -15,25 +15,22 @@
// specific language governing permissions and limitations
Review Comment:
similar comments as gcd (where applicable)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]