andygrove commented on code in PR #15958:
URL: https://github.com/apache/datafusion/pull/15958#discussion_r2079751198


##########
datafusion/spark/src/function/math/ceil_floor.rs:
##########
@@ -0,0 +1,720 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{any::Any, sync::Arc};
+
+use arrow::{
+    array::{ArrayRef, ArrowNativeTypeOp, AsArray},
+    datatypes::{
+        DataType, Decimal128Type, Field, Float32Type, Float64Type, Int16Type, 
Int32Type,
+        Int64Type, Int8Type, DECIMAL128_MAX_PRECISION, DECIMAL128_MAX_SCALE,
+    },
+};
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::{
+    ColumnarValue, ReturnFieldArgs, ScalarFunctionArgs, ScalarUDFImpl, 
Signature,
+    Volatility,
+};
+use num::integer::{div_ceil, div_floor};
+
+use crate::function::error_utils::{
+    generic_exec_err, generic_internal_err, invalid_arg_count_exec_err,
+    unsupported_data_type_exec_err, unsupported_data_types_exec_err,
+};
+
+use super::round_decimal_base;
+
+fn ceil_floor_coerce_types(name: &str, arg_types: &[DataType]) -> 
Result<Vec<DataType>> {
+    if arg_types.len() == 1 {
+        if arg_types[0].is_numeric() {
+            Ok(vec![ceil_floor_coerce_first_arg(name, &arg_types[0])?])
+        } else {
+            Err(unsupported_data_types_exec_err(
+                name,
+                "Numeric Type",
+                arg_types,
+            ))
+        }
+    } else if arg_types.len() == 2 {
+        if arg_types[0].is_numeric() && arg_types[1].is_integer() {
+            Ok(vec![
+                ceil_floor_coerce_first_arg(name, &arg_types[0])?,
+                DataType::Int32,
+            ])
+        } else {
+            Err(unsupported_data_types_exec_err(
+                name,
+                "Numeric Type for expr and Integer Type for target scale",
+                arg_types,
+            ))
+        }
+    } else {
+        Err(invalid_arg_count_exec_err(name, (1, 2), arg_types.len()))
+    }
+}
+
+fn ceil_floor_return_field_from_args(name: &str, args: ReturnFieldArgs) -> 
Result<Field> {
+    let arg_fields = args.arg_fields;
+    let scalar_arguments = args.scalar_arguments;
+    let return_type = if arg_fields.len() == 1 {
+        match &arg_fields[0].data_type() {
+            DataType::Decimal128(precision, scale) => {
+                let (precision, scale) =
+                    round_decimal_base(*precision as i32, *scale as i32, 0, 
true);
+                Ok(DataType::Decimal128(precision, scale))
+            }
+            DataType::Decimal256(precision, scale) => {
+                if *precision <= DECIMAL128_MAX_PRECISION
+                    && *scale <= DECIMAL128_MAX_SCALE
+                {
+                    let (precision, scale) =
+                        round_decimal_base(*precision as i32, *scale as i32, 
0, false);
+                    Ok(DataType::Decimal128(precision, scale))
+                } else {
+                    Err(unsupported_data_type_exec_err(
+                        name,
+                        format!("Decimal Type must have precision <= 
{DECIMAL128_MAX_PRECISION} and scale <= {DECIMAL128_MAX_SCALE}").as_str(),
+                        arg_fields[0].data_type(),
+                    ))
+                }
+            }
+            _ => Ok(DataType::Int64),
+        }
+    } else if arg_fields.len() == 2 {
+        if let Some(target_scale) = scalar_arguments[1] {
+            let expr = &arg_fields[0].data_type();
+            let target_scale: i32 = match target_scale {
+                ScalarValue::Int8(Some(v)) => Ok(*v as i32),
+                ScalarValue::Int16(Some(v)) => Ok(*v as i32),
+                ScalarValue::Int32(Some(v)) => Ok(*v),
+                ScalarValue::Int64(Some(v)) => Ok(*v as i32),
+                ScalarValue::UInt8(Some(v)) => Ok(*v as i32),
+                ScalarValue::UInt16(Some(v)) => Ok(*v as i32),
+                ScalarValue::UInt32(Some(v)) => Ok(*v as i32),
+                ScalarValue::UInt64(Some(v)) => Ok(*v as i32),
+                _ => Err(unsupported_data_type_exec_err(
+                    name,
+                    "Target scale must be Integer literal",
+                    &target_scale.data_type(),
+                )),
+            }?;
+            if target_scale < -38 {
+                return Err(generic_exec_err(
+                    name,
+                    "Target scale must be greater than -38",
+                ));
+            }
+            let (precision, scale) = match expr {
+                DataType::Int8 => Ok((3, 0)),
+                DataType::UInt8 | DataType::Int16 => Ok((5, 0)),
+                DataType::UInt16 | DataType::Int32 => Ok((10, 0)),
+                DataType::UInt32 | DataType::UInt64 | DataType::Int64 => 
Ok((20, 0)),
+                DataType::Float32 => Ok((14, 7)),
+                DataType::Float64 => Ok((30, 15)),
+                DataType::Decimal128(precision, scale)
+                | DataType::Decimal256(precision, scale) => {
+                    if *precision <= DECIMAL128_MAX_PRECISION
+                        && *scale <= DECIMAL128_MAX_SCALE
+                    {
+                        Ok((*precision as i32, *scale as i32))
+                    } else {
+                        Err(unsupported_data_type_exec_err(
+                            name,
+                            format!("Decimal Type must have precision <= 
{DECIMAL128_MAX_PRECISION} and scale <= {DECIMAL128_MAX_SCALE}").as_str(),
+                            arg_fields[0].data_type(),
+                        ))
+                    }
+                }
+                _ => Err(unsupported_data_type_exec_err(
+                    name,
+                    "Numeric Type for expr",
+                    expr,
+                )),
+            }?;
+            let (precision, scale) =
+                round_decimal_base(precision, scale, target_scale, true);
+            Ok(DataType::Decimal128(precision, scale))
+        } else {
+            Err(generic_exec_err(
+                name,
+                "Target scale must be Integer literal, received: None",
+            ))
+        }
+    } else {
+        Err(invalid_arg_count_exec_err(name, (1, 2), arg_fields.len()))
+    }?;
+    Ok(Field::new(name.to_string(), return_type, true))
+}
+
+fn ceil_floor_coerce_first_arg(name: &str, arg_type: &DataType) -> 
Result<DataType> {
+    if arg_type.is_numeric() {
+        match arg_type {
+            DataType::UInt8 => Ok(DataType::Int16),
+            DataType::UInt16 => Ok(DataType::Int32),
+            DataType::UInt32 | DataType::UInt64 => Ok(DataType::Int64),
+            DataType::Decimal256(precision, scale) => {
+                if *precision <= DECIMAL128_MAX_PRECISION
+                    && *scale <= DECIMAL128_MAX_SCALE
+                {
+                    Ok(DataType::Decimal128(*precision, *scale))
+                } else {
+                    Err(unsupported_data_type_exec_err(
+                        name,
+                        format!("Decimal Type must have precision <= 
{DECIMAL128_MAX_PRECISION} and scale <= {DECIMAL128_MAX_SCALE}").as_str(),
+                        arg_type,
+                    ))
+                }
+            }
+            other => Ok(other.clone()),
+        }
+    } else {
+        Err(unsupported_data_type_exec_err(
+            name,
+            "First arg must be Numeric Type",
+            arg_type,
+        ))
+    }
+}
+
+#[inline]
+fn get_return_type_precision_scale(return_type: &DataType) -> Result<(u8, i8)> 
{
+    match return_type {
+        DataType::Decimal128(precision, scale) => Ok((*precision, *scale)),
+        other => Err(generic_internal_err(
+            "ceil",
+            format!("Expected return type to be Decimal128, got: 
{other}").as_str(),
+        )),
+    }
+}
+
+#[derive(Debug)]
+pub struct SparkCeil {
+    signature: Signature,
+}
+
+impl Default for SparkCeil {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl SparkCeil {
+    pub fn new() -> Self {
+        Self {
+            signature: Signature::user_defined(Volatility::Immutable),
+        }
+    }
+}
+
+impl ScalarUDFImpl for SparkCeil {
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn name(&self) -> &str {
+        "ceil"
+    }
+
+    fn signature(&self) -> &Signature {
+        &self.signature
+    }
+
+    fn return_type(&self, _arg_types: &[DataType]) -> Result<DataType> {
+        Err(generic_internal_err(
+            "ceil",
+            "`return_type` should not be called, call `return_type_from_args` 
instead",
+        ))
+    }
+
+    fn return_field_from_args(&self, args: ReturnFieldArgs) -> Result<Field> {
+        ceil_floor_return_field_from_args("ceil", args)
+    }
+
+    fn invoke_with_args(&self, args: ScalarFunctionArgs) -> 
Result<ColumnarValue> {
+        let arg_len = args.args.len();
+        let target_scale = if arg_len == 1 {
+            Ok(&None)
+        } else if arg_len == 2 {

Review Comment:
   I think it is fine to have both cases covered in a single expression in 
DataFusion. Sorry for the confusion.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to