This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 71742fb feat: Implement native st_pointz, st_pointm, st_pointzm (#13)
71742fb is described below
commit 71742fb788a78f225d5978361467193ae734b00b
Author: Peter Nguyen <[email protected]>
AuthorDate: Wed Sep 3 14:24:56 2025 -0700
feat: Implement native st_pointz, st_pointm, st_pointzm (#13)
* Initial implementation of st_point z, m, zm
* Add more null tests
* Add array and null tests
* Improve readability through comments and adding WKB_HEADER_SIZE variable
* Add benches
* Change sedona_internal_err! to unreachable!() since doc() can't propagate
a Result
* Remove unused import
* Add support for ArrayArrayArray and ArrayArrayArrayArray benchmarkargs
* Fix benches
* Add python integration tests
* Remove unnecessary num_coords field
* Fix python integration tests
* Add docstring to invoke_arrays
* Separate into separate doc functions
* Temporarily comment out new benches for convenient comparison
* Save progress (revert this)
* Revert st_point.rs to the original
* Add new st_pointzm.rs
* Re-enable rust benchs
* Re-use function
---
python/sedonadb/tests/functions/test_functions.py | 64 +++
rust/sedona-functions/benches/native-functions.rs | 37 ++
rust/sedona-functions/src/lib.rs | 1 +
rust/sedona-functions/src/register.rs | 3 +
rust/sedona-functions/src/st_pointzm.rs | 556 ++++++++++++++++++++++
rust/sedona-testing/src/benchmark_util.rs | 104 +++-
rust/sedona-testing/src/testers.rs | 3 +-
7 files changed, 765 insertions(+), 3 deletions(-)
diff --git a/python/sedonadb/tests/functions/test_functions.py
b/python/sedonadb/tests/functions/test_functions.py
index 306e0c5..870b960 100644
--- a/python/sedonadb/tests/functions/test_functions.py
+++ b/python/sedonadb/tests/functions/test_functions.py
@@ -588,6 +588,70 @@ def test_st_point(eng, x, y, expected):
)
[email protected]("eng", [SedonaDB, PostGIS])
[email protected](
+ ("x", "y", "z", "expected"),
+ [
+ (None, None, None, None),
+ (1, None, None, None),
+ (None, 1, None, None),
+ (None, None, 1, None),
+ (1, 1, 1, "POINT Z (1 1 1)"),
+ (1.0, 1.0, 1.0, "POINT Z (1 1 1)"),
+ (10, -1.5, 1.0, "POINT Z (10 -1.5 1)"),
+ ],
+)
+def test_st_pointz(eng, x, y, z, expected):
+ eng = eng.create_or_skip()
+ eng.assert_query_result(
+ f"SELECT ST_PointZ({val_or_null(x)}, {val_or_null(y)},
{val_or_null(z)})",
+ expected,
+ )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
[email protected](
+ ("x", "y", "m", "expected"),
+ [
+ (None, None, None, None),
+ (1, None, None, None),
+ (None, 1, None, None),
+ (None, None, 1, None),
+ (1, 1, 1, "POINT M (1 1 1)"),
+ (1.0, 1.0, 1.0, "POINT M (1 1 1)"),
+ (10, -1.5, 1.0, "POINT M (10 -1.5 1)"),
+ ],
+)
+def test_st_pointm(eng, x, y, m, expected):
+ eng = eng.create_or_skip()
+ eng.assert_query_result(
+ f"SELECT ST_PointM({val_or_null(x)}, {val_or_null(y)},
{val_or_null(m)})",
+ expected,
+ )
+
+
[email protected]("eng", [SedonaDB, PostGIS])
[email protected](
+ ("x", "y", "z", "m", "expected"),
+ [
+ (None, None, None, None, None),
+ (1, None, None, None, None),
+ (None, 1, None, None, None),
+ (None, None, 1, None, None),
+ (None, None, None, 1, None),
+ (1, 1, 1, 1, "POINT ZM (1 1 1 1)"),
+ (1.0, 1.0, 1.0, 1.0, "POINT ZM (1 1 1 1)"),
+ (10, -1.5, 1.0, 1.0, "POINT ZM (10 -1.5 1 1)"),
+ ],
+)
+def test_st_pointzm(eng, x, y, z, m, expected):
+ eng = eng.create_or_skip()
+ eng.assert_query_result(
+ f"SELECT ST_PointZM({val_or_null(x)}, {val_or_null(y)},
{val_or_null(z)}, {val_or_null(m)})",
+ expected,
+ )
+
+
@pytest.mark.parametrize("eng", [SedonaDB, PostGIS])
@pytest.mark.parametrize(
("geom", "expected"),
diff --git a/rust/sedona-functions/benches/native-functions.rs
b/rust/sedona-functions/benches/native-functions.rs
index 4ea0bde..0aa4c1a 100644
--- a/rust/sedona-functions/benches/native-functions.rs
+++ b/rust/sedona-functions/benches/native-functions.rs
@@ -66,6 +66,43 @@ fn criterion_benchmark(c: &mut Criterion) {
BenchmarkArgs::ArrayArray(Float64(0.0, 100.0), Float64(0.0, 100.0)),
);
+ benchmark::scalar(
+ c,
+ &f,
+ "native",
+ "st_pointz",
+ BenchmarkArgs::ArrayArrayArray(
+ Float64(0.0, 100.0),
+ Float64(0.0, 100.0),
+ Float64(0.0, 100.0),
+ ),
+ );
+
+ benchmark::scalar(
+ c,
+ &f,
+ "native",
+ "st_pointm",
+ BenchmarkArgs::ArrayArrayArray(
+ Float64(0.0, 100.0),
+ Float64(0.0, 100.0),
+ Float64(0.0, 100.0),
+ ),
+ );
+
+ benchmark::scalar(
+ c,
+ &f,
+ "native",
+ "st_pointzm",
+ BenchmarkArgs::ArrayArrayArrayArray(
+ Float64(0.0, 100.0),
+ Float64(0.0, 100.0),
+ Float64(0.0, 100.0),
+ Float64(0.0, 100.0),
+ ),
+ );
+
benchmark::scalar(c, &f, "native", "st_hasz", Point);
benchmark::scalar(c, &f, "native", "st_hasz", LineString(10));
diff --git a/rust/sedona-functions/src/lib.rs b/rust/sedona-functions/src/lib.rs
index f633d09..c8d95db 100644
--- a/rust/sedona-functions/src/lib.rs
+++ b/rust/sedona-functions/src/lib.rs
@@ -41,6 +41,7 @@ mod st_isempty;
mod st_length;
mod st_perimeter;
mod st_point;
+mod st_pointzm;
mod st_setsrid;
mod st_transform;
pub mod st_union_aggr;
diff --git a/rust/sedona-functions/src/register.rs
b/rust/sedona-functions/src/register.rs
index b6a7cfc..3f7f931 100644
--- a/rust/sedona-functions/src/register.rs
+++ b/rust/sedona-functions/src/register.rs
@@ -81,6 +81,9 @@ pub fn default_function_set() -> FunctionSet {
crate::st_perimeter::st_perimeter_udf,
crate::st_point::st_geogpoint_udf,
crate::st_point::st_point_udf,
+ crate::st_pointzm::st_pointz_udf,
+ crate::st_pointzm::st_pointm_udf,
+ crate::st_pointzm::st_pointzm_udf,
crate::st_transform::st_transform_udf,
crate::st_setsrid::st_set_srid_udf,
crate::st_xyzm::st_m_udf,
diff --git a/rust/sedona-functions/src/st_pointzm.rs
b/rust/sedona-functions/src/st_pointzm.rs
new file mode 100644
index 0000000..2899d77
--- /dev/null
+++ b/rust/sedona-functions/src/st_pointzm.rs
@@ -0,0 +1,556 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+use std::{
+ io::{Cursor, Write},
+ sync::Arc,
+ vec,
+};
+
+use arrow_array::{builder::BinaryBuilder, Array};
+use arrow_schema::DataType;
+use datafusion_common::cast::as_float64_array;
+use datafusion_common::error::Result;
+use datafusion_common::scalar::ScalarValue;
+use datafusion_common::DataFusionError;
+use datafusion_expr::{
+ scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation,
Volatility,
+};
+use geo_traits::Dimensions;
+use sedona_expr::scalar_udf::{ArgMatcher, SedonaScalarKernel, SedonaScalarUDF};
+use sedona_geometry::{
+ error::SedonaGeometryError,
+ wkb_factory::{write_wkb_coord, write_wkb_point_header},
+};
+use sedona_schema::datatypes::{SedonaType, WKB_GEOMETRY};
+
+use crate::executor::WkbExecutor;
+
+// 1 byte for endian(1) + 4 bytes for type(4)
+const WKB_HEADER_SIZE: usize = 5;
+
+/// ST_PointZ() scalar UDF implementation
+///
+/// Native implementation to create Z geometries from coordinates.
+pub fn st_pointz_udf() -> SedonaScalarUDF {
+ SedonaScalarUDF::new(
+ "st_pointz",
+ vec![Arc::new(STGeoFromPointZm {
+ out_type: WKB_GEOMETRY,
+ dim: Dimensions::Xyz,
+ })],
+ Volatility::Immutable,
+ Some(three_coord_point_doc("ST_PointZ", "Geometry", "Z")),
+ )
+}
+
+/// ST_PointM() scalar UDF implementation
+///
+/// Native implementation to create M geometries from coordinates.
+pub fn st_pointm_udf() -> SedonaScalarUDF {
+ SedonaScalarUDF::new(
+ "st_pointm",
+ vec![Arc::new(STGeoFromPointZm {
+ out_type: WKB_GEOMETRY,
+ dim: Dimensions::Xym,
+ })],
+ Volatility::Immutable,
+ Some(three_coord_point_doc("ST_PointM", "Geometry", "M")),
+ )
+}
+
+/// ST_PointZM() scalar UDF implementation
+///
+/// Native implementation to create ZM geometries from coordinates.
+pub fn st_pointzm_udf() -> SedonaScalarUDF {
+ SedonaScalarUDF::new(
+ "st_pointzm",
+ vec![Arc::new(STGeoFromPointZm {
+ out_type: WKB_GEOMETRY,
+ dim: Dimensions::Xyzm,
+ })],
+ Volatility::Immutable,
+ Some(xyzm_point_doc("ST_PointZM", "Geometry")),
+ )
+}
+
+fn three_coord_point_doc(name: &str, out_type_name: &str, third_dim: &str) ->
Documentation {
+ Documentation::builder(
+ DOC_SECTION_OTHER,
+ format!(
+ "Construct a Point {} from X, Y and {}",
+ out_type_name.to_lowercase(),
+ third_dim
+ ),
+ format!("{name} (x: Double, y: Double, z: Double)"),
+ )
+ .with_argument("x", "double: X value")
+ .with_argument("y", "double: Y value")
+ .with_argument(
+ third_dim.to_lowercase(),
+ format!("double: {} value", third_dim),
+ )
+ .with_sql_example(format!("{name}(-64.36, 45.09, 100.0)"))
+ .build()
+}
+
+fn xyzm_point_doc(name: &str, out_type_name: &str) -> Documentation {
+ Documentation::builder(
+ DOC_SECTION_OTHER,
+ format!(
+ "Construct a Point {} from X, Y, Z and M",
+ out_type_name.to_lowercase()
+ ),
+ format!("{name} (x: Double, y: Double, z: Double)"),
+ )
+ .with_argument("x", "double: X value")
+ .with_argument("y", "double: Y value")
+ .with_argument("z", "double: Z value")
+ .with_argument("m", "double: M value")
+ .with_sql_example(format!("{name}(-64.36, 45.09, 100.0, 50.0)"))
+ .build()
+}
+
+#[derive(Debug)]
+struct STGeoFromPointZm {
+ out_type: SedonaType,
+ dim: Dimensions,
+}
+
+impl SedonaScalarKernel for STGeoFromPointZm {
+ fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
+ let num_coords = self.dim.size();
+ let expected_args = vec![ArgMatcher::is_numeric(); num_coords];
+ let matcher = ArgMatcher::new(expected_args, self.out_type.clone());
+ matcher.match_args(args)
+ }
+
+ fn invoke_batch(
+ &self,
+ arg_types: &[SedonaType],
+ args: &[ColumnarValue],
+ ) -> Result<ColumnarValue> {
+ let num_coords = self.dim.size();
+ let executor = WkbExecutor::new(arg_types, args);
+
+ // Cast all arguments to Float64
+ let coord_values: Result<Vec<_>> = args
+ .iter()
+ .map(|arg| arg.cast_to(&DataType::Float64, None))
+ .collect();
+ let coord_values = coord_values?;
+
+ // Check if all arguments are scalars
+ let all_scalars = coord_values
+ .iter()
+ .all(|v| matches!(v, ColumnarValue::Scalar(_)));
+
+ if all_scalars {
+ let scalar_coords: Result<Vec<_>> = coord_values
+ .iter()
+ .map(|v| match v {
+ ColumnarValue::Scalar(ScalarValue::Float64(val)) =>
Ok(*val),
+ _ => Err(datafusion_common::DataFusionError::Internal(
+ "Expected Float64 scalar".to_string(),
+ )),
+ })
+ .collect();
+ let scalar_coords = scalar_coords?;
+
+ // Check if any coordinate is null
+ if scalar_coords.iter().any(|coord| coord.is_none()) {
+ return Ok(ScalarValue::Binary(None).into());
+ }
+
+ // Populate WKB with coordinates
+ let coord_values: Vec<f64> = scalar_coords.into_iter().map(|c|
c.unwrap()).collect();
+ let mut buffer = Vec::new();
+ let mut cursor = Cursor::new(&mut buffer);
+ write_wkb_pointzm(&mut cursor, &coord_values, self.dim)
+ .map_err(|err| -> DataFusionError {
DataFusionError::External(Box::new(err)) })?;
+ return Ok(ScalarValue::Binary(Some(buffer)).into());
+ }
+
+ // Handle array case
+ let coord_arrays: Result<Vec<_>> = coord_values
+ .iter()
+ .map(|v| v.to_array(executor.num_iterations()))
+ .collect();
+ let coord_arrays = coord_arrays?;
+
+ let coord_f64_arrays: Result<Vec<_>> = coord_arrays
+ .iter()
+ .map(|array| as_float64_array(array))
+ .collect();
+ let coord_f64_arrays = coord_f64_arrays?;
+
+ // Calculate WKB item size based on coordinates: endian(1) + type(4) +
coords(8 each)
+ let wkb_size = WKB_HEADER_SIZE + (num_coords * 8);
+ let mut builder = BinaryBuilder::with_capacity(
+ executor.num_iterations(),
+ wkb_size * executor.num_iterations(),
+ );
+
+ for i in 0..executor.num_iterations() {
+ let num_dimensions = self.dim.size();
+ let arrays = (0..num_dimensions)
+ .map(|j| coord_f64_arrays[j])
+ .collect::<Vec<_>>();
+ let any_null = arrays.iter().any(|&v| v.is_null(i));
+ let values = arrays.iter().map(|v| v.value(i)).collect::<Vec<_>>();
+ if !any_null {
+ write_wkb_pointzm(&mut builder, &values, self.dim).map_err(|_|
{
+ datafusion_common::DataFusionError::Internal(
+ "Failed to write WKB point header".to_string(),
+ )
+ })?;
+ builder.append_value([]);
+ } else {
+ builder.append_null();
+ }
+ }
+
+ let new_array = builder.finish();
+ Ok(ColumnarValue::Array(Arc::new(new_array)))
+ }
+}
+
+fn write_wkb_pointzm(
+ buf: &mut impl Write,
+ coords: &[f64],
+ dim: Dimensions,
+) -> Result<(), SedonaGeometryError> {
+ let values = coords;
+ write_wkb_point_header(buf, dim)?;
+ match dim.size() {
+ 3 => {
+ let coord = (values[0], values[1], values[2]);
+ write_wkb_coord(buf, coord)
+ }
+ 4 => {
+ let coord = (values[0], values[1], values[2], values[3]);
+ write_wkb_coord(buf, coord)
+ }
+ _ => Err(SedonaGeometryError::Invalid(
+ "Unsupported number of dimensions".to_string(),
+ )),
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use arrow_array::create_array;
+ use arrow_schema::DataType;
+ use datafusion_expr::ScalarUDF;
+ use rstest::rstest;
+
+ use sedona_testing::{
+ compare::assert_value_equal,
+ create::{create_array_value, create_scalar_value},
+ };
+
+ use super::*;
+
+ #[test]
+ fn udf_metadata() {
+ let pointz: ScalarUDF = st_pointz_udf().into();
+ assert_eq!(pointz.name(), "st_pointz");
+ assert!(pointz.documentation().is_some());
+
+ let pointm: ScalarUDF = st_pointm_udf().into();
+ assert_eq!(pointm.name(), "st_pointm");
+ assert!(pointm.documentation().is_some());
+
+ let pointzm: ScalarUDF = st_pointzm_udf().into();
+ assert_eq!(pointzm.name(), "st_pointzm");
+ assert!(pointzm.documentation().is_some());
+ }
+
+ #[rstest]
+ #[case(DataType::Float64, DataType::Float64)]
+ #[case(DataType::Float32, DataType::Float64)]
+ #[case(DataType::Float64, DataType::Float32)]
+ #[case(DataType::Float32, DataType::Float32)]
+ fn udf_invoke(#[case] lhs_type: DataType, #[case] rhs_type: DataType) {
+ // Just test one of the UDFs
+ // We have other functions to ensure the logic works for all Z, M, and
ZM
+ let udf = st_pointz_udf();
+
+ let scalar_null_1 =
ScalarValue::Float64(None).cast_to(&lhs_type).unwrap();
+ let scalar_1 =
ScalarValue::Float64(Some(1.0)).cast_to(&lhs_type).unwrap();
+ let scalar_null_2 =
ScalarValue::Float64(None).cast_to(&rhs_type).unwrap();
+ let scalar_2 =
ScalarValue::Float64(Some(2.0)).cast_to(&rhs_type).unwrap();
+ let scalar_3 =
ScalarValue::Float64(Some(3.0)).cast_to(&lhs_type).unwrap();
+ let array_1 =
+ ColumnarValue::Array(create_array!(Float64, [Some(1.0), Some(2.0),
None, None]))
+ .cast_to(&lhs_type, None)
+ .unwrap();
+ let array_2 =
+ ColumnarValue::Array(create_array!(Float64, [Some(5.0), None,
Some(7.0), None]))
+ .cast_to(&rhs_type, None)
+ .unwrap();
+
+ let array_3 = ColumnarValue::Array(create_array!(
+ Float64,
+ [Some(3.0), Some(3.0), Some(3.0), None]
+ ))
+ .cast_to(&lhs_type, None)
+ .unwrap();
+
+ // Check scalar
+ assert_value_equal(
+ &udf.invoke_batch(
+ &[
+ scalar_1.clone().into(),
+ scalar_2.clone().into(),
+ scalar_3.clone().into(),
+ ],
+ 1,
+ )
+ .unwrap(),
+ &create_scalar_value(Some("POINT Z (1 2 3)"), &WKB_GEOMETRY),
+ );
+
+ // Check scalar null combinations
+ assert_value_equal(
+ &udf.invoke_batch(
+ &[
+ scalar_1.clone().into(),
+ scalar_null_2.clone().into(),
+ scalar_3.clone().into(),
+ ],
+ 1,
+ )
+ .unwrap(),
+ &create_scalar_value(None, &WKB_GEOMETRY),
+ );
+
+ assert_value_equal(
+ &udf.invoke_batch(
+ &[
+ scalar_null_1.clone().into(),
+ scalar_2.clone().into(),
+ scalar_3.clone().into(),
+ ],
+ 1,
+ )
+ .unwrap(),
+ &create_scalar_value(None, &WKB_GEOMETRY),
+ );
+
+ assert_value_equal(
+ &udf.invoke_batch(
+ &[
+ scalar_null_1.clone().into(),
+ scalar_null_2.clone().into(),
+ scalar_3.clone().into(),
+ ],
+ 1,
+ )
+ .unwrap(),
+ &create_scalar_value(None, &WKB_GEOMETRY),
+ );
+
+ // Check array
+ assert_value_equal(
+ &udf.invoke_batch(&[array_1.clone(), array_2.clone(),
array_3.clone()], 4)
+ .unwrap(),
+ &create_array_value(&[Some("POINT Z (1 5 3)"), None, None, None],
&WKB_GEOMETRY),
+ );
+
+ // Check array/scalar combinations
+ assert_value_equal(
+ &udf.invoke_batch(
+ &[
+ array_1.clone(),
+ scalar_2.clone().into(),
+ scalar_3.clone().into(),
+ ],
+ 4,
+ )
+ .unwrap(),
+ &create_array_value(
+ &[Some("POINT Z (1 2 3)"), Some("POINT Z (2 2 3)"), None,
None],
+ &WKB_GEOMETRY,
+ ),
+ );
+
+ assert_value_equal(
+ &udf.invoke_batch(&[scalar_1.clone().into(), array_2,
array_3.clone()], 4)
+ .unwrap(),
+ &create_array_value(
+ &[Some("POINT Z (1 5 3)"), None, Some("POINT Z (1 7 3)"),
None],
+ &WKB_GEOMETRY,
+ ),
+ );
+ }
+
+ #[test]
+ fn test_pointz() {
+ let udf = st_pointz_udf();
+
+ // Test scalar case
+ assert_value_equal(
+ &udf.invoke_batch(
+ &[
+ ScalarValue::Float64(Some(1.0)).into(),
+ ScalarValue::Float64(Some(2.0)).into(),
+ ScalarValue::Float64(Some(3.0)).into(),
+ ],
+ 1,
+ )
+ .unwrap(),
+ &create_scalar_value(Some("POINT Z (1 2 3)"), &WKB_GEOMETRY),
+ );
+
+ // Test array and null cases
+ // Even if xy are valid, result is null if z is null
+ let x_array =
+ ColumnarValue::Array(create_array!(Float64, [Some(1.0), Some(2.0),
None, None]))
+ .cast_to(&DataType::Float64, None)
+ .unwrap();
+
+ let y_array = ColumnarValue::Array(create_array!(
+ Float64,
+ [Some(5.0), Some(1.0), Some(7.0), None]
+ ))
+ .cast_to(&DataType::Float64, None)
+ .unwrap();
+
+ let z_array =
+ ColumnarValue::Array(create_array!(Float64, [Some(10.0), None,
Some(12.0), None]))
+ .cast_to(&DataType::Float64, None)
+ .unwrap();
+
+ assert_value_equal(
+ &udf.invoke_batch(&[x_array.clone(), y_array.clone(),
z_array.clone()], 1)
+ .unwrap(),
+ &create_array_value(&[Some("POINT Z (1 5 10)"), None, None, None],
&WKB_GEOMETRY),
+ );
+ }
+
+ #[test]
+ fn test_pointm() {
+ let udf = st_pointm_udf();
+
+ // Test scalar case
+ assert_value_equal(
+ &udf.invoke_batch(
+ &[
+ ScalarValue::Float64(Some(1.0)).into(),
+ ScalarValue::Float64(Some(2.0)).into(),
+ ScalarValue::Float64(Some(4.0)).into(),
+ ],
+ 1,
+ )
+ .unwrap(),
+ &create_scalar_value(Some("POINT M (1 2 4)"), &WKB_GEOMETRY),
+ );
+
+ // Test array and null cases
+ // Even if xy are valid, result is null if z is null
+ let x_array =
+ ColumnarValue::Array(create_array!(Float64, [Some(1.0), Some(2.0),
None, None]))
+ .cast_to(&DataType::Float64, None)
+ .unwrap();
+
+ let y_array = ColumnarValue::Array(create_array!(
+ Float64,
+ [Some(5.0), Some(1.0), Some(7.0), None]
+ ))
+ .cast_to(&DataType::Float64, None)
+ .unwrap();
+
+ let m_array =
+ ColumnarValue::Array(create_array!(Float64, [Some(10.0), None,
Some(12.0), None]))
+ .cast_to(&DataType::Float64, None)
+ .unwrap();
+
+ assert_value_equal(
+ &udf.invoke_batch(&[x_array.clone(), y_array.clone(),
m_array.clone()], 1)
+ .unwrap(),
+ &create_array_value(&[Some("POINT M (1 5 10)"), None, None, None],
&WKB_GEOMETRY),
+ );
+ }
+
+ #[test]
+ fn test_pointzm() {
+ let udf = st_pointzm_udf();
+
+ // Test scalar case
+ assert_value_equal(
+ &udf.invoke_batch(
+ &[
+ ScalarValue::Float64(Some(1.0)).into(),
+ ScalarValue::Float64(Some(2.0)).into(),
+ ScalarValue::Float64(Some(3.0)).into(),
+ ScalarValue::Float64(Some(4.0)).into(),
+ ],
+ 1,
+ )
+ .unwrap(),
+ &create_scalar_value(Some("POINT ZM (1 2 3 4)"), &WKB_GEOMETRY),
+ );
+
+ // Even if xy are valid, result is null if z or m is null
+ // Test array and null cases
+ // Even if xy are valid, result is null if z is null
+ let x_array = ColumnarValue::Array(create_array!(
+ Float64,
+ [Some(1.0), Some(2.0), None, Some(1.0)]
+ ))
+ .cast_to(&DataType::Float64, None)
+ .unwrap();
+
+ let y_array = ColumnarValue::Array(create_array!(
+ Float64,
+ [Some(5.0), Some(1.0), Some(7.0), Some(2.0)]
+ ))
+ .cast_to(&DataType::Float64, None)
+ .unwrap();
+
+ let z_array = ColumnarValue::Array(create_array!(
+ Float64,
+ [Some(20.0), Some(1.0), Some(7.0), None]
+ ))
+ .cast_to(&DataType::Float64, None)
+ .unwrap();
+
+ let m_array = ColumnarValue::Array(create_array!(
+ Float64,
+ [Some(10.0), None, Some(12.0), Some(4.0)]
+ ))
+ .cast_to(&DataType::Float64, None)
+ .unwrap();
+
+ assert_value_equal(
+ &udf.invoke_batch(
+ &[
+ x_array.clone(),
+ y_array.clone(),
+ z_array.clone(),
+ m_array.clone(),
+ ],
+ 1,
+ )
+ .unwrap(),
+ &create_array_value(
+ &[Some("POINT ZM (1 5 20 10)"), None, None, None],
+ &WKB_GEOMETRY,
+ ),
+ );
+ }
+}
diff --git a/rust/sedona-testing/src/benchmark_util.rs
b/rust/sedona-testing/src/benchmark_util.rs
index ad4cff3..bd2d01d 100644
--- a/rust/sedona-testing/src/benchmark_util.rs
+++ b/rust/sedona-testing/src/benchmark_util.rs
@@ -171,6 +171,15 @@ pub enum BenchmarkArgs {
ArrayScalarScalar(BenchmarkArgSpec, BenchmarkArgSpec, BenchmarkArgSpec),
/// Invoke a ternary function with two arrays and a scalar
ArrayArrayScalar(BenchmarkArgSpec, BenchmarkArgSpec, BenchmarkArgSpec),
+ /// Invoke a ternary function with three arrays
+ ArrayArrayArray(BenchmarkArgSpec, BenchmarkArgSpec, BenchmarkArgSpec),
+ /// Invoke a quaternary function with four arrays
+ ArrayArrayArrayArray(
+ BenchmarkArgSpec,
+ BenchmarkArgSpec,
+ BenchmarkArgSpec,
+ BenchmarkArgSpec,
+ ),
}
impl From<BenchmarkArgSpec> for BenchmarkArgs {
@@ -190,7 +199,9 @@ impl BenchmarkArgs {
let array_configs = match self {
BenchmarkArgs::Array(_)
| BenchmarkArgs::ArrayArray(_, _)
- | BenchmarkArgs::ArrayArrayScalar(_, _, _) => self.specs(),
+ | BenchmarkArgs::ArrayArrayScalar(_, _, _)
+ | BenchmarkArgs::ArrayArrayArray(_, _, _)
+ | BenchmarkArgs::ArrayArrayArrayArray(_, _, _, _) => self.specs(),
BenchmarkArgs::ScalarArray(_, col)
| BenchmarkArgs::ArrayScalar(col, _)
| BenchmarkArgs::ArrayScalarScalar(col, _, _) => {
@@ -238,9 +249,13 @@ impl BenchmarkArgs {
vec![col0.clone(), col1.clone()]
}
BenchmarkArgs::ArrayScalarScalar(col0, col1, col2)
- | BenchmarkArgs::ArrayArrayScalar(col0, col1, col2) => {
+ | BenchmarkArgs::ArrayArrayScalar(col0, col1, col2)
+ | BenchmarkArgs::ArrayArrayArray(col0, col1, col2) => {
vec![col0.clone(), col1.clone(), col2.clone()]
}
+ BenchmarkArgs::ArrayArrayArrayArray(col0, col1, col2, col3) => {
+ vec![col0.clone(), col1.clone(), col2.clone(), col3.clone()]
+ }
}
}
}
@@ -466,6 +481,25 @@ impl BenchmarkData {
)?;
}
}
+ BenchmarkArgs::ArrayArrayArray(_, _, _) => {
+ for i in 0..self.num_batches {
+ tester.invoke_arrays(vec![
+ self.arrays[0][i].clone(),
+ self.arrays[1][i].clone(),
+ self.arrays[2][i].clone(),
+ ])?;
+ }
+ }
+ BenchmarkArgs::ArrayArrayArrayArray(_, _, _, _) => {
+ for i in 0..self.num_batches {
+ tester.invoke_arrays(vec![
+ self.arrays[0][i].clone(),
+ self.arrays[1][i].clone(),
+ self.arrays[2][i].clone(),
+ self.arrays[3][i].clone(),
+ ])?;
+ }
+ }
}
Ok(())
@@ -771,4 +805,70 @@ mod test {
assert_eq!(data.scalars[0].data_type(), DataType::Float64);
}
+
+ #[test]
+ fn args_array_array_array() {
+ let spec = BenchmarkArgs::ArrayArrayArray(
+ BenchmarkArgSpec::Point,
+ BenchmarkArgSpec::Point,
+ BenchmarkArgSpec::Float64(1.0, 2.0),
+ );
+ assert_eq!(
+ spec.sedona_types(),
+ [
+ WKB_GEOMETRY,
+ WKB_GEOMETRY,
+ DataType::Float64.try_into().unwrap()
+ ]
+ );
+
+ let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
+ assert_eq!(data.num_batches, 2);
+ assert_eq!(data.arrays.len(), 3);
+ assert_eq!(data.scalars.len(), 0);
+ assert_eq!(data.arrays[0].len(), 2);
+ assert_eq!(
+ WKB_GEOMETRY,
+ data.arrays[0][0].data_type().try_into().unwrap()
+ );
+ assert_eq!(data.arrays[1].len(), 2);
+ assert_eq!(
+ WKB_GEOMETRY,
+ data.arrays[1][0].data_type().try_into().unwrap()
+ );
+ assert_eq!(data.arrays[2].len(), 2);
+ assert_eq!(data.arrays[2][0].data_type(), &DataType::Float64);
+ }
+
+ #[test]
+ fn args_array_array_array_array() {
+ let spec = BenchmarkArgs::ArrayArrayArrayArray(
+ BenchmarkArgSpec::Float64(1.0, 2.0),
+ BenchmarkArgSpec::Float64(3.0, 4.0),
+ BenchmarkArgSpec::Float64(5.0, 6.0),
+ BenchmarkArgSpec::Float64(7.0, 8.0),
+ );
+ assert_eq!(
+ spec.sedona_types(),
+ [
+ DataType::Float64.try_into().unwrap(),
+ DataType::Float64.try_into().unwrap(),
+ DataType::Float64.try_into().unwrap(),
+ DataType::Float64.try_into().unwrap()
+ ]
+ );
+
+ let data = spec.build_data(2, ROWS_PER_BATCH).unwrap();
+ assert_eq!(data.num_batches, 2);
+ assert_eq!(data.arrays.len(), 4);
+ assert_eq!(data.scalars.len(), 0);
+ assert_eq!(data.arrays[0].len(), 2);
+ assert_eq!(data.arrays[0][0].data_type(), &DataType::Float64);
+ assert_eq!(data.arrays[1].len(), 2);
+ assert_eq!(data.arrays[1][0].data_type(), &DataType::Float64);
+ assert_eq!(data.arrays[2].len(), 2);
+ assert_eq!(data.arrays[2][0].data_type(), &DataType::Float64);
+ assert_eq!(data.arrays[3].len(), 2);
+ assert_eq!(data.arrays[3][0].data_type(), &DataType::Float64);
+ }
}
diff --git a/rust/sedona-testing/src/testers.rs
b/rust/sedona-testing/src/testers.rs
index aa3c91e..be084b4 100644
--- a/rust/sedona-testing/src/testers.rs
+++ b/rust/sedona-testing/src/testers.rs
@@ -368,7 +368,8 @@ impl ScalarUdfTester {
}
}
- fn invoke_arrays(&self, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
+ // Invoke a function with a set of arrays
+ pub fn invoke_arrays(&self, arrays: Vec<ArrayRef>) -> Result<ArrayRef> {
let args = zip(arrays, &self.arg_types)
.map(|(array, sedona_type)| {
ColumnarValue::Array(array).cast_to(&sedona_type.data_type(),
None)