This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new af678b1 Push down all spatial predicates when reading GeoParquet
files (#39)
af678b1 is described below
commit af678b1e4c46d70a120f566ff709be07b8dba513
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Sep 10 01:20:51 2025 +0800
Push down all spatial predicates when reading GeoParquet files (#39)
* Initial impl of bounds contains
* Add tests for bbox contains and interval contains
* Support various spatial range predicates and distance predicate
* Implement bounding box expansion
* Comment the distance predicate parsing util
* Add tests for PhysicalExpr to SpatialFilter conversion and SpatialFilter
evaluation function
* Fix conversion from st_contains/st_coveres to SpatialFilter
* Extend geoparquet metadata pruning test for st_contains
* Fix review comments
---
Cargo.lock | 2 +
rust/sedona-expr/Cargo.toml | 1 +
rust/sedona-expr/src/lib.rs | 1 +
rust/sedona-expr/src/spatial_filter.rs | 479 +++++++++++++++++++---
rust/sedona-expr/src/utils.rs | 103 +++++
rust/sedona-geometry/src/bounding_box.rs | 169 ++++++++
rust/sedona-geometry/src/interval.rs | 202 +++++++++
rust/sedona-geoparquet/Cargo.toml | 1 +
rust/sedona-geoparquet/src/format.rs | 19 +-
rust/sedona-spatial-join/src/operand_evaluator.rs | 3 +
rust/sedona-spatial-join/src/optimizer.rs | 59 +--
11 files changed, 922 insertions(+), 117 deletions(-)
diff --git a/Cargo.lock b/Cargo.lock
index 999f2c0..4003484 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -4880,6 +4880,7 @@ dependencies = [
"datafusion-expr",
"datafusion-physical-expr",
"geo-traits 0.2.0",
+ "rstest",
"sedona-common",
"sedona-geometry",
"sedona-schema",
@@ -4992,6 +4993,7 @@ dependencies = [
"geo-traits 0.2.0",
"object_store",
"parquet",
+ "rstest",
"sedona-common",
"sedona-expr",
"sedona-geometry",
diff --git a/rust/sedona-expr/Cargo.toml b/rust/sedona-expr/Cargo.toml
index 2fef1be..0cb7260 100644
--- a/rust/sedona-expr/Cargo.toml
+++ b/rust/sedona-expr/Cargo.toml
@@ -29,6 +29,7 @@ result_large_err = "allow"
[dev-dependencies]
sedona-testing = { path = "../sedona-testing" }
+rstest = { workspace = true }
[dependencies]
arrow-array = { workspace = true }
diff --git a/rust/sedona-expr/src/lib.rs b/rust/sedona-expr/src/lib.rs
index d242625..c200b8e 100644
--- a/rust/sedona-expr/src/lib.rs
+++ b/rust/sedona-expr/src/lib.rs
@@ -19,3 +19,4 @@ pub mod function_set;
pub mod scalar_udf;
pub mod spatial_filter;
pub mod statistics;
+pub mod utils;
diff --git a/rust/sedona-expr/src/spatial_filter.rs
b/rust/sedona-expr/src/spatial_filter.rs
index 4923c41..71f2949 100644
--- a/rust/sedona-expr/src/spatial_filter.rs
+++ b/rust/sedona-expr/src/spatial_filter.rs
@@ -16,7 +16,7 @@
// under the License.
use std::sync::Arc;
-use arrow_schema::Schema;
+use arrow_schema::{DataType, Schema};
use datafusion_common::{DataFusionError, Result, ScalarValue};
use datafusion_expr::Operator;
use datafusion_physical_expr::{
@@ -28,7 +28,10 @@ use sedona_common::sedona_internal_err;
use sedona_geometry::{bounding_box::BoundingBox, bounds::wkb_bounds_xy,
interval::IntervalTrait};
use sedona_schema::datatypes::SedonaType;
-use crate::statistics::GeoStatistics;
+use crate::{
+ statistics::GeoStatistics,
+ utils::{parse_distance_predicate, ParsedDistancePredicate},
+};
/// Simplified parsed spatial filter
///
@@ -42,6 +45,8 @@ use crate::statistics::GeoStatistics;
pub enum SpatialFilter {
/// ST_Intersects(\<column\>, \<literal\>) or ST_Intersects(\<literal\>,
\<column\>)
Intersects(Column, BoundingBox),
+ /// ST_CoveredBy(\<column\>, \<literal\>) or ST_CoveredBy(\<literal\>,
\<column\>)
+ CoveredBy(Column, BoundingBox),
/// ST_HasZ(\<column\>)
HasZ(Column),
/// Logical AND
@@ -64,6 +69,9 @@ impl SpatialFilter {
SpatialFilter::Intersects(column, bounds) => {
Self::evaluate_intersects_bbox(&table_stats[column.index()],
bounds)
}
+ SpatialFilter::CoveredBy(column, bounds) => {
+ Self::evaluate_covered_by_bbox(&table_stats[column.index()],
bounds)
+ }
SpatialFilter::HasZ(column) =>
Self::evaluate_has_z(&table_stats[column.index()]),
SpatialFilter::And(lhs, rhs) => Self::evaluate_and(lhs, rhs,
table_stats),
SpatialFilter::Or(lhs, rhs) => Self::evaluate_or(lhs, rhs,
table_stats),
@@ -80,6 +88,14 @@ impl SpatialFilter {
}
}
+ fn evaluate_covered_by_bbox(column_stats: &GeoStatistics, bounds:
&BoundingBox) -> bool {
+ if let Some(bbox) = column_stats.bbox() {
+ bounds.contains(bbox)
+ } else {
+ true
+ }
+ }
+
fn evaluate_has_z(column_stats: &GeoStatistics) -> bool {
if let Some(bbox) = column_stats.bbox() {
if let Some(z) = bbox.z() {
@@ -119,46 +135,10 @@ impl SpatialFilter {
///
/// Parses expr to extract known expressions we can evaluate against
statistics.
pub fn try_from_expr(expr: &Arc<dyn PhysicalExpr>) -> Result<Self> {
- if let Some(scalar_fun) =
expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
- let raw_args = scalar_fun.args();
- let args = parse_args(raw_args);
- match scalar_fun.fun().name() {
- "st_intersects" => {
- if args.len() != 2 {
- return sedona_internal_err!(
- "unexpected argument count in filter evaluation"
- );
- }
-
- match (&args[0], &args[1]) {
- (ArgRef::Col(column), ArgRef::Lit(literal))
- | (ArgRef::Lit(literal), ArgRef::Col(column)) => {
- match literal_bounds(literal) {
- Ok(literal_bounds) => {
- Ok(Self::Intersects(column.clone(),
literal_bounds))
- }
- Err(e) =>
Err(DataFusionError::External(Box::new(e))),
- }
- }
- // Not between a literal and a column
- _ => Ok(Self::Unknown),
- }
- }
- "st_hasz" => {
- if args.len() != 1 {
- return sedona_internal_err!(
- "unexpected argument count in filter evaluation"
- );
- }
-
- match &args[0] {
- ArgRef::Col(column) => Ok(Self::HasZ(column.clone())),
- _ => Ok(Self::Unknown),
- }
- }
- // Not a function we know about
- _ => Ok(Self::Unknown),
- }
+ if let Some(spatial_filter) = Self::try_from_range_predicate(expr)? {
+ Ok(spatial_filter)
+ } else if let Some(spatial_filter) =
Self::try_from_distance_predicate(expr)? {
+ Ok(spatial_filter)
} else if let Some(binary_expr) =
expr.as_any().downcast_ref::<BinaryExpr>() {
match binary_expr.op() {
Operator::And => Ok(Self::And(
@@ -187,6 +167,144 @@ impl SpatialFilter {
Ok(Self::Unknown)
}
}
+
+ fn try_from_range_predicate(expr: &Arc<dyn PhysicalExpr>) ->
Result<Option<Self>> {
+ let Some(scalar_fun) =
expr.as_any().downcast_ref::<ScalarFunctionExpr>() else {
+ return Ok(None);
+ };
+
+ let raw_args = scalar_fun.args();
+ let args = parse_args(raw_args);
+ let fun_name = scalar_fun.fun().name();
+ match fun_name {
+ "st_intersects" | "st_equals" | "st_touches" => {
+ if args.len() != 2 {
+ return sedona_internal_err!("unexpected argument count in
filter evaluation");
+ }
+
+ match (&args[0], &args[1]) {
+ (ArgRef::Col(column), ArgRef::Lit(literal))
+ | (ArgRef::Lit(literal), ArgRef::Col(column)) => {
+ match literal_bounds(literal) {
+ Ok(literal_bounds) => {
+ Ok(Some(Self::Intersects(column.clone(),
literal_bounds)))
+ }
+ Err(e) =>
Err(DataFusionError::External(Box::new(e))),
+ }
+ }
+ // Not between a literal and a column
+ _ => Ok(Some(Self::Unknown)),
+ }
+ }
+ "st_within" | "st_covered_by" | "st_coveredby" => {
+ if args.len() != 2 {
+ return sedona_internal_err!("unexpected argument count in
filter evaluation");
+ }
+
+ match (&args[0], &args[1]) {
+ (ArgRef::Col(column), ArgRef::Lit(literal)) => {
+ // column within/covered_by literal -> CoveredBy filter
+ match literal_bounds(literal) {
+ Ok(literal_bounds) => {
+ Ok(Some(Self::CoveredBy(column.clone(),
literal_bounds)))
+ }
+ Err(e) =>
Err(DataFusionError::External(Box::new(e))),
+ }
+ }
+ (ArgRef::Lit(literal), ArgRef::Col(column)) => {
+ // literal within/covered_by column -> Intersects
filter
+ match literal_bounds(literal) {
+ Ok(literal_bounds) => {
+ Ok(Some(Self::Intersects(column.clone(),
literal_bounds)))
+ }
+ Err(e) =>
Err(DataFusionError::External(Box::new(e))),
+ }
+ }
+ // Not between a literal and a column
+ _ => Ok(Some(Self::Unknown)),
+ }
+ }
+ "st_contains" | "st_covers" => {
+ if args.len() != 2 {
+ return sedona_internal_err!("unexpected argument count in
filter evaluation");
+ }
+
+ match (&args[0], &args[1]) {
+ (ArgRef::Col(column), ArgRef::Lit(literal)) => {
+ // column contains/covers literal -> Intersects filter
+ // (column must potentially intersect literal to
contain it)
+ match literal_bounds(literal) {
+ Ok(literal_bounds) => {
+ Ok(Some(Self::Intersects(column.clone(),
literal_bounds)))
+ }
+ Err(e) =>
Err(DataFusionError::External(Box::new(e))),
+ }
+ }
+ (ArgRef::Lit(literal), ArgRef::Col(column)) => {
+ // literal contains/covers column -> CoveredBy filter
+ // (equivalent to st_within(column, literal))
+ match literal_bounds(literal) {
+ Ok(literal_bounds) => {
+ Ok(Some(Self::CoveredBy(column.clone(),
literal_bounds)))
+ }
+ Err(e) =>
Err(DataFusionError::External(Box::new(e))),
+ }
+ }
+ // Not between a literal and a column
+ _ => Ok(Some(Self::Unknown)),
+ }
+ }
+ "st_hasz" => {
+ if args.len() != 1 {
+ return sedona_internal_err!("unexpected argument count in
filter evaluation");
+ }
+
+ match &args[0] {
+ ArgRef::Col(column) =>
Ok(Some(Self::HasZ(column.clone()))),
+ _ => Ok(Some(Self::Unknown)),
+ }
+ }
+ _ => Ok(None),
+ }
+ }
+
+ fn try_from_distance_predicate(expr: &Arc<dyn PhysicalExpr>) ->
Result<Option<Self>> {
+ let Some(ParsedDistancePredicate {
+ arg0,
+ arg1,
+ arg_distance,
+ }) = parse_distance_predicate(expr)
+ else {
+ return Ok(None);
+ };
+
+ let raw_args = [arg0, arg1, arg_distance];
+ let args = parse_args(&raw_args);
+
+ match (&args[0], &args[1], &args[2]) {
+ (ArgRef::Col(column), ArgRef::Lit(literal), ArgRef::Lit(distance))
+ | (ArgRef::Lit(literal), ArgRef::Col(column),
ArgRef::Lit(distance)) => {
+ match (
+ literal_bounds(literal),
+ distance.value().cast_to(&DataType::Float64)?,
+ ) {
+ (Ok(literal_bounds), distance_scalar_value) => {
+ let ScalarValue::Float64(Some(dist)) =
distance_scalar_value else {
+ return Ok(None);
+ };
+ if dist.is_nan() || dist < 0.0 {
+ return Ok(None);
+ }
+ let expanded_bounds = literal_bounds.expand_by(dist);
+ Ok(Some(Self::Intersects(column.clone(),
expanded_bounds)))
+ }
+ (Err(e), _) => Err(DataFusionError::External(Box::new(e))),
+ }
+ }
+ // Not between a literal and a column
+ _ => Ok(Some(Self::Unknown)),
+ }
+ }
}
/// Internal utility to help match physical expression types
@@ -233,15 +351,16 @@ mod test {
use arrow_schema::{DataType, Field};
use datafusion_expr::{ScalarUDF, Signature, SimpleScalarUDF, Volatility};
+ use rstest::rstest;
use sedona_geometry::{bounding_box::BoundingBox, interval::Interval};
use sedona_schema::datatypes::WKB_GEOMETRY;
use sedona_testing::create::create_scalar;
use super::*;
- fn dummy_st_intersects() -> ScalarUDF {
+ fn dummy_st_hasz() -> ScalarUDF {
SimpleScalarUDF::new_with_signature(
- "st_intersects",
+ "st_hasz",
Signature::any(2, Volatility::Immutable),
DataType::Boolean,
Arc::new(|_args| Ok(ScalarValue::Boolean(Some(true)).into())),
@@ -249,9 +368,9 @@ mod test {
.into()
}
- fn dummy_st_hasz() -> ScalarUDF {
+ fn dummy_unrelated() -> ScalarUDF {
SimpleScalarUDF::new_with_signature(
- "st_hasz",
+ "st_not_a_predicate",
Signature::any(2, Volatility::Immutable),
DataType::Boolean,
Arc::new(|_args| Ok(ScalarValue::Boolean(Some(true)).into())),
@@ -259,19 +378,16 @@ mod test {
.into()
}
- fn dummy_unrelated() -> ScalarUDF {
+ fn create_dummy_spatial_function(name: &str, arg_count: usize) ->
ScalarUDF {
SimpleScalarUDF::new_with_signature(
- "st_not_a_predicate",
- Signature::any(2, Volatility::Immutable),
+ name,
+ Signature::any(arg_count, Volatility::Immutable),
DataType::Boolean,
Arc::new(|_args| Ok(ScalarValue::Boolean(Some(true)).into())),
)
.into()
}
- #[test]
- fn spatial_filters() {}
-
#[test]
fn predicate_intersects() {
let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap();
@@ -307,6 +423,32 @@ mod test {
.contains("Unexpected scalar type in filter expression"));
}
+ #[test]
+ fn predicate_covered_by() {
+ let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap();
+ let literal = Literal::new_with_metadata(
+ create_scalar(Some("POLYGON ((0 0, 4 0, 4 4, 0 4, 0 0))"),
&WKB_GEOMETRY),
+ Some(storage_field.metadata().into()),
+ );
+ let bounds = literal_bounds(&literal).unwrap();
+
+ let stats_no_info = [GeoStatistics::unspecified()];
+ let stats_covered = [
+ GeoStatistics::unspecified().with_bbox(Some(BoundingBox::xy((1.0,
1.0), (2.0, 2.0))))
+ ];
+ let stats_not_covered = [
+ GeoStatistics::unspecified().with_bbox(Some(BoundingBox::xy((3.0,
3.0), (5.0, 5.0))))
+ ];
+ let col0 = Column::new("col0", 0);
+
+ // CoveredBy should return true when column bbox is fully contained in
literal bounds
+ assert!(SpatialFilter::CoveredBy(col0.clone(),
bounds.clone()).evaluate(&stats_no_info));
+ assert!(SpatialFilter::CoveredBy(col0.clone(),
bounds.clone()).evaluate(&stats_covered));
+ assert!(
+ !SpatialFilter::CoveredBy(col0.clone(),
bounds.clone()).evaluate(&stats_not_covered)
+ );
+ }
+
#[test]
fn predicate_has_z() {
let col0 = Column::new("col0", 0);
@@ -405,39 +547,248 @@ mod test {
));
}
- #[test]
- fn predicate_from_expr_intersects() {
+ #[rstest]
+ fn predicate_from_expr_commutative_functions(
+ #[values("st_intersects", "st_equals", "st_touches")] func_name: &str,
+ ) {
let column: Arc<dyn PhysicalExpr> = Arc::new(Column::new("geometry",
0));
let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap();
let literal: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new_with_metadata(
- create_scalar(Some("POINT (1 2)"), &WKB_GEOMETRY),
+ create_scalar(Some("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))"),
&WKB_GEOMETRY),
Some(storage_field.metadata().into()),
));
- let st_intersects = dummy_st_intersects();
+ // Test functions that should result in Intersects filter
+ let func = create_dummy_spatial_function(func_name, 2);
let expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
- "intersects",
- Arc::new(st_intersects.clone()),
+ func_name,
+ Arc::new(func.clone()),
vec![column.clone(), literal.clone()],
Arc::new(Field::new("", DataType::Boolean, true)),
));
let predicate = SpatialFilter::try_from_expr(&expr).unwrap();
- assert!(matches!(predicate, SpatialFilter::Intersects(_, _)));
+ assert!(
+ matches!(predicate, SpatialFilter::Intersects(_, _)),
+ "Function {} should produce Intersects filter",
+ func_name
+ );
+ // Test reversed argument order
+ let expr_reversed: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
+ func_name,
+ Arc::new(func),
+ vec![literal.clone(), column.clone()],
+ Arc::new(Field::new("", DataType::Boolean, true)),
+ ));
+ let predicate_reversed =
SpatialFilter::try_from_expr(&expr_reversed).unwrap();
+ assert!(
+ matches!(predicate_reversed, SpatialFilter::Intersects(_, _)),
+ "Function {} with reversed args should produce Intersects filter",
+ func_name
+ );
+ }
+
+ #[rstest]
+ fn predicate_from_expr_within_covered_by_functions(
+ #[values("st_within", "st_covered_by", "st_coveredby")] func_name:
&str,
+ ) {
+ let column: Arc<dyn PhysicalExpr> = Arc::new(Column::new("geometry",
0));
+ let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap();
+ let literal: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new_with_metadata(
+ create_scalar(Some("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))"),
&WKB_GEOMETRY),
+ Some(storage_field.metadata().into()),
+ ));
+
+ // Test functions that should result in CoveredBy filter when column
is first arg
+ let func = create_dummy_spatial_function(func_name, 2);
let expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
- "intersects",
- Arc::new(st_intersects.clone()),
+ func_name,
+ Arc::new(func.clone()),
+ vec![column.clone(), literal.clone()],
+ Arc::new(Field::new("", DataType::Boolean, true)),
+ ));
+ let predicate = SpatialFilter::try_from_expr(&expr).unwrap();
+ assert!(
+ matches!(predicate, SpatialFilter::CoveredBy(_, _)),
+ "Function {} should produce CoveredBy filter",
+ func_name
+ );
+
+ // Test reversed argument order: should be converted to Intersects
filter
+ let expr_reversed: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
+ func_name,
+ Arc::new(func),
vec![literal.clone(), column.clone()],
Arc::new(Field::new("", DataType::Boolean, true)),
));
+ let predicate_reversed =
SpatialFilter::try_from_expr(&expr_reversed).unwrap();
+ assert!(
+ matches!(predicate_reversed, SpatialFilter::Intersects(_, _)),
+ "Function {} with reversed args should produce Intersects filter",
+ func_name
+ );
+ }
+
+ #[rstest]
+ fn predicate_from_expr_contains_covers_functions(
+ #[values("st_contains", "st_covers")] func_name: &str,
+ ) {
+ let column: Arc<dyn PhysicalExpr> = Arc::new(Column::new("geometry",
0));
+ let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap();
+ let literal: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new_with_metadata(
+ create_scalar(Some("POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0))"),
&WKB_GEOMETRY),
+ Some(storage_field.metadata().into()),
+ ));
+
+ // Test functions that should result in Intersects filter when column
is first arg
+ // (column contains/covers literal -> column must intersect literal)
+ let func = create_dummy_spatial_function(func_name, 2);
+ let expr: Arc<dyn PhysicalExpr> = Arc::new(ScalarFunctionExpr::new(
+ func_name,
+ Arc::new(func.clone()),
+ vec![column.clone(), literal.clone()],
+ Arc::new(Field::new("", DataType::Boolean, true)),
+ ));
let predicate = SpatialFilter::try_from_expr(&expr).unwrap();
- assert!(matches!(predicate, SpatialFilter::Intersects(_, _)))
+ assert!(
+ matches!(predicate, SpatialFilter::Intersects(_, _)),
+ "Function {} should produce Intersects filter",
+ func_name
+ );
+
+ // Test reversed argument order: should be converted to CoveredBy
filter
+ // (literal contains/covers column -> equivalent to st_within(column,
literal))
+ let expr_reversed: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
+ func_name,
+ Arc::new(func),
+ vec![literal.clone(), column.clone()],
+ Arc::new(Field::new("", DataType::Boolean, true)),
+ ));
+ let predicate_reversed =
SpatialFilter::try_from_expr(&expr_reversed).unwrap();
+ assert!(
+ matches!(predicate_reversed, SpatialFilter::CoveredBy(_, _)),
+ "Function {} with reversed args should produce CoveredBy filter",
+ func_name
+ );
}
#[test]
- fn predicate_from_intersects_errors() {
+ fn predicate_from_expr_distance_functions() {
+ let column: Arc<dyn PhysicalExpr> = Arc::new(Column::new("geometry",
0));
+ let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap();
+ let literal: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new_with_metadata(
+ create_scalar(Some("POINT (1 2)"), &WKB_GEOMETRY),
+ Some(storage_field.metadata().into()),
+ ));
+ let distance_literal: Arc<dyn PhysicalExpr> =
+ Arc::new(Literal::new(ScalarValue::Float64(Some(100.0))));
+
+ // Test ST_DWithin function
+ let st_dwithin = create_dummy_spatial_function("st_dwithin", 3);
+ let dwithin_expr: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
+ "st_dwithin",
+ Arc::new(st_dwithin.clone()),
+ vec![column.clone(), literal.clone(), distance_literal.clone()],
+ Arc::new(Field::new("", DataType::Boolean, true)),
+ ));
+ let predicate = SpatialFilter::try_from_expr(&dwithin_expr).unwrap();
+ assert!(
+ matches!(predicate, SpatialFilter::Intersects(_, _)),
+ "ST_DWithin should produce Intersects filter with expanded bounds"
+ );
+
+ // Test ST_DWithin with reversed geometry arguments
+ let dwithin_expr_reversed: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
+ "st_dwithin",
+ Arc::new(st_dwithin),
+ vec![literal.clone(), column.clone(), distance_literal.clone()],
+ Arc::new(Field::new("", DataType::Boolean, true)),
+ ));
+ let predicate_reversed =
SpatialFilter::try_from_expr(&dwithin_expr_reversed).unwrap();
+ assert!(
+ matches!(predicate_reversed, SpatialFilter::Intersects(_, _)),
+ "ST_DWithin with reversed args should produce Intersects filter"
+ );
+
+ // Test ST_Distance <= threshold
+ let st_distance = create_dummy_spatial_function("st_distance", 2);
+ let distance_expr: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
+ "st_distance",
+ Arc::new(st_distance.clone()),
+ vec![column.clone(), literal.clone()],
+ Arc::new(Field::new("", DataType::Boolean, true)),
+ ));
+ let comparison_expr: Arc<dyn PhysicalExpr> = Arc::new(BinaryExpr::new(
+ distance_expr.clone(),
+ Operator::LtEq,
+ distance_literal.clone(),
+ ));
+ let predicate =
SpatialFilter::try_from_expr(&comparison_expr).unwrap();
+ assert!(
+ matches!(predicate, SpatialFilter::Intersects(_, _)),
+ "ST_Distance <= threshold should produce Intersects filter"
+ );
+
+ // Test threshold >= ST_Distance
+ let comparison_expr_reversed: Arc<dyn PhysicalExpr> =
Arc::new(BinaryExpr::new(
+ distance_literal.clone(),
+ Operator::GtEq,
+ distance_expr.clone(),
+ ));
+ let predicate_reversed =
SpatialFilter::try_from_expr(&comparison_expr_reversed).unwrap();
+ assert!(
+ matches!(predicate_reversed, SpatialFilter::Intersects(_, _)),
+ "threshold >= ST_Distance should produce Intersects filter"
+ );
+
+ // Test with negative distance (should be treated as Unknown)
+ let negative_distance: Arc<dyn PhysicalExpr> =
+ Arc::new(Literal::new(ScalarValue::Float64(Some(-10.0))));
+ let st_dwithin = create_dummy_spatial_function("st_dwithin", 3);
+ let dwithin_expr: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
+ "st_dwithin",
+ Arc::new(st_dwithin.clone()),
+ vec![column.clone(), literal.clone(), negative_distance],
+ Arc::new(Field::new("", DataType::Boolean, true)),
+ ));
+ let predicate = SpatialFilter::try_from_expr(&dwithin_expr).unwrap();
+ assert!(
+ matches!(predicate, SpatialFilter::Unknown),
+ "Negative distance should result in Unknown filter"
+ );
+
+ // Test with NaN distance (should be treated as Unknown)
+ let nan_distance: Arc<dyn PhysicalExpr> =
+ Arc::new(Literal::new(ScalarValue::Float64(Some(f64::NAN))));
+ let dwithin_expr_nan: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
+ "st_dwithin",
+ Arc::new(st_dwithin),
+ vec![column.clone(), literal.clone(), nan_distance],
+ Arc::new(Field::new("", DataType::Boolean, true)),
+ ));
+ let predicate_nan =
SpatialFilter::try_from_expr(&dwithin_expr_nan).unwrap();
+ assert!(
+ matches!(predicate_nan, SpatialFilter::Unknown),
+ "NaN distance should result in Unknown filter"
+ );
+ }
+
+ #[rstest]
+ fn predicate_from_spatial_relation_function_errors(
+ #[values(
+ "st_intersects",
+ "st_equals",
+ "st_touches",
+ "st_contains",
+ "st_covers",
+ "st_within",
+ "st_covered_by",
+ "st_coveredby"
+ )]
+ func_name: &str,
+ ) {
let literal: Arc<dyn PhysicalExpr> =
Arc::new(Literal::new(ScalarValue::Null));
- let st_intersects = dummy_st_intersects();
+ let st_intersects = create_dummy_spatial_function(func_name, 2);
// Wrong number of args
let expr_no_args: Arc<dyn PhysicalExpr> =
Arc::new(ScalarFunctionExpr::new(
diff --git a/rust/sedona-expr/src/utils.rs b/rust/sedona-expr/src/utils.rs
new file mode 100644
index 0000000..25081a9
--- /dev/null
+++ b/rust/sedona-expr/src/utils.rs
@@ -0,0 +1,103 @@
+use std::sync::Arc;
+
+use datafusion_expr::Operator;
+use datafusion_physical_expr::{expressions::BinaryExpr, PhysicalExpr,
ScalarFunctionExpr};
+
+/// Represents a parsed distance predicate with its constituent parts.
+///
+/// Distance predicates are spatial operations that determine whether two
geometries
+/// are within a specified distance of each other. This struct holds the parsed
+/// components of such predicates for further processing.
+///
+/// ## Supported Distance Predicate Forms
+///
+/// This struct can represent the parsed components from any of these distance
predicate forms:
+///
+/// 1. **Direct distance function**:
+/// - `st_dwithin(geom1, geom2, distance)` - Returns true if geometries are
within the distance
+///
+/// 2. **Distance comparison (left-to-right)**:
+/// - `st_distance(geom1, geom2) <= distance` - Distance is less than or
equal to threshold
+/// - `st_distance(geom1, geom2) < distance` - Distance is strictly less
than threshold
+///
+/// 3. **Distance comparison (right-to-left)**:
+/// - `distance >= st_distance(geom1, geom2)` - Threshold is greater than
or equal to distance
+/// - `distance > st_distance(geom1, geom2)` - Threshold is strictly
greater than distance
+///
+/// All forms are logically equivalent but may appear differently in SQL
queries. The parser
+/// normalizes them into this common structure for uniform processing.
+pub struct ParsedDistancePredicate {
+ /// The first geometry argument in the distance predicate
+ pub arg0: Arc<dyn PhysicalExpr>,
+ /// The second geometry argument in the distance predicate
+ pub arg1: Arc<dyn PhysicalExpr>,
+ /// The distance threshold argument (as a physical expression)
+ pub arg_distance: Arc<dyn PhysicalExpr>,
+}
+
+/// Parses a physical expression to extract distance predicate components.
+///
+/// This function recognizes and parses distance predicates in spatial queries.
+/// See [`ParsedDistancePredicate`] documentation for details on the supported
+/// distance predicate forms.
+///
+/// # Arguments
+///
+/// * `expr` - A physical expression that potentially represents a distance
predicate
+///
+/// # Returns
+///
+/// * `Some(ParsedDistancePredicate)` - If the expression is a recognized
distance predicate,
+/// returns the parsed components (two geometry arguments and the distance
threshold)
+/// * `None` - If the expression is not a distance predicate or cannot be
parsed
+///
+/// # Examples
+///
+/// The function can parse expressions like:
+/// - `st_dwithin(geometry_column, POINT(0 0), 100.0)`
+/// - `st_distance(geom_a, geom_b) <= 50.0`
+/// - `25.0 >= st_distance(geom_x, geom_y)`
+pub fn parse_distance_predicate(expr: &Arc<dyn PhysicalExpr>) ->
Option<ParsedDistancePredicate> {
+ if let Some(binary_expr) = expr.as_any().downcast_ref::<BinaryExpr>() {
+ let left = binary_expr.left();
+ let right = binary_expr.right();
+ let (st_distance_expr, distance_bound_expr) = match *binary_expr.op() {
+ Operator::Lt | Operator::LtEq => (left, right),
+ Operator::Gt | Operator::GtEq => (right, left),
+ _ => return None,
+ };
+
+ if let Some(st_distance_expr) = st_distance_expr
+ .as_any()
+ .downcast_ref::<ScalarFunctionExpr>()
+ {
+ if st_distance_expr.fun().name() != "st_distance" {
+ return None;
+ }
+
+ let args = st_distance_expr.args();
+ assert!(args.len() >= 2);
+ Some(ParsedDistancePredicate {
+ arg0: Arc::clone(&args[0]),
+ arg1: Arc::clone(&args[1]),
+ arg_distance: Arc::clone(distance_bound_expr),
+ })
+ } else {
+ None
+ }
+ } else if let Some(st_dwithin_expr) =
expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
+ if st_dwithin_expr.fun().name() != "st_dwithin" {
+ return None;
+ }
+
+ let args = st_dwithin_expr.args();
+ assert!(args.len() >= 3);
+ Some(ParsedDistancePredicate {
+ arg0: Arc::clone(&args[0]),
+ arg1: Arc::clone(&args[1]),
+ arg_distance: Arc::clone(&args[2]),
+ })
+ } else {
+ None
+ }
+}
diff --git a/rust/sedona-geometry/src/bounding_box.rs
b/rust/sedona-geometry/src/bounding_box.rs
index 4d50cd7..fb5b977 100644
--- a/rust/sedona-geometry/src/bounding_box.rs
+++ b/rust/sedona-geometry/src/bounding_box.rs
@@ -108,6 +108,39 @@ impl BoundingBox {
intersects_xy && may_intersect_z && may_intersect_m
}
+ /// Calculate whether this bounding box contains another BoundingBox
+ ///
+ /// Returns true if this bounding box contains other or false otherwise.
+ /// This method will consider Z and M dimension if and only if those
dimensions are present
+ /// in both bounding boxes.
+ pub fn contains(&self, other: &Self) -> bool {
+ let contains_xy = self.x.contains_interval(&other.x) &&
self.y.contains_interval(&other.y);
+ let may_contain_z = match (self.z, other.z) {
+ (Some(z), Some(other_z)) => z.contains_interval(&other_z),
+ _ => true,
+ };
+ let may_contain_m = match (self.m, other.m) {
+ (Some(m), Some(other_m)) => m.contains_interval(&other_m),
+ _ => true,
+ };
+
+ contains_xy && may_contain_z && may_contain_m
+ }
+
+ /// Expand this BoundingBox by a given distance in x and y dimensions only
+ ///
+ /// Returns a new BoundingBox where x and y intervals are expanded by the
given distance.
+ /// The x dimension (which may wrap around) is handled correctly.
+ /// Z and M dimensions are left unchanged.
+ pub fn expand_by(&self, distance: f64) -> Self {
+ Self {
+ x: self.x.expand_by(distance),
+ y: self.y.expand_by(distance),
+ z: self.z,
+ m: self.m,
+ }
+ }
+
/// Update this BoundingBox to include the bounds of another
///
/// This method will propagate missingness of Z or M dimensions from the
two boxes
@@ -188,6 +221,88 @@ mod test {
)));
}
+ #[test]
+ fn bounding_box_contains() {
+ let xyzm = BoundingBox::xyzm(
+ (10, 20),
+ (30, 40),
+ Some((50, 60).into()),
+ Some((70, 80).into()),
+ );
+
+ // Should contain a smaller box completely within bounds
+ assert!(xyzm.contains(&BoundingBox::xy((14, 16), (34, 36))));
+
+ // Should contain itself
+ assert!(xyzm.contains(&xyzm));
+
+ // Should contain a box without z or m information if xy is contained
+ assert!(xyzm.contains(&BoundingBox::xy((12, 18), (32, 38))));
+
+ // Should contain without z information but with contained m
+ assert!(xyzm.contains(&BoundingBox::xyzm(
+ (14, 16),
+ (34, 36),
+ None,
+ Some((74, 76).into())
+ )));
+
+ // Should contain without m information but with contained z
+ assert!(xyzm.contains(&BoundingBox::xyzm(
+ (14, 16),
+ (34, 36),
+ Some((54, 56).into()),
+ None,
+ )));
+
+ // Should contain boxes that touch the boundaries
+ assert!(xyzm.contains(&BoundingBox::xy((10, 20), (30, 40))));
+ assert!(xyzm.contains(&BoundingBox::xy((10, 15), (30, 35))));
+ assert!(xyzm.contains(&BoundingBox::xy((15, 20), (35, 40))));
+
+ // Should *not* contain if x or y extends beyond bounds
+ assert!(!xyzm.contains(&BoundingBox::xy((4, 16), (34, 36)))); // x
extends below
+ assert!(!xyzm.contains(&BoundingBox::xy((14, 26), (34, 36)))); // x
extends above
+ assert!(!xyzm.contains(&BoundingBox::xy((14, 16), (24, 36)))); // y
extends below
+ assert!(!xyzm.contains(&BoundingBox::xy((14, 16), (34, 46)))); // y
extends above
+
+ // Should *not* contain if z is provided but extends beyond bounds
+ assert!(!xyzm.contains(&BoundingBox::xyzm(
+ (14, 16),
+ (34, 36),
+ Some((44, 56).into()), // z extends below
+ None
+ )));
+
+ assert!(!xyzm.contains(&BoundingBox::xyzm(
+ (14, 16),
+ (34, 36),
+ Some((54, 66).into()), // z extends above
+ None
+ )));
+
+ // Should *not* contain if m is provided but extends beyond bounds
+ assert!(!xyzm.contains(&BoundingBox::xyzm(
+ (14, 16),
+ (34, 36),
+ None,
+ Some((64, 76).into()) // m extends below
+ )));
+
+ assert!(!xyzm.contains(&BoundingBox::xyzm(
+ (14, 16),
+ (34, 36),
+ None,
+ Some((74, 86).into()) // m extends above
+ )));
+
+ // Should *not* contain boxes that are completely outside
+ assert!(!xyzm.contains(&BoundingBox::xy((0, 5), (30, 40)))); // x
completely below
+ assert!(!xyzm.contains(&BoundingBox::xy((25, 30), (30, 40)))); // x
completely above
+ assert!(!xyzm.contains(&BoundingBox::xy((10, 20), (0, 25)))); // y
completely below
+ assert!(!xyzm.contains(&BoundingBox::xy((10, 20), (45, 50)))); // y
completely above
+ }
+
#[test]
fn bounding_box_update() {
let xyzm = BoundingBox::xyzm(
@@ -299,4 +414,58 @@ mod test {
assert!(bbox_nan2.y().lo().is_nan());
assert!(bbox_nan2.y().hi().is_nan());
}
+
+ #[test]
+ fn bounding_box_expand_by() {
+ let xyzm = BoundingBox::xyzm(
+ (10, 20),
+ (30, 40),
+ Some((50, 60).into()),
+ Some((70, 80).into()),
+ );
+
+ // Expand by a positive distance - only x and y should change
+ let expanded = xyzm.expand_by(5.0);
+ assert_eq!(expanded.x(), &WraparoundInterval::new(5.0, 25.0));
+ assert_eq!(expanded.y(), &Interval::new(25.0, 45.0));
+ assert_eq!(expanded.z(), &Some(Interval::new(50.0, 60.0))); //
unchanged
+ assert_eq!(expanded.m(), &Some(Interval::new(70.0, 80.0))); //
unchanged
+
+ // Expand by zero does nothing
+ let unchanged = xyzm.expand_by(0.0);
+ assert_eq!(unchanged, xyzm);
+
+ // Expand by negative distance does nothing
+ let unchanged_neg = xyzm.expand_by(-2.0);
+ assert_eq!(unchanged_neg, xyzm);
+
+ // Expand by NaN does nothing
+ let unchanged_nan = xyzm.expand_by(f64::NAN);
+ assert_eq!(unchanged_nan, xyzm);
+
+ // Test with missing z and m dimensions
+ let xy_only = BoundingBox::xy((10, 20), (30, 40));
+ let expanded_xy = xy_only.expand_by(3.0);
+ assert_eq!(expanded_xy.x(), &WraparoundInterval::new(7.0, 23.0));
+ assert_eq!(expanded_xy.y(), &Interval::new(27.0, 43.0));
+ assert!(expanded_xy.z().is_none());
+ assert!(expanded_xy.m().is_none());
+
+ // Test with empty intervals
+ let bbox_with_empty = BoundingBox::xy((10, 20), Interval::empty());
+ let expanded_empty = bbox_with_empty.expand_by(5.0);
+ assert_eq!(expanded_empty.x(), &WraparoundInterval::new(5.0, 25.0));
+ assert_eq!(expanded_empty.y(), &Interval::empty());
+
+ // Test with wraparound x interval
+ let wraparound_x = BoundingBox::xy(WraparoundInterval::new(170.0,
-170.0), (30, 40));
+ let expanded_wraparound = wraparound_x.expand_by(10.0);
+ // Original excludes (-170, 170), expanding by 10 should exclude
(-160, 160)
+ // So the new interval should be (160, -160)
+ assert_eq!(
+ expanded_wraparound.x(),
+ &WraparoundInterval::new(160.0, -160.0)
+ );
+ assert_eq!(expanded_wraparound.y(), &Interval::new(20.0, 50.0));
+ }
}
diff --git a/rust/sedona-geometry/src/interval.rs
b/rust/sedona-geometry/src/interval.rs
index b87d0e6..1037bf7 100644
--- a/rust/sedona-geometry/src/interval.rs
+++ b/rust/sedona-geometry/src/interval.rs
@@ -73,6 +73,15 @@ pub trait IntervalTrait: std::fmt::Debug + PartialEq {
/// `is_wraparound()` when not required for an implementation.
fn intersects_interval(&self, other: &Self) -> bool;
+ /// Check for potential containment of an interval
+ ///
+ /// Note that intervals always contain their endpoints (for both the
wraparound and
+ /// non-wraparound case).
+ ///
+ /// This method accepts Self for performance reasons to prevent
unnecessary checking of
+ /// `is_wraparound()` when not required for an implementation.
+ fn contains_interval(&self, other: &Self) -> bool;
+
/// The width of the interval
///
/// For the non-wraparound case, this is the distance between lo and hi.
For the wraparound
@@ -98,6 +107,13 @@ pub trait IntervalTrait: std::fmt::Debug + PartialEq {
///
/// When accumulating intervals in a loop, use [Interval::update_value].
fn merge_value(&self, other: f64) -> Self;
+
+ /// Expand this interval by a given distance
+ ///
+ /// Returns a new interval where both endpoints are moved outward by the
given distance.
+ /// For regular intervals, this expands both lo and hi by the distance.
+ /// For wraparound intervals, this may result in the full interval if
expansion is large enough.
+ fn expand_by(&self, distance: f64) -> Self;
}
/// 1D Interval that never wraps around
@@ -204,6 +220,10 @@ impl IntervalTrait for Interval {
self.lo <= other.hi && other.lo <= self.hi
}
+ fn contains_interval(&self, other: &Self) -> bool {
+ self.lo <= other.lo && self.hi >= other.hi
+ }
+
fn width(&self) -> f64 {
self.hi - self.lo
}
@@ -227,6 +247,14 @@ impl IntervalTrait for Interval {
out.update_value(other);
out
}
+
+ fn expand_by(&self, distance: f64) -> Self {
+ if self.is_empty() || distance.is_nan() || distance < 0.0 {
+ return *self;
+ }
+
+ Self::new(self.lo - distance, self.hi + distance)
+ }
}
#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
@@ -316,6 +344,12 @@ impl IntervalTrait for WraparoundInterval {
|| right.intersects_interval(&other_right)
}
+ fn contains_interval(&self, other: &Self) -> bool {
+ let (left, right) = self.split();
+ let (other_left, other_right) = other.split();
+ left.contains_interval(&other_left) &&
right.contains_interval(&other_right)
+ }
+
fn width(&self) -> f64 {
if self.is_wraparound() {
f64::INFINITY
@@ -423,6 +457,35 @@ impl IntervalTrait for WraparoundInterval {
}
}
}
+
+ fn expand_by(&self, distance: f64) -> Self {
+ if self.is_empty() || distance.is_nan() || distance < 0.0 {
+ return *self;
+ }
+
+ if !self.is_wraparound() {
+ // For non-wraparound, just expand the inner interval
+ return Self {
+ inner: self.inner.expand_by(distance),
+ };
+ }
+
+ // For wraparound intervals, expanding means including more values
+ // Wraparound interval (a, b) where a > b excludes the region (b, a)
+ // To expand by distance d, we shrink the excluded region from (b, a)
to (b+d, a-d)
+ // This means the new wraparound interval becomes (a-d, b+d)
+ let excluded_lo = self.inner.hi + distance; // b + d
+ let excluded_hi = self.inner.lo - distance; // a - d
+
+ // If the excluded region disappears (excluded_lo >= excluded_hi), we
get the full interval
+ if excluded_lo >= excluded_hi {
+ return Self::full();
+ }
+
+ // The new wraparound interval excludes (excluded_lo, excluded_hi)
+ // So the interval itself is (excluded_hi, excluded_lo)
+ Self::new(excluded_hi, excluded_lo)
+ }
}
#[cfg(test)]
@@ -451,6 +514,13 @@ mod test {
// ...except the full interval
assert!(empty.intersects_interval(&T::full()));
+ // Empty contains no intervals
+ assert!(!empty.contains_interval(&T::new(-10.0, 10.0)));
+ assert!(!empty.contains_interval(&T::full()));
+
+ // ...except empty itself (empty set is subset of itself)
+ assert!(empty.contains_interval(&T::empty()));
+
// Merging NaN is still empty
assert_eq!(empty.merge_value(f64::NAN), empty);
@@ -465,6 +535,12 @@ mod test {
empty.merge_interval(&T::new(10.0, 20.0)),
T::new(10.0, 20.0)
);
+
+ // Expanding empty interval keeps it empty
+ assert_eq!(empty.expand_by(5.0), empty);
+ assert_eq!(empty.expand_by(0.0), empty);
+ assert_eq!(empty.expand_by(-1.0), empty);
+ assert_eq!(empty.expand_by(f64::NAN), empty);
}
#[test]
@@ -528,6 +604,21 @@ mod test {
assert!(!finite.intersects_interval(&T::new(25.0, 30.0)));
assert!(!finite.intersects_interval(&T::empty()));
+ // Intervals that are contained
+ assert!(finite.contains_interval(&T::new(14.0, 16.0)));
+ assert!(finite.contains_interval(&T::new(10.0, 15.0)));
+ assert!(finite.contains_interval(&T::new(15.0, 20.0)));
+ assert!(finite.contains_interval(&T::new(10.0, 20.0))); // itself
+ assert!(finite.contains_interval(&T::empty()));
+
+ // Intervals that are not contained
+ assert!(!finite.contains_interval(&T::new(5.0, 15.0))); // extends
below
+ assert!(!finite.contains_interval(&T::new(15.0, 25.0))); // extends
above
+ assert!(!finite.contains_interval(&T::new(5.0, 25.0))); // extends
both ways
+ assert!(!finite.contains_interval(&T::new(0.0, 5.0))); // completely
below
+ assert!(!finite.contains_interval(&T::new(25.0, 30.0))); // completely
above
+ assert!(!finite.contains_interval(&T::full())); // full interval is
larger
+
// Merging NaN
assert_eq!(finite.merge_value(f64::NAN), finite);
@@ -579,6 +670,19 @@ mod test {
finite.merge_interval(&T::new(25.0, 30.0)),
T::new(10.0, 30.0)
);
+
+ // Expanding by positive distance
+ assert_eq!(finite.expand_by(2.0), T::new(8.0, 22.0));
+ assert_eq!(finite.expand_by(5.0), T::new(5.0, 25.0));
+
+ // Expanding by zero does nothing
+ assert_eq!(finite.expand_by(0.0), finite);
+
+ // Expanding by negative distance does nothing
+ assert_eq!(finite.expand_by(-1.0), finite);
+
+ // Expanding by NaN does nothing
+ assert_eq!(finite.expand_by(f64::NAN), finite);
}
#[test]
@@ -660,6 +764,47 @@ mod test {
assert!(wraparound.intersects_interval(&WraparoundInterval::new(30.0,
25.0)));
}
+ #[test]
+ fn wraparound_interval_actually_wraparound_contains_interval() {
+ // Everything *except* the interval (10, 20)
+ let wraparound = WraparoundInterval::new(20.0, 10.0);
+
+ // Contains itself
+ assert!(wraparound.contains_interval(&wraparound));
+
+ // Empty is contained by everything
+ assert!(wraparound.contains_interval(&WraparoundInterval::empty()));
+
+ // Does not contain the full interval
+ assert!(!wraparound.contains_interval(&WraparoundInterval::full()));
+
+ // Regular interval completely between endpoints is not contained
+ assert!(!wraparound.contains_interval(&WraparoundInterval::new(14.0,
16.0)));
+
+ // Wraparound intervals that exclude more (narrower included regions)
are contained
+ assert!(wraparound.contains_interval(&WraparoundInterval::new(22.0,
8.0))); // excludes (8,22) which is larger than (10,20)
+ assert!(!wraparound.contains_interval(&WraparoundInterval::new(18.0,
12.0))); // excludes (12,18) which is smaller than (10,20)
+
+ // Regular intervals don't work the same way due to the split logic
+ // For a regular interval (a, b), split gives (left=(a,b), right=empty)
+ // For wraparound to contain it, we need both parts to be contained
+ // This means (-inf, 10] must contain (a,b) AND [20, inf) must contain
empty
+ // The second is always true, but the first requires b <= 10
+ assert!(wraparound.contains_interval(&WraparoundInterval::new(0.0,
5.0))); // completely within left part
+ assert!(wraparound.contains_interval(&WraparoundInterval::new(-5.0,
10.0))); // fits in left part
+ assert!(!wraparound.contains_interval(&WraparoundInterval::new(25.0,
30.0))); // doesn't fit in left part
+ assert!(!wraparound.contains_interval(&WraparoundInterval::new(20.0,
25.0))); // doesn't fit in left part
+
+ // Regular intervals that overlap the excluded zone are not contained
+ assert!(!wraparound.contains_interval(&WraparoundInterval::new(5.0,
15.0))); // overlaps excluded zone
+ assert!(!wraparound.contains_interval(&WraparoundInterval::new(15.0,
25.0))); // overlaps excluded zone
+
+ // Wraparound intervals that exclude less (wider included regions) are
not contained
+ assert!(!wraparound.contains_interval(&WraparoundInterval::new(15.0,
5.0))); // excludes (5,15) which is smaller
+ assert!(!wraparound.contains_interval(&WraparoundInterval::new(25.0,
15.0)));
+ // excludes (15,25) which is smaller
+ }
+
#[test]
fn wraparound_interval_actually_wraparound_merge_value() {
// Everything *except* the interval (10, 20)
@@ -834,6 +979,63 @@ mod test {
);
}
+ #[test]
+ fn wraparound_interval_actually_wraparound_expand_by() {
+ // Everything *except* the interval (10, 20)
+ let wraparound = WraparoundInterval::new(20.0, 10.0);
+
+ // Expanding by a small amount shrinks the excluded region
+ // Original excludes (10, 20), expanding by 2 should exclude (12, 18)
+ // So the new interval should be (18, 12) = everything except (12, 18)
+ assert_eq!(
+ wraparound.expand_by(2.0),
+ WraparoundInterval::new(18.0, 12.0)
+ ); // now excludes (12, 18)
+
+ // Expanding by 4 should exclude (14, 16)
+ assert_eq!(
+ wraparound.expand_by(4.0),
+ WraparoundInterval::new(16.0, 14.0)
+ ); // now excludes (14, 16)
+
+ // Expanding by 5.0 should exactly eliminate the excluded region
+ // excluded region (10, 20) shrinks to (15, 15) which is empty
+ assert_eq!(wraparound.expand_by(5.0), WraparoundInterval::full()); //
excluded region disappears
+
+ // Any expansion greater than 5.0 should also give full interval
+ assert_eq!(wraparound.expand_by(6.0), WraparoundInterval::full());
+
+ assert_eq!(wraparound.expand_by(100.0), WraparoundInterval::full());
+
+ // Expanding by zero does nothing
+ assert_eq!(wraparound.expand_by(0.0), wraparound);
+
+ // Expanding by negative distance does nothing
+ assert_eq!(wraparound.expand_by(-1.0), wraparound);
+
+ // Expanding by NaN does nothing
+ assert_eq!(wraparound.expand_by(f64::NAN), wraparound);
+
+ // Test a finite (non-wraparound) wraparound interval
+ let non_wraparound = WraparoundInterval::new(10.0, 20.0);
+ assert!(!non_wraparound.is_wraparound());
+ assert_eq!(
+ non_wraparound.expand_by(2.0),
+ WraparoundInterval::new(8.0, 22.0)
+ );
+
+ // Test another wraparound case - excludes (5, 15) with width 10
+ let wraparound2 = WraparoundInterval::new(15.0, 5.0);
+ // Expanding by 3 should shrink excluded region from (5, 15) to (8, 12)
+ assert_eq!(
+ wraparound2.expand_by(3.0),
+ WraparoundInterval::new(12.0, 8.0)
+ );
+
+ // Expanding by 5 should make excluded region disappear: (5+5, 15-5) =
(10, 10)
+ assert_eq!(wraparound2.expand_by(5.0), WraparoundInterval::full());
+ }
+
#[test]
fn wraparound_interval_actually_wraparound_convert() {
// Everything *except* the interval (10, 20)
diff --git a/rust/sedona-geoparquet/Cargo.toml
b/rust/sedona-geoparquet/Cargo.toml
index ee5ffe6..6c1ffa1 100644
--- a/rust/sedona-geoparquet/Cargo.toml
+++ b/rust/sedona-geoparquet/Cargo.toml
@@ -33,6 +33,7 @@ default = []
[dev-dependencies]
sedona-testing = { path = "../sedona-testing" }
url = { workspace = true }
+rstest = { workspace = true }
[dependencies]
async-trait = { workspace = true }
diff --git a/rust/sedona-geoparquet/src/format.rs
b/rust/sedona-geoparquet/src/format.rs
index 74c6fcf..93346d2 100644
--- a/rust/sedona-geoparquet/src/format.rs
+++ b/rust/sedona-geoparquet/src/format.rs
@@ -352,7 +352,7 @@ impl GeoParquetFileSource {
if let Some(parquet_source) =
inner.as_any().downcast_ref::<ParquetSource>() {
let mut parquet_source = parquet_source.clone();
- // Extract the precicate from the existing source if it exists so
we can keep a copy of it
+ // Extract the predicate from the existing source if it exists so
we can keep a copy of it
let new_predicate = match (parquet_source.predicate().cloned(),
predicate) {
(None, None) => None,
(None, Some(specified_predicate)) => Some(specified_predicate),
@@ -530,6 +530,7 @@ mod test {
use datafusion_physical_expr::expressions::{BinaryExpr, Column, Literal};
use datafusion_physical_expr::PhysicalExpr;
+ use rstest::rstest;
use sedona_schema::crs::lnglat;
use sedona_schema::datatypes::{Edges, SedonaType, WKB_GEOMETRY};
use sedona_testing::create::create_scalar;
@@ -675,21 +676,24 @@ mod test {
assert_eq!(total_size, 244);
}
+ #[rstest]
#[tokio::test]
- async fn pruning_geoparquet_metadata() {
+ async fn pruning_geoparquet_metadata(#[values("st_intersects",
"st_within")] udf_name: &str) {
let data_dir = geoarrow_data_dir().unwrap();
let ctx = setup_context();
let udf: ScalarUDF = SimpleScalarUDF::new_with_signature(
- "st_intersects",
+ udf_name,
Signature::any(2, Volatility::Immutable),
DataType::Boolean,
Arc::new(|_args| Ok(ScalarValue::Boolean(Some(true)).into())),
)
.into();
- let definitely_non_intersecting_scalar =
- create_scalar(Some("POINT (100 200)"), &WKB_GEOMETRY);
+ let definitely_non_intersecting_scalar = create_scalar(
+ Some("POLYGON ((100 200), (100 300), (200 300), (100 200))"),
+ &WKB_GEOMETRY,
+ );
let storage_field = WKB_GEOMETRY.to_storage_field("", true).unwrap();
let df = ctx
@@ -708,7 +712,10 @@ mod test {
let batches_out = df.collect().await.unwrap();
assert!(batches_out.is_empty());
- let definitely_intersecting_scalar = create_scalar(Some("POINT (30
10)"), &WKB_GEOMETRY);
+ let definitely_intersecting_scalar = create_scalar(
+ Some("POLYGON ((30 10), (30 20), (40 20), (40 10), (30 10))"),
+ &WKB_GEOMETRY,
+ );
let df = ctx
.table(format!("{data_dir}/example/files/*_geo.parquet"))
.await
diff --git a/rust/sedona-spatial-join/src/operand_evaluator.rs
b/rust/sedona-spatial-join/src/operand_evaluator.rs
index d945f71..56dca64 100644
--- a/rust/sedona-spatial-join/src/operand_evaluator.rs
+++ b/rust/sedona-spatial-join/src/operand_evaluator.rs
@@ -18,6 +18,7 @@ use core::fmt;
use std::{mem::transmute, sync::Arc};
use arrow_array::{Array, ArrayRef, Float64Array, RecordBatch};
+use arrow_schema::DataType;
use datafusion_common::{
utils::proxy::VecAllocExt, DataFusionError, JoinSide, Result, ScalarValue,
};
@@ -240,6 +241,8 @@ impl DistanceOperandEvaluator {
// Expand the vec by distance
let distance_columnar_value = self.inner.distance.evaluate(batch)?;
+ // No timezone conversion needed for distance; pass None as
cast_options explicitly.
+ let distance_columnar_value =
distance_columnar_value.cast_to(&DataType::Float64, None)?;
match &distance_columnar_value {
ColumnarValue::Scalar(ScalarValue::Float64(Some(distance))) => {
result.rects.iter_mut().for_each(|(_, rect)| {
diff --git a/rust/sedona-spatial-join/src/optimizer.rs
b/rust/sedona-spatial-join/src/optimizer.rs
index b45d14c..a576014 100644
--- a/rust/sedona-spatial-join/src/optimizer.rs
+++ b/rust/sedona-spatial-join/src/optimizer.rs
@@ -40,6 +40,7 @@ use datafusion_physical_plan::joins::utils::ColumnIndex;
use datafusion_physical_plan::joins::{HashJoinExec, NestedLoopJoinExec};
use datafusion_physical_plan::{joins::utils::JoinFilter, ExecutionPlan};
use sedona_common::{option::SedonaOptions, sedona_internal_err};
+use sedona_expr::utils::{parse_distance_predicate, ParsedDistancePredicate};
/// Physical planner extension for spatial joins
///
@@ -594,60 +595,24 @@ fn match_distance_predicate(
expr: &Arc<dyn PhysicalExpr>,
column_indices: &[ColumnIndex],
) -> Option<DistancePredicate> {
- // There are 3 forms of distance predicates:
- // 1. st_dwithin(geom1, geom2, distance)
- // 2. st_distance(geom1, geom2) <= distance or st_distance(geom1, geom2) <
distance
- // 3. distance >= st_distance(geom1, geom2) or distance >
st_distance(geom1, geom2)
- let (arg0, arg1, distance_bound_expr) =
- if let Some(binary_expr) = expr.as_any().downcast_ref::<BinaryExpr>() {
- // handle case 2. and 3.
- let left = binary_expr.left();
- let right = binary_expr.right();
- let (st_distance_expr, distance_bound_expr) = match
*binary_expr.op() {
- Operator::Lt | Operator::LtEq => (left, right),
- Operator::Gt | Operator::GtEq => (right, left),
- _ => return None,
- };
-
- if let Some(st_distance_expr) = st_distance_expr
- .as_any()
- .downcast_ref::<ScalarFunctionExpr>()
- {
- if st_distance_expr.fun().name() != "st_distance" {
- return None;
- }
-
- let args = st_distance_expr.args();
- assert!(args.len() >= 2);
- (&args[0], &args[1], distance_bound_expr)
- } else {
- return None;
- }
- } else if let Some(st_dwithin_expr) =
expr.as_any().downcast_ref::<ScalarFunctionExpr>() {
- // handle case 1.
- if st_dwithin_expr.fun().name() != "st_dwithin" {
- return None;
- }
-
- let args = st_dwithin_expr.args();
- assert!(args.len() >= 3);
- (&args[0], &args[1], &args[2])
- } else {
- return None;
- };
+ let ParsedDistancePredicate {
+ arg0,
+ arg1,
+ arg_distance,
+ } = parse_distance_predicate(expr)?;
// Try to find the expressions that evaluates to the arguments of the
spatial function
- let arg0_refs = collect_column_references(arg0, column_indices);
- let arg1_refs = collect_column_references(arg1, column_indices);
- let arg_dist_refs = collect_column_references(distance_bound_expr,
column_indices);
+ let arg0_refs = collect_column_references(&arg0, column_indices);
+ let arg1_refs = collect_column_references(&arg1, column_indices);
+ let arg_dist_refs = collect_column_references(&arg_distance,
column_indices);
let arg_dist_side = side_of_column_references(&arg_dist_refs)?;
let (arg0_side, arg1_side) = resolve_column_reference_sides(&arg0_refs,
&arg1_refs)?;
- let arg0_reprojected = reproject_column_references_for_side(arg0,
column_indices, arg0_side);
- let arg1_reprojected = reproject_column_references_for_side(arg1,
column_indices, arg1_side);
+ let arg0_reprojected = reproject_column_references_for_side(&arg0,
column_indices, arg0_side);
+ let arg1_reprojected = reproject_column_references_for_side(&arg1,
column_indices, arg1_side);
let arg_dist_reprojected =
- reproject_column_references_for_side(distance_bound_expr,
column_indices, arg_dist_side);
+ reproject_column_references_for_side(&arg_distance, column_indices,
arg_dist_side);
match (arg0_side, arg1_side) {
(JoinSide::Left, JoinSide::Right) => Some(DistancePredicate::new(