This is an automated email from the ASF dual-hosted git repository.
kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 3f91e265 feat(rust/sedona-expr): Pass ConfigOptions into UDFs (#557)
3f91e265 is described below
commit 3f91e2657c2473dc06357999dec9282d5e45e35a
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Thu Jan 29 22:20:49 2026 +0800
feat(rust/sedona-expr): Pass ConfigOptions into UDFs (#557)
Closes #248
This will allow altering the behavior of UDFs. For instance, allowing the
size of per-thread proj cache used by ST_Transform to be configured before
invoking any CRS transformation functions.
---
c/sedona-extension/src/scalar_kernel.rs | 3 ++
c/sedona-proj/src/st_transform.rs | 2 ++
python/sedonadb/src/udf.rs | 2 ++
rust/sedona-expr/src/item_crs.rs | 21 +++++++++++--
rust/sedona-expr/src/scalar_udf.rs | 10 ++++++-
rust/sedona-functions/src/st_setsrid.rs | 3 ++
rust/sedona-geoparquet/src/format.rs | 11 +++++--
rust/sedona-geoparquet/src/writer.rs | 13 ++++----
rust/sedona-spatial-join/src/optimizer.rs | 2 --
rust/sedona-testing/src/testers.rs | 49 +++++++++++++++++++++++++++----
rust/sedona/src/show.rs | 21 +++++++++----
11 files changed, 114 insertions(+), 23 deletions(-)
diff --git a/c/sedona-extension/src/scalar_kernel.rs
b/c/sedona-extension/src/scalar_kernel.rs
index 17972356..04c9ef0b 100644
--- a/c/sedona-extension/src/scalar_kernel.rs
+++ b/c/sedona-extension/src/scalar_kernel.rs
@@ -20,6 +20,7 @@ use arrow_array::{
make_array, ArrayRef,
};
use arrow_schema::{ArrowError, Field};
+use datafusion_common::config::ConfigOptions;
use datafusion_common::{plan_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;
use sedona_common::sedona_internal_err;
@@ -101,6 +102,7 @@ impl SedonaScalarKernel for ImportedScalarKernel {
args: &[ColumnarValue],
return_type: &SedonaType,
num_rows: usize,
+ _config_options: Option<&ConfigOptions>,
) -> Result<ColumnarValue> {
let arg_scalars = args
.iter()
@@ -560,6 +562,7 @@ impl ExportedScalarKernelImpl {
&args,
return_type,
num_rows as usize,
+ None,
)?;
// Convert the result to an ArrayRef
diff --git a/c/sedona-proj/src/st_transform.rs
b/c/sedona-proj/src/st_transform.rs
index b7fd90a6..fcdf8e78 100644
--- a/c/sedona-proj/src/st_transform.rs
+++ b/c/sedona-proj/src/st_transform.rs
@@ -19,6 +19,7 @@ use arrow_array::builder::{BinaryBuilder, StringViewBuilder};
use arrow_array::ArrayRef;
use arrow_schema::DataType;
use datafusion_common::cast::{as_string_view_array, as_struct_array};
+use datafusion_common::config::ConfigOptions;
use datafusion_common::{exec_err, DataFusionError, Result, ScalarValue};
use datafusion_expr::ColumnarValue;
use sedona_common::sedona_internal_err;
@@ -99,6 +100,7 @@ impl SedonaScalarKernel for STTransform {
args: &[ColumnarValue],
_return_type: &SedonaType,
_num_rows: usize,
+ _config_options: Option<&ConfigOptions>,
) -> Result<ColumnarValue> {
let inputs = zip(arg_types, args)
.map(|(arg_type, arg)| ArgInput::from_arg(arg_type, arg))
diff --git a/python/sedonadb/src/udf.rs b/python/sedonadb/src/udf.rs
index eeb7cdd9..d731bdc2 100644
--- a/python/sedonadb/src/udf.rs
+++ b/python/sedonadb/src/udf.rs
@@ -22,6 +22,7 @@ use arrow_array::{
ArrayRef,
};
use arrow_schema::Field;
+use datafusion_common::config::ConfigOptions;
use datafusion_common::{Result, ScalarValue};
use datafusion_expr::{ColumnarValue, ScalarUDF, Volatility};
use datafusion_ffi::udf::FFI_ScalarUDF;
@@ -159,6 +160,7 @@ impl SedonaScalarKernel for PySedonaScalarKernel {
args: &[ColumnarValue],
return_type: &SedonaType,
num_rows: usize,
+ _config_options: Option<&ConfigOptions>,
) -> Result<ColumnarValue> {
let result = Python::with_gil(|py| -> Result<ArrayRef, PySedonaError> {
let py_values = zip(arg_types, args)
diff --git a/rust/sedona-expr/src/item_crs.rs b/rust/sedona-expr/src/item_crs.rs
index 0889622b..3bb73a16 100644
--- a/rust/sedona-expr/src/item_crs.rs
+++ b/rust/sedona-expr/src/item_crs.rs
@@ -20,6 +20,7 @@ use std::{fmt::Debug, iter::zip, sync::Arc};
use arrow_array::{Array, ArrayRef, StructArray};
use arrow_buffer::NullBuffer;
use arrow_schema::{DataType, Field, FieldRef};
+use datafusion_common::config::ConfigOptions;
use datafusion_common::{
cast::{as_string_view_array, as_struct_array},
exec_err, DataFusionError, Result, ScalarValue,
@@ -102,8 +103,16 @@ impl SedonaScalarKernel for ItemCrsKernel {
args: &[ColumnarValue],
return_type: &SedonaType,
num_rows: usize,
+ config_options: Option<&ConfigOptions>,
) -> Result<ColumnarValue> {
- invoke_handle_item_crs(self.inner.as_ref(), arg_types, args,
return_type, num_rows)
+ invoke_handle_item_crs(
+ self.inner.as_ref(),
+ arg_types,
+ args,
+ return_type,
+ num_rows,
+ config_options,
+ )
}
fn invoke_batch(
@@ -444,6 +453,7 @@ fn invoke_handle_item_crs(
args: &[ColumnarValue],
return_type: &SedonaType,
num_rows: usize,
+ config_options: Option<&ConfigOptions>,
) -> Result<ColumnarValue> {
// Separate the argument types into item and Option<crs>
// Don't strip the CRSes because we need them to compare with
@@ -485,8 +495,13 @@ fn invoke_handle_item_crs(
None => return sedona_internal_err!("Expected inner kernel to match
types {item_types:?}"),
};
- let item_result =
- kernel.invoke_batch_from_args(&item_types, &item_args, return_type,
num_rows)?;
+ let item_result = kernel.invoke_batch_from_args(
+ &item_types,
+ &item_args,
+ return_type,
+ num_rows,
+ config_options,
+ )?;
if ArgMatcher::is_geometry_or_geography().match_type(&out_item_type) {
make_item_crs(&out_item_type, item_result, crs_result, None)
diff --git a/rust/sedona-expr/src/scalar_udf.rs
b/rust/sedona-expr/src/scalar_udf.rs
index d5bc56ec..fac30e60 100644
--- a/rust/sedona-expr/src/scalar_udf.rs
+++ b/rust/sedona-expr/src/scalar_udf.rs
@@ -17,6 +17,7 @@
use std::{any::Any, fmt::Debug, sync::Arc};
use arrow_schema::{DataType, FieldRef};
+use datafusion_common::config::ConfigOptions;
use datafusion_common::{not_impl_err, Result, ScalarValue};
use datafusion_expr::{
ColumnarValue, Documentation, ReturnFieldArgs, ScalarFunctionArgs,
ScalarUDFImpl, Signature,
@@ -136,6 +137,7 @@ pub trait SedonaScalarKernel: Debug + Send + Sync {
args: &[ColumnarValue],
_return_type: &SedonaType,
_num_rows: usize,
+ _config_options: Option<&ConfigOptions>,
) -> Result<ColumnarValue> {
self.invoke_batch(arg_types, args)
}
@@ -323,7 +325,13 @@ impl ScalarUDFImpl for SedonaScalarUDF {
.collect::<Vec<_>>();
let (kernel, return_type) = self.return_type_impl(&arg_types,
&arg_scalars)?;
- kernel.invoke_batch_from_args(&arg_types, &args.args, &return_type,
args.number_rows)
+ kernel.invoke_batch_from_args(
+ &arg_types,
+ &args.args,
+ &return_type,
+ args.number_rows,
+ Some(&*args.config_options),
+ )
}
fn aliases(&self) -> &[String] {
diff --git a/rust/sedona-functions/src/st_setsrid.rs
b/rust/sedona-functions/src/st_setsrid.rs
index 3e9fd0b3..7c5b9713 100644
--- a/rust/sedona-functions/src/st_setsrid.rs
+++ b/rust/sedona-functions/src/st_setsrid.rs
@@ -25,6 +25,7 @@ use arrow_array::{
};
use arrow_buffer::NullBuffer;
use arrow_schema::DataType;
+use datafusion_common::config::ConfigOptions;
use datafusion_common::{
cast::{as_int64_array, as_string_view_array},
error::Result,
@@ -142,6 +143,7 @@ impl SedonaScalarKernel for STSetSRID {
args: &[ColumnarValue],
return_type: &SedonaType,
_num_rows: usize,
+ _config_options: Option<&ConfigOptions>,
) -> Result<ColumnarValue> {
let item_crs_matcher = ArgMatcher::is_item_crs();
if item_crs_matcher.match_type(return_type) {
@@ -198,6 +200,7 @@ impl SedonaScalarKernel for STSetCRS {
args: &[ColumnarValue],
return_type: &SedonaType,
_num_rows: usize,
+ _config_options: Option<&ConfigOptions>,
) -> Result<ColumnarValue> {
let item_crs_matcher = ArgMatcher::is_item_crs();
if item_crs_matcher.match_type(return_type) {
diff --git a/rust/sedona-geoparquet/src/format.rs
b/rust/sedona-geoparquet/src/format.rs
index 0da807b3..4fa966e4 100644
--- a/rust/sedona-geoparquet/src/format.rs
+++ b/rust/sedona-geoparquet/src/format.rs
@@ -302,11 +302,18 @@ impl FileFormat for GeoParquetFormat {
async fn create_writer_physical_plan(
&self,
input: Arc<dyn ExecutionPlan>,
- _state: &dyn Session,
+ session: &dyn Session,
conf: FileSinkConfig,
order_requirements: Option<LexRequirement>,
) -> Result<Arc<dyn ExecutionPlan>> {
- create_geoparquet_writer_physical_plan(input, conf,
order_requirements, &self.options)
+ let session_config_options = session.config().options();
+ create_geoparquet_writer_physical_plan(
+ input,
+ conf,
+ order_requirements,
+ &self.options,
+ session_config_options,
+ )
}
fn file_source(&self) -> Arc<dyn FileSource> {
diff --git a/rust/sedona-geoparquet/src/writer.rs
b/rust/sedona-geoparquet/src/writer.rs
index 8ea1a264..42c9de4d 100644
--- a/rust/sedona-geoparquet/src/writer.rs
+++ b/rust/sedona-geoparquet/src/writer.rs
@@ -69,6 +69,7 @@ pub fn create_geoparquet_writer_physical_plan(
mut conf: FileSinkConfig,
order_requirements: Option<LexRequirement>,
options: &TableGeoParquetOptions,
+ session_config_options: &Arc<ConfigOptions>,
) -> Result<Arc<dyn ExecutionPlan>> {
if conf.insert_op != InsertOp::Append {
return not_impl_err!("Overwrites are not implemented yet for Parquet");
@@ -93,8 +94,11 @@ pub fn create_geoparquet_writer_physical_plan(
}
GeoParquetVersion::V1_1 => {
metadata.version = "1.1.0".to_string();
- (bbox_projection, bbox_columns) =
- project_bboxes(&input, options.overwrite_bbox_columns)?;
+ (bbox_projection, bbox_columns) = project_bboxes(
+ &input,
+ options.overwrite_bbox_columns,
+ session_config_options,
+ )?;
parquet_output_schema = compute_final_schema(&bbox_projection,
&input.schema())?;
output_geometry_column_indices =
conf.output_schema.geometry_column_indices()?;
}
@@ -291,6 +295,7 @@ type ProjectBboxesResult = (
fn project_bboxes(
input: &Arc<dyn ExecutionPlan>,
overwrite_bbox_columns: bool,
+ session_config_options: &Arc<ConfigOptions>,
) -> Result<ProjectBboxesResult> {
let input_schema = input.schema();
let matcher = ArgMatcher::is_geometry();
@@ -310,14 +315,12 @@ fn project_bboxes(
column.return_field(&input_schema)?.as_ref(),
)?) {
let bbox_field_name = bbox_column_name(f.name());
- // TODO: Pipe actual ConfigOptions from session instead of using
defaults
- // See: https://github.com/apache/sedona-db/issues/248
let expr = Arc::new(ScalarFunctionExpr::new(
bbox_udf_name,
bbox_udf.clone(),
vec![column],
Arc::new(Field::new("", bbox_type(), true)),
- Arc::new(ConfigOptions::default()),
+ Arc::clone(session_config_options),
));
bbox_exprs.insert(i, (expr, bbox_field_name.clone()));
diff --git a/rust/sedona-spatial-join/src/optimizer.rs
b/rust/sedona-spatial-join/src/optimizer.rs
index a8c28167..a5a8baef 100644
--- a/rust/sedona-spatial-join/src/optimizer.rs
+++ b/rust/sedona-spatial-join/src/optimizer.rs
@@ -1176,8 +1176,6 @@ mod tests {
) -> Arc<ScalarFunctionExpr> {
let return_type = udf.return_type(&[]).unwrap();
let field = Arc::new(arrow::datatypes::Field::new("result",
return_type, false));
- // TODO: Pipe actual ConfigOptions from session instead of using
defaults
- // See: https://github.com/apache/sedona-db/issues/248
Arc::new(ScalarFunctionExpr::new(
udf.name(),
Arc::clone(&udf),
diff --git a/rust/sedona-testing/src/testers.rs
b/rust/sedona-testing/src/testers.rs
index 54b947a3..33476d49 100644
--- a/rust/sedona-testing/src/testers.rs
+++ b/rust/sedona-testing/src/testers.rs
@@ -27,7 +27,7 @@ use datafusion_expr::{
ReturnFieldArgs, ScalarFunctionArgs, ScalarUDF,
};
use datafusion_physical_expr::{expressions::Column, PhysicalExpr};
-use sedona_common::sedona_internal_err;
+use sedona_common::{sedona_internal_err, SedonaOptions};
use sedona_schema::datatypes::SedonaType;
use crate::{
@@ -240,12 +240,53 @@ impl AggregateUdfTester {
pub struct ScalarUdfTester {
udf: ScalarUDF,
arg_types: Vec<SedonaType>,
+ config_options: Arc<ConfigOptions>,
}
impl ScalarUdfTester {
/// Create a new tester
pub fn new(udf: ScalarUDF, arg_types: Vec<SedonaType>) -> Self {
- Self { udf, arg_types }
+ let mut config_options = ConfigOptions::default();
+ let sedona_options = SedonaOptions::default();
+ config_options.extensions.insert(sedona_options);
+ Self {
+ udf,
+ arg_types,
+ config_options: Arc::new(config_options),
+ }
+ }
+
+ /// Returns the [`ConfigOptions`] used when invoking the UDF.
+ ///
+ /// This is the same structure DataFusion threads through
[`ScalarFunctionArgs`].
+ /// Sedona-specific options are stored in `config_options.extensions`.
+ pub fn config_options(&self) -> &ConfigOptions {
+ &self.config_options
+ }
+
+ /// Returns a mutable reference to the [`ConfigOptions`] used when
invoking the UDF.
+ ///
+ /// Use this to tweak DataFusion options or to insert/update Sedona
options via
+ /// `config_options.extensions` before calling the tester's `invoke_*`
helpers.
+ pub fn config_options_mut(&mut self) -> &mut ConfigOptions {
+ // config_options can only be owned by this tester, so it's safe to
get a mutable reference.
+ Arc::get_mut(&mut self.config_options).expect("ConfigOptions is
shared")
+ }
+
+ /// Returns the [`SedonaOptions`] stored in `config_options.extensions`,
if present.
+ pub fn sedona_options(&self) -> &SedonaOptions {
+ self.config_options
+ .extensions
+ .get::<SedonaOptions>()
+ .expect("SedonaOptions does not exist")
+ }
+
+ /// Returns a mutable reference to the [`SedonaOptions`] stored in
`config_options.extensions`, if present.
+ pub fn sedona_options_mut(&mut self) -> &mut SedonaOptions {
+ self.config_options_mut()
+ .extensions
+ .get_mut::<SedonaOptions>()
+ .expect("SedonaOptions does not exist")
}
/// Assert the return type of the function for the argument types used
@@ -610,9 +651,7 @@ impl ScalarUdfTester {
arg_fields: self.arg_fields(),
number_rows,
return_field: return_type.to_storage_field("", true)?.into(),
- // TODO: Consider piping actual ConfigOptions for more realistic
testing
- // See: https://github.com/apache/sedona-db/issues/248
- config_options: Arc::new(ConfigOptions::default()),
+ config_options: Arc::clone(&self.config_options),
};
self.udf.invoke_with_args(args)
diff --git a/rust/sedona/src/show.rs b/rust/sedona/src/show.rs
index 4276a119..0d3b946b 100644
--- a/rust/sedona/src/show.rs
+++ b/rust/sedona/src/show.rs
@@ -50,7 +50,11 @@ pub fn show_batches<'a, W: std::io::Write>(
))?
.clone();
- let mut table = DisplayTable::try_new(schema, batches,
options)?.with_format_fn(format_fn);
+ let session_config = ctx.ctx.copied_config();
+ let session_config_options = session_config.options();
+
+ let mut table = DisplayTable::try_new(schema, batches, options,
session_config_options)?
+ .with_format_fn(format_fn);
table.negotiate_hidden_columns()?;
table.write(writer)
}
@@ -141,6 +145,7 @@ impl<'a> DisplayTable<'a> {
schema: &Schema,
batches: Vec<RecordBatch>,
options: DisplayTableOptions<'a>,
+ session_config_options: &Arc<ConfigOptions>,
) -> Result<Self> {
let num_rows = batches.iter().map(|batch| batch.num_rows()).sum();
@@ -155,6 +160,7 @@ impl<'a> DisplayTable<'a> {
.iter()
.map(|batch| batch.column(i).clone())
.collect(),
+ Arc::clone(session_config_options),
)
})
.collect::<Result<Vec<_>>>()?;
@@ -354,17 +360,23 @@ struct DisplayColumn {
raw_values: Vec<ArrayRef>,
format_fn: Option<SedonaScalarUDF>,
hidden: bool,
+ session_config_options: Arc<ConfigOptions>,
}
impl DisplayColumn {
/// Create a new display column
- pub fn try_new(field: &Field, raw_values: Vec<ArrayRef>) -> Result<Self> {
+ pub fn try_new(
+ field: &Field,
+ raw_values: Vec<ArrayRef>,
+ session_config_options: Arc<ConfigOptions>,
+ ) -> Result<Self> {
Ok(Self {
name: field.name().to_string(),
sedona_type: SedonaType::from_storage_field(field)?,
raw_values,
format_fn: None,
hidden: false,
+ session_config_options,
})
}
@@ -382,6 +394,7 @@ impl DisplayColumn {
raw_values: vec![Arc::new(raw_values)],
format_fn: None,
hidden: false,
+ session_config_options: Arc::new(ConfigOptions::default()),
}
}
@@ -495,9 +508,7 @@ impl DisplayColumn {
arg_fields,
number_rows: array.len(),
return_field,
- // TODO: Pipe actual ConfigOptions from SedonaContext instead
of using defaults
- // See: https://github.com/apache/sedona-db/issues/248
- config_options: Arc::new(ConfigOptions::default()),
+ config_options: Arc::clone(&self.session_config_options),
};
let format_proxy_value = format_udf.invoke_with_args(args)?;