This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 53e70bd2 Add RS_Example and support for raster scalars (#307)
53e70bd2 is described below

commit 53e70bd2ed484d547defcc690b1f4add6d744cc8
Author: jp <[email protected]>
AuthorDate: Fri Nov 14 22:21:07 2025 -0800

    Add RS_Example and support for raster scalars (#307)
---
 rust/sedona-raster-functions/src/executor.rs   | 148 ++++++++++++++++++------
 rust/sedona-raster-functions/src/lib.rs        |   1 +
 rust/sedona-raster-functions/src/register.rs   |   6 +-
 rust/sedona-raster-functions/src/rs_example.rs | 152 +++++++++++++++++++++++++
 4 files changed, 273 insertions(+), 34 deletions(-)

diff --git a/rust/sedona-raster-functions/src/executor.rs 
b/rust/sedona-raster-functions/src/executor.rs
index 75123e50..4cf0f801 100644
--- a/rust/sedona-raster-functions/src/executor.rs
+++ b/rust/sedona-raster-functions/src/executor.rs
@@ -54,7 +54,7 @@ impl<'a, 'b> RasterExecutor<'a, 'b> {
     ///
     /// This handles the common pattern of:
     /// 1. Downcasting array to StructArray
-    /// 2. Creating raster iterator
+    /// 2. Creating raster array
     /// 3. Iterating with null checks
     /// 4. Calling the provided function with each raster
     pub fn execute_raster_void<F>(&self, mut func: F) -> Result<()>
@@ -64,38 +64,46 @@ impl<'a, 'b> RasterExecutor<'a, 'b> {
         if self.arg_types[0] != RASTER {
             return sedona_internal_err!("First argument must be a raster 
type");
         }
-        let raster_array = match &self.args[0] {
-            ColumnarValue::Array(array) => array,
-            ColumnarValue::Scalar(_) => {
-                return Err(DataFusionError::NotImplemented(
-                    "Scalar raster input not yet supported".to_string(),
-                ));
-            }
-        };
 
-        // Downcast to StructArray (rasters are stored as structs)
-        let raster_struct = raster_array
-            .as_any()
-            .downcast_ref::<StructArray>()
-            .ok_or_else(|| {
-                DataFusionError::Internal("Expected StructArray for raster 
data".to_string())
-            })?;
-
-        // Create raster iterator
-        let raster_array = RasterStructArray::new(raster_struct);
-
-        // Iterate through each raster in the array
-        for i in 0..self.num_iterations {
-            if raster_array.is_null(i) {
-                func(i, None)?;
-                continue;
-            }
-            let raster = raster_array.get(i)?;
+        match &self.args[0] {
+            ColumnarValue::Array(array) => {
+                // Downcast to StructArray (rasters are stored as structs)
+                let raster_struct =
+                    array
+                        .as_any()
+                        .downcast_ref::<StructArray>()
+                        .ok_or_else(|| {
+                            DataFusionError::Internal(
+                                "Expected StructArray for raster 
data".to_string(),
+                            )
+                        })?;
 
-            func(i, Some(raster))?;
-        }
+                let raster_array = RasterStructArray::new(raster_struct);
 
-        Ok(())
+                // Iterate through each raster in the array
+                for i in 0..self.num_iterations {
+                    if raster_array.is_null(i) {
+                        func(i, None)?;
+                        continue;
+                    }
+                    let raster = raster_array.get(i)?;
+                    func(i, Some(raster))?;
+                }
+
+                Ok(())
+            }
+            ColumnarValue::Scalar(scalar_value) => match scalar_value {
+                ScalarValue::Struct(arc_struct) => {
+                    let raster_array = 
RasterStructArray::new(arc_struct.as_ref());
+                    let raster = raster_array.get(0)?;
+                    func(0, Some(raster))
+                }
+                ScalarValue::Null => func(0, None),
+                _ => Err(DataFusionError::Internal(
+                    "Expected Struct scalar for raster".to_string(),
+                )),
+            },
+        }
     }
 
     /// Finish an [ArrayRef] output as the appropriate [ColumnarValue]
@@ -113,7 +121,7 @@ impl<'a, 'b> RasterExecutor<'a, 'b> {
             }
         }
 
-        // For all scalar arguments, we return a scalar
+        // All arguments are scalars, return a scalar
         Ok(ColumnarValue::Scalar(ScalarValue::try_from_array(&out, 0)?))
     }
 
@@ -122,7 +130,7 @@ impl<'a, 'b> RasterExecutor<'a, 'b> {
     fn calc_num_iterations(args: &[ColumnarValue]) -> usize {
         for arg in args {
             match arg {
-                // If any argument is an array, we have to iterate array.len() 
times
+                // If any argument is an array, iterate array.len() times
                 ColumnarValue::Array(array) => {
                     return array.len();
                 }
@@ -130,7 +138,7 @@ impl<'a, 'b> RasterExecutor<'a, 'b> {
             }
         }
 
-        // All scalars: we iterate once
+        // All arguments are scalars, iterate once
         1
     }
 }
@@ -184,4 +192,78 @@ mod tests {
         assert!(width_array.is_null(1));
         assert_eq!(width_array.value(2), 3);
     }
+
+    #[test]
+    fn test_raster_executor_scalar_input() {
+        let rasters = generate_test_rasters(1, None).unwrap();
+        let raster_struct = 
rasters.as_any().downcast_ref::<StructArray>().unwrap();
+        let scalar_raster = 
ScalarValue::Struct(Arc::new(raster_struct.clone()));
+
+        let args = [ColumnarValue::Scalar(scalar_raster)];
+        let arg_types = vec![RASTER];
+
+        let executor = RasterExecutor::new(&arg_types, &args);
+        assert_eq!(executor.num_iterations(), 1);
+
+        let mut builder = 
UInt64Builder::with_capacity(executor.num_iterations());
+        executor
+            .execute_raster_void(|_i, raster_opt| {
+                match raster_opt {
+                    None => builder.append_null(),
+                    Some(raster) => {
+                        let width = raster.metadata().width();
+                        builder.append_value(width);
+                    }
+                }
+                Ok(())
+            })
+            .unwrap();
+
+        let result = executor.finish(Arc::new(builder.finish())).unwrap();
+
+        // With scalar input, result should be a scalar
+        let width_scalar = match &result {
+            ColumnarValue::Scalar(scalar) => scalar,
+            ColumnarValue::Array(_) => panic!("Expected scalar, got array"),
+        };
+
+        match width_scalar {
+            ScalarValue::UInt64(Some(width)) => assert_eq!(*width, 1),
+            _ => panic!("Expected UInt64 scalar"),
+        }
+    }
+
+    #[test]
+    fn test_raster_executor_null_scalar() {
+        // Test with a null scalar
+        let args = [ColumnarValue::Scalar(ScalarValue::Null)];
+        let arg_types = vec![RASTER];
+
+        let executor = RasterExecutor::new(&arg_types, &args);
+        assert_eq!(executor.num_iterations(), 1);
+
+        let mut builder = 
UInt64Builder::with_capacity(executor.num_iterations());
+        executor
+            .execute_raster_void(|_i, raster_opt| {
+                match raster_opt {
+                    None => builder.append_null(),
+                    Some(raster) => {
+                        let width = raster.metadata().width();
+                        builder.append_value(width);
+                    }
+                }
+                Ok(())
+            })
+            .unwrap();
+
+        let result = executor.finish(Arc::new(builder.finish())).unwrap();
+
+        // With null scalar input, result should be null scalar
+        let width_scalar = match &result {
+            ColumnarValue::Scalar(scalar) => scalar,
+            ColumnarValue::Array(_) => panic!("Expected scalar, got array"),
+        };
+
+        assert_eq!(width_scalar, &ScalarValue::UInt64(None));
+    }
 }
diff --git a/rust/sedona-raster-functions/src/lib.rs 
b/rust/sedona-raster-functions/src/lib.rs
index 86aea003..6f4077d7 100644
--- a/rust/sedona-raster-functions/src/lib.rs
+++ b/rust/sedona-raster-functions/src/lib.rs
@@ -17,4 +17,5 @@
 
 mod executor;
 pub mod register;
+pub mod rs_example;
 pub mod rs_size;
diff --git a/rust/sedona-raster-functions/src/register.rs 
b/rust/sedona-raster-functions/src/register.rs
index 7499892a..1945372d 100644
--- a/rust/sedona-raster-functions/src/register.rs
+++ b/rust/sedona-raster-functions/src/register.rs
@@ -36,7 +36,11 @@ pub fn default_function_set() -> FunctionSet {
         };
     }
 
-    register_scalar_udfs!(function_set, crate::rs_size::rs_width_udf,);
+    register_scalar_udfs!(
+        function_set,
+        crate::rs_size::rs_width_udf,
+        crate::rs_example::rs_example_udf,
+    );
 
     register_aggregate_udfs!(function_set,);
 
diff --git a/rust/sedona-raster-functions/src/rs_example.rs 
b/rust/sedona-raster-functions/src/rs_example.rs
new file mode 100644
index 00000000..d5bce5c8
--- /dev/null
+++ b/rust/sedona-raster-functions/src/rs_example.rs
@@ -0,0 +1,152 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+use std::{sync::Arc, vec};
+
+use crate::executor::RasterExecutor;
+use datafusion_common::error::Result;
+use datafusion_expr::{
+    scalar_doc_sections::DOC_SECTION_OTHER, ColumnarValue, Documentation, 
Volatility,
+};
+use sedona_expr::scalar_udf::{SedonaScalarKernel, SedonaScalarUDF};
+use sedona_raster::builder::RasterBuilder;
+use sedona_raster::traits::BandMetadata;
+use sedona_raster::traits::RasterMetadata;
+use sedona_schema::{
+    crs::lnglat,
+    datatypes::SedonaType,
+    matchers::ArgMatcher,
+    raster::{BandDataType, StorageType},
+};
+
+/// RS_Example() scalar UDF implementation
+///
+/// Creates a simple concrete example for testing purposes
+/// May expand with additional optional parameters in the future
+pub fn rs_example_udf() -> SedonaScalarUDF {
+    SedonaScalarUDF::new(
+        "rs_example",
+        vec![Arc::new(RsExample {})],
+        Volatility::Immutable,
+        Some(rs_example_doc()),
+    )
+}
+
+fn rs_example_doc() -> Documentation {
+    Documentation::builder(
+        DOC_SECTION_OTHER,
+        "Return an example raster".to_string(),
+        "RS_Example()".to_string(),
+    )
+    .with_sql_example("SELECT RS_Example()".to_string())
+    .build()
+}
+
+#[derive(Debug)]
+struct RsExample {}
+
+impl SedonaScalarKernel for RsExample {
+    fn return_type(&self, args: &[SedonaType]) -> Result<Option<SedonaType>> {
+        let matcher = ArgMatcher::new(vec![], SedonaType::Raster);
+
+        matcher.match_args(args)
+    }
+
+    fn invoke_batch(
+        &self,
+        arg_types: &[SedonaType],
+        args: &[ColumnarValue],
+    ) -> Result<ColumnarValue> {
+        let executor = RasterExecutor::new(arg_types, args);
+        let mut builder = RasterBuilder::new(1);
+
+        let raster_metadata = RasterMetadata {
+            width: 64,
+            height: 32,
+            upperleft_x: 43.08,
+            upperleft_y: 79.07,
+            scale_x: 2.0,
+            scale_y: 2.0,
+            skew_x: 1.0,
+            skew_y: 1.0,
+        };
+        let crs = lnglat().unwrap().to_json();
+        builder.start_raster(&raster_metadata, Some(&crs))?;
+        let nodata_value = 127u8;
+        for band_id in 1..=3 {
+            builder.start_band(BandMetadata {
+                datatype: BandDataType::UInt8,
+                nodata_value: Some(vec![nodata_value]),
+                storage_type: StorageType::InDb,
+                outdb_url: None,
+                outdb_band_id: None,
+            })?;
+
+            let mut band_data =
+                vec![band_id as u8; (raster_metadata.width * 
raster_metadata.height) as usize];
+            band_data[0] = nodata_value; // set the top corner to nodata
+
+            builder.band_data_writer().append_value(&band_data);
+            builder.finish_band()?;
+        }
+        builder.finish_raster()?;
+
+        executor.finish(Arc::new(builder.finish()?))
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use datafusion_common::ScalarValue;
+    use datafusion_expr::ScalarUDF;
+    use sedona_raster::array::RasterStructArray;
+    use sedona_raster::traits::RasterRef;
+
+    #[test]
+    fn udf_size() {
+        let udf: ScalarUDF = rs_example_udf().into();
+        assert_eq!(udf.name(), "rs_example");
+        assert!(udf.documentation().is_some());
+    }
+
+    #[test]
+    fn udf_invoke() {
+        let kernel = RsExample {};
+        let args = [];
+        let arg_types = vec![];
+
+        let result = kernel.invoke_batch(&arg_types, &args).unwrap();
+        if let ColumnarValue::Scalar(ScalarValue::Struct(arc_struct)) = result 
{
+            let raster_array = RasterStructArray::new(arc_struct.as_ref());
+
+            assert_eq!(raster_array.len(), 1);
+            let raster = raster_array.get(0).unwrap();
+            let metadata = raster.metadata();
+            assert_eq!(metadata.width(), 64);
+            assert_eq!(metadata.height(), 32);
+
+            let bands = raster.bands();
+            let band = bands.band(1).unwrap();
+            let band_metadata = band.metadata();
+            assert_eq!(band_metadata.data_type(), BandDataType::UInt8);
+            assert_eq!(band_metadata.nodata_value(), Some(&[127u8][..]));
+            assert_eq!(band_metadata.storage_type(), StorageType::InDb);
+        } else {
+            panic!("Expected scalar struct result");
+        }
+    }
+}

Reply via email to