Copilot commented on code in PR #228:
URL: https://github.com/apache/sedona-db/pull/228#discussion_r2449537140


##########
python/sedonadb/src/udf.rs:
##########
@@ -0,0 +1,341 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{ffi::CString, iter::zip, sync::Arc};
+
+use arrow_array::{
+    ffi::{FFI_ArrowArray, FFI_ArrowSchema},
+    ArrayRef,
+};
+use arrow_schema::Field;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::{ColumnarValue, ScalarUDF, Volatility};
+use datafusion_ffi::udf::FFI_ScalarUDF;
+use pyo3::{
+    pyclass, pyfunction, pymethods,
+    types::{PyAnyMethods, PyCapsule, PyTuple},
+    Bound, PyObject, Python,
+};
+use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
+use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
+
+use crate::{
+    error::PySedonaError,
+    import_from::{check_pycapsule, import_arg_matcher, import_arrow_array, 
import_sedona_type},
+    schema::PySedonaType,
+};
+
+#[pyfunction]
+pub fn sedona_scalar_udf<'py>(
+    py: Python<'py>,
+    py_invoke_batch: PyObject,
+    py_return_type: PyObject,
+    py_input_types: Option<Vec<PyObject>>,
+    volatility: &str,
+    name: &str,
+) -> Result<PySedonaScalarUdf, PySedonaError> {
+    let volatility = match volatility {
+        "immutable" => Volatility::Immutable,
+        "stable" => Volatility::Stable,
+        "volatile" => Volatility::Volatile,
+        v => {
+            return Err(PySedonaError::SedonaPython(format!(
+                "Expected one of 'immutable', 'stable', or 'volatile' but got 
'{v}'"
+            )));
+        }
+    };
+
+    let scalar_kernel = sedona_scalar_kernel(py, py_input_types, 
py_return_type, py_invoke_batch)?;
+    let sedona_scalar_udf =
+        SedonaScalarUDF::new(name, vec![Arc::new(scalar_kernel)], volatility, 
None);
+
+    Ok(PySedonaScalarUdf {
+        inner: sedona_scalar_udf,
+    })
+}
+
+fn sedona_scalar_kernel<'py>(
+    py: Python<'py>,
+    input_types: Option<Vec<PyObject>>,
+    py_return_field: PyObject,
+    py_invoke_batch: PyObject,
+) -> Result<PySedonaScalarKernel, PySedonaError> {
+    let matcher = if let Some(input_types) = input_types {
+        let arg_matchers = input_types
+            .iter()
+            .map(|obj| import_arg_matcher(obj.bind(py)))
+            .collect::<Result<Vec<_>, _>>()?;
+        let return_type = import_sedona_type(py_return_field.bind(py))?;
+        Some(ArgMatcher::new(arg_matchers, return_type))
+    } else {
+        None
+    };
+
+    let kernel_impl = PySedonaScalarKernel {
+        matcher,
+        py_return_field,
+        py_invoke_batch,
+    };
+
+    Ok(kernel_impl)
+}
+
+#[derive(Debug)]
+struct PySedonaScalarKernel {
+    matcher: Option<ArgMatcher>,
+    py_return_field: PyObject,
+    py_invoke_batch: PyObject,
+}
+
+impl SedonaScalarKernel for PySedonaScalarKernel {
+    fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
+        Err(PySedonaError::SedonaPython("Unexpected call to 
return_type()".to_string()).into())
+    }
+
+    fn invoke_batch(
+        &self,
+        _arg_types: &[SedonaType],
+        _args: &[ColumnarValue],
+    ) -> Result<ColumnarValue> {
+        Err(PySedonaError::SedonaPython("Unexpected call to 
invoke_batch()".to_string()).into())
+    }
+
+    fn return_type_from_args_and_scalars(
+        &self,
+        args: &[SedonaType],
+        scalar_args: &[Option<&ScalarValue>],
+    ) -> Result<Option<SedonaType>> {
+        if let Some(matcher) = &self.matcher {
+            let return_type = matcher.match_args(args)?;
+            return Ok(return_type);
+        }
+
+        let return_type = Python::with_gil(|py| -> Result<Option<SedonaType>, 
PySedonaError> {
+            let py_sedona_types = args
+                .iter()
+                .map(|arg| -> Result<_, PySedonaError> { 
Ok(PySedonaType::new(arg.clone())) })
+                .collect::<Result<Vec<_>, _>>()?;
+            let py_scalar_values = zip(&py_sedona_types, scalar_args)
+                .map(|(sedona_type, maybe_arg)| {
+                    maybe_arg.map(|arg| PySedonaValue {
+                        sedona_type: sedona_type.clone(),
+                        value: ColumnarValue::Scalar(arg.clone()),
+                        num_rows: 1,
+                    })
+                })
+                .collect::<Vec<_>>();
+
+            let py_return_field =
+                self.py_return_field
+                    .call(py, (py_sedona_types, py_scalar_values), None)?;
+            if py_return_field.is_none(py) {
+                return Ok(None);
+            }
+
+            let return_type = import_sedona_type(py_return_field.bind(py))?;
+            Ok(Some(return_type))
+        })?;
+
+        Ok(return_type)
+    }
+
+    fn invoke_batch_from_args(
+        &self,
+        arg_types: &[SedonaType],
+        args: &[ColumnarValue],
+        return_type: &SedonaType,
+        num_rows: usize,
+    ) -> Result<ColumnarValue> {
+        let result = Python::with_gil(|py| -> Result<ArrayRef, PySedonaError> {
+            let py_values = zip(arg_types, args)
+                .map(|(sedona_type, arg)| PySedonaValue {
+                    sedona_type: PySedonaType::new(sedona_type.clone()),
+                    value: arg.clone(),
+                    num_rows,
+                })
+                .collect::<Vec<_>>();
+
+            let py_return_type = PySedonaType::new(return_type.clone());
+            let py_args = PyTuple::new(py, py_values)?;
+
+            let result =
+                self.py_invoke_batch
+                    .call(py, (py_args, py_return_type, num_rows), None)?;
+            let result_bound = result.bind(py);
+            if !result_bound.hasattr("__arrow_c_array__")? {
+                return Err(
+                    PySedonaError::SedonaPython(
+                        "Expected result of user-defined function to return an 
object implementing __arrow_c_array__()".to_string()
+                    )
+                );
+            }
+
+            let (result_field, result_array) = 
import_arrow_array(result_bound)?;
+            let result_sedona_type = 
SedonaType::from_storage_field(&result_field)?;
+
+            if return_type != &result_sedona_type {
+                let return_type_storage = 
SedonaType::Arrow(return_type.storage_type().clone());
+                if return_type_storage != result_sedona_type {
+                    return Err(PySedonaError::SedonaPython(format!(
+                        "Expected result of user-defined function to return 
array of type {return_type} or its storage but got {result_sedona_type}"
+                    )));
+                }
+            }
+
+            if result_array.len() != num_rows {
+                return Err(PySedonaError::SedonaPython(format!(
+                    "Expected result of user-defined function to return array 
of length {num_rows} but got {}",
+                    result_array.len()
+                )));
+            }
+
+            Ok(result_array)
+        })?;
+
+        if args.is_empty() {
+            return Ok(ColumnarValue::Array(result));
+        }
+
+        for arg in args {
+            match arg {
+                ColumnarValue::Array(_) => return 
Ok(ColumnarValue::Array(result)),
+                ColumnarValue::Scalar(_) => {}
+            }
+        }
+
+        Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
+            &result, 0,
+        )?))
+    }
+}
+
+#[pyclass]
+#[derive(Clone)]
+pub struct PySedonaScalarUdf {
+    pub inner: SedonaScalarUDF,
+}
+
+#[pymethods]
+impl PySedonaScalarUdf {
+    fn __datafusion_scalar_udf__<'py>(
+        &self,
+        py: Python<'py>,
+    ) -> Result<Bound<'py, PyCapsule>, PySedonaError> {
+        let capsule_name = CString::new("datafusion_scalar_udf").unwrap();
+        let scalar_udf: ScalarUDF = self.inner.clone().into();
+        let ffi_scalar_udf = FFI_ScalarUDF::from(Arc::new(scalar_udf));
+        Ok(PyCapsule::new(py, ffi_scalar_udf, Some(capsule_name))?)
+    }
+}
+
+#[pyclass]
+#[derive(Debug)]
+pub struct PySedonaValue {
+    pub sedona_type: PySedonaType,
+    pub value: ColumnarValue,
+    pub num_rows: usize,
+}
+
+#[pymethods]
+impl PySedonaValue {
+    #[getter]
+    fn r#type(&self) -> Result<PySedonaType, PySedonaError> {
+        Ok(self.sedona_type.clone())
+    }
+
+    fn is_scalar(&self) -> bool {
+        matches!(&self.value, ColumnarValue::Scalar(_))
+    }
+
+    #[getter]
+    fn storage(&self) -> Result<Self, PySedonaError> {
+        Ok(PySedonaValue {
+            sedona_type: PySedonaType {
+                inner: 
SedonaType::Arrow(self.sedona_type.inner.storage_type().clone()),
+            },
+            value: self.value.clone(),
+            num_rows: self.num_rows,
+        })
+    }
+
+    fn to_array(&self) -> Result<Self, PySedonaError> {
+        Ok(PySedonaValue {
+            sedona_type: self.sedona_type.clone(),
+            value: ColumnarValue::Array(self.value.to_array(self.num_rows)?),
+            num_rows: self.num_rows,
+        })
+    }
+
+    fn __arrow_c_schema__<'py>(
+        &self,
+        py: Python<'py>,
+    ) -> Result<Bound<'py, PyCapsule>, PySedonaError> {
+        let schema_capsule_name = CString::new("arrow_schema").unwrap();
+        let storage_field = self.sedona_type.inner.to_storage_field("", true)?;
+        let ffi_schema = FFI_ArrowSchema::try_from(storage_field)?;
+        Ok(PyCapsule::new(py, ffi_schema, Some(schema_capsule_name))?)
+    }
+
+    #[pyo3(signature = (requsted_schema=None))]
+    fn __arrow_c_array__<'py>(
+        &self,
+        py: Python<'py>,
+        requsted_schema: Option<Bound<PyCapsule>>,
+    ) -> Result<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>), PySedonaError> 
{
+        if let Some(requested_schema) = requsted_schema {

Review Comment:
   Corrected spelling of 'requsted_schema' to 'requested_schema'.
   ```suggestion
       #[pyo3(signature = (requested_schema=None))]
       fn __arrow_c_array__<'py>(
           &self,
           py: Python<'py>,
           requested_schema: Option<Bound<PyCapsule>>,
       ) -> Result<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>), 
PySedonaError> {
           if let Some(requested_schema) = requested_schema {
   ```



##########
python/sedonadb/src/udf.rs:
##########
@@ -0,0 +1,341 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::{ffi::CString, iter::zip, sync::Arc};
+
+use arrow_array::{
+    ffi::{FFI_ArrowArray, FFI_ArrowSchema},
+    ArrayRef,
+};
+use arrow_schema::Field;
+use datafusion_common::{Result, ScalarValue};
+use datafusion_expr::{ColumnarValue, ScalarUDF, Volatility};
+use datafusion_ffi::udf::FFI_ScalarUDF;
+use pyo3::{
+    pyclass, pyfunction, pymethods,
+    types::{PyAnyMethods, PyCapsule, PyTuple},
+    Bound, PyObject, Python,
+};
+use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
+use sedona_schema::{datatypes::SedonaType, matchers::ArgMatcher};
+
+use crate::{
+    error::PySedonaError,
+    import_from::{check_pycapsule, import_arg_matcher, import_arrow_array, 
import_sedona_type},
+    schema::PySedonaType,
+};
+
+#[pyfunction]
+pub fn sedona_scalar_udf<'py>(
+    py: Python<'py>,
+    py_invoke_batch: PyObject,
+    py_return_type: PyObject,
+    py_input_types: Option<Vec<PyObject>>,
+    volatility: &str,
+    name: &str,
+) -> Result<PySedonaScalarUdf, PySedonaError> {
+    let volatility = match volatility {
+        "immutable" => Volatility::Immutable,
+        "stable" => Volatility::Stable,
+        "volatile" => Volatility::Volatile,
+        v => {
+            return Err(PySedonaError::SedonaPython(format!(
+                "Expected one of 'immutable', 'stable', or 'volatile' but got 
'{v}'"
+            )));
+        }
+    };
+
+    let scalar_kernel = sedona_scalar_kernel(py, py_input_types, 
py_return_type, py_invoke_batch)?;
+    let sedona_scalar_udf =
+        SedonaScalarUDF::new(name, vec![Arc::new(scalar_kernel)], volatility, 
None);
+
+    Ok(PySedonaScalarUdf {
+        inner: sedona_scalar_udf,
+    })
+}
+
+fn sedona_scalar_kernel<'py>(
+    py: Python<'py>,
+    input_types: Option<Vec<PyObject>>,
+    py_return_field: PyObject,
+    py_invoke_batch: PyObject,
+) -> Result<PySedonaScalarKernel, PySedonaError> {
+    let matcher = if let Some(input_types) = input_types {
+        let arg_matchers = input_types
+            .iter()
+            .map(|obj| import_arg_matcher(obj.bind(py)))
+            .collect::<Result<Vec<_>, _>>()?;
+        let return_type = import_sedona_type(py_return_field.bind(py))?;
+        Some(ArgMatcher::new(arg_matchers, return_type))
+    } else {
+        None
+    };
+
+    let kernel_impl = PySedonaScalarKernel {
+        matcher,
+        py_return_field,
+        py_invoke_batch,
+    };
+
+    Ok(kernel_impl)
+}
+
+#[derive(Debug)]
+struct PySedonaScalarKernel {
+    matcher: Option<ArgMatcher>,
+    py_return_field: PyObject,
+    py_invoke_batch: PyObject,
+}
+
+impl SedonaScalarKernel for PySedonaScalarKernel {
+    fn return_type(&self, _args: &[SedonaType]) -> Result<Option<SedonaType>> {
+        Err(PySedonaError::SedonaPython("Unexpected call to 
return_type()".to_string()).into())
+    }
+
+    fn invoke_batch(
+        &self,
+        _arg_types: &[SedonaType],
+        _args: &[ColumnarValue],
+    ) -> Result<ColumnarValue> {
+        Err(PySedonaError::SedonaPython("Unexpected call to 
invoke_batch()".to_string()).into())
+    }
+
+    fn return_type_from_args_and_scalars(
+        &self,
+        args: &[SedonaType],
+        scalar_args: &[Option<&ScalarValue>],
+    ) -> Result<Option<SedonaType>> {
+        if let Some(matcher) = &self.matcher {
+            let return_type = matcher.match_args(args)?;
+            return Ok(return_type);
+        }
+
+        let return_type = Python::with_gil(|py| -> Result<Option<SedonaType>, 
PySedonaError> {
+            let py_sedona_types = args
+                .iter()
+                .map(|arg| -> Result<_, PySedonaError> { 
Ok(PySedonaType::new(arg.clone())) })
+                .collect::<Result<Vec<_>, _>>()?;
+            let py_scalar_values = zip(&py_sedona_types, scalar_args)
+                .map(|(sedona_type, maybe_arg)| {
+                    maybe_arg.map(|arg| PySedonaValue {
+                        sedona_type: sedona_type.clone(),
+                        value: ColumnarValue::Scalar(arg.clone()),
+                        num_rows: 1,
+                    })
+                })
+                .collect::<Vec<_>>();
+
+            let py_return_field =
+                self.py_return_field
+                    .call(py, (py_sedona_types, py_scalar_values), None)?;
+            if py_return_field.is_none(py) {
+                return Ok(None);
+            }
+
+            let return_type = import_sedona_type(py_return_field.bind(py))?;
+            Ok(Some(return_type))
+        })?;
+
+        Ok(return_type)
+    }
+
+    fn invoke_batch_from_args(
+        &self,
+        arg_types: &[SedonaType],
+        args: &[ColumnarValue],
+        return_type: &SedonaType,
+        num_rows: usize,
+    ) -> Result<ColumnarValue> {
+        let result = Python::with_gil(|py| -> Result<ArrayRef, PySedonaError> {
+            let py_values = zip(arg_types, args)
+                .map(|(sedona_type, arg)| PySedonaValue {
+                    sedona_type: PySedonaType::new(sedona_type.clone()),
+                    value: arg.clone(),
+                    num_rows,
+                })
+                .collect::<Vec<_>>();
+
+            let py_return_type = PySedonaType::new(return_type.clone());
+            let py_args = PyTuple::new(py, py_values)?;
+
+            let result =
+                self.py_invoke_batch
+                    .call(py, (py_args, py_return_type, num_rows), None)?;
+            let result_bound = result.bind(py);
+            if !result_bound.hasattr("__arrow_c_array__")? {
+                return Err(
+                    PySedonaError::SedonaPython(
+                        "Expected result of user-defined function to return an 
object implementing __arrow_c_array__()".to_string()
+                    )
+                );
+            }
+
+            let (result_field, result_array) = 
import_arrow_array(result_bound)?;
+            let result_sedona_type = 
SedonaType::from_storage_field(&result_field)?;
+
+            if return_type != &result_sedona_type {
+                let return_type_storage = 
SedonaType::Arrow(return_type.storage_type().clone());
+                if return_type_storage != result_sedona_type {
+                    return Err(PySedonaError::SedonaPython(format!(
+                        "Expected result of user-defined function to return 
array of type {return_type} or its storage but got {result_sedona_type}"
+                    )));
+                }
+            }
+
+            if result_array.len() != num_rows {
+                return Err(PySedonaError::SedonaPython(format!(
+                    "Expected result of user-defined function to return array 
of length {num_rows} but got {}",
+                    result_array.len()
+                )));
+            }
+
+            Ok(result_array)
+        })?;
+
+        if args.is_empty() {
+            return Ok(ColumnarValue::Array(result));
+        }
+
+        for arg in args {
+            match arg {
+                ColumnarValue::Array(_) => return 
Ok(ColumnarValue::Array(result)),
+                ColumnarValue::Scalar(_) => {}
+            }
+        }
+
+        Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(
+            &result, 0,
+        )?))
+    }
+}
+
+#[pyclass]
+#[derive(Clone)]
+pub struct PySedonaScalarUdf {
+    pub inner: SedonaScalarUDF,
+}
+
+#[pymethods]
+impl PySedonaScalarUdf {
+    fn __datafusion_scalar_udf__<'py>(
+        &self,
+        py: Python<'py>,
+    ) -> Result<Bound<'py, PyCapsule>, PySedonaError> {
+        let capsule_name = CString::new("datafusion_scalar_udf").unwrap();
+        let scalar_udf: ScalarUDF = self.inner.clone().into();
+        let ffi_scalar_udf = FFI_ScalarUDF::from(Arc::new(scalar_udf));
+        Ok(PyCapsule::new(py, ffi_scalar_udf, Some(capsule_name))?)
+    }
+}
+
+#[pyclass]
+#[derive(Debug)]
+pub struct PySedonaValue {
+    pub sedona_type: PySedonaType,
+    pub value: ColumnarValue,
+    pub num_rows: usize,
+}
+
+#[pymethods]
+impl PySedonaValue {
+    #[getter]
+    fn r#type(&self) -> Result<PySedonaType, PySedonaError> {
+        Ok(self.sedona_type.clone())
+    }
+
+    fn is_scalar(&self) -> bool {
+        matches!(&self.value, ColumnarValue::Scalar(_))
+    }
+
+    #[getter]
+    fn storage(&self) -> Result<Self, PySedonaError> {
+        Ok(PySedonaValue {
+            sedona_type: PySedonaType {
+                inner: 
SedonaType::Arrow(self.sedona_type.inner.storage_type().clone()),
+            },
+            value: self.value.clone(),
+            num_rows: self.num_rows,
+        })
+    }
+
+    fn to_array(&self) -> Result<Self, PySedonaError> {
+        Ok(PySedonaValue {
+            sedona_type: self.sedona_type.clone(),
+            value: ColumnarValue::Array(self.value.to_array(self.num_rows)?),
+            num_rows: self.num_rows,
+        })
+    }
+
+    fn __arrow_c_schema__<'py>(
+        &self,
+        py: Python<'py>,
+    ) -> Result<Bound<'py, PyCapsule>, PySedonaError> {
+        let schema_capsule_name = CString::new("arrow_schema").unwrap();
+        let storage_field = self.sedona_type.inner.to_storage_field("", true)?;
+        let ffi_schema = FFI_ArrowSchema::try_from(storage_field)?;
+        Ok(PyCapsule::new(py, ffi_schema, Some(schema_capsule_name))?)
+    }
+
+    #[pyo3(signature = (requsted_schema=None))]
+    fn __arrow_c_array__<'py>(
+        &self,
+        py: Python<'py>,
+        requsted_schema: Option<Bound<PyCapsule>>,
+    ) -> Result<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>), PySedonaError> 
{
+        if let Some(requested_schema) = requsted_schema {

Review Comment:
   Corrected spelling of 'requsted_schema' to 'requested_schema'.
   ```suggestion
       #[pyo3(signature = (requested_schema=None))]
       fn __arrow_c_array__<'py>(
           &self,
           py: Python<'py>,
           requested_schema: Option<Bound<PyCapsule>>,
       ) -> Result<(Bound<'py, PyCapsule>, Bound<'py, PyCapsule>), 
PySedonaError> {
           if let Some(requested_schema) = requested_schema {
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to