This is an automated email from the ASF dual-hosted git repository.
kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new a2d3244c fix(rust/sedona-spatial-join): Fix several bugs related to
KNN join (#508)
a2d3244c is described below
commit a2d3244c7f79d6dd964ad198699bac67c39dc51f
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Jan 14 13:29:29 2026 +0800
fix(rust/sedona-spatial-join): Fix several bugs related to KNN join (#508)
Fixes three critical bugs in the KNN join implementation: handling
different partition counts between joined relations, correcting column
projection when left/right sides have different column counts, and fixing
incorrect K-nearest neighbor search results for non-point query objects with
tie-breakers enabled.
**Changes:**
- Fixed build/probe side determination for KNN joins to use
`knn.probe_side.negate()` instead of hardcoded `JoinSide::Right`
- Corrected tie-breaker envelope calculation to use bounding box expansion
instead of centroid-based approach
- Removed incorrect column index swapping logic for KNN joins
- Added comprehensive test coverage for KNN join correctness
---
rust/sedona-spatial-join/src/exec.rs | 203 +++++++++++++++++++--
.../sedona-spatial-join/src/index/spatial_index.rs | 76 +++++---
rust/sedona-spatial-join/src/stream.rs | 6 +-
3 files changed, 240 insertions(+), 45 deletions(-)
diff --git a/rust/sedona-spatial-join/src/exec.rs
b/rust/sedona-spatial-join/src/exec.rs
index c8929bff..e3440cd1 100644
--- a/rust/sedona-spatial-join/src/exec.rs
+++ b/rust/sedona-spatial-join/src/exec.rs
@@ -171,6 +171,7 @@ impl SpatialJoinExec {
let cache = Self::compute_properties(
&left,
&right,
+ &on,
Arc::clone(&join_schema),
*join_type,
projection.as_ref(),
@@ -236,9 +237,11 @@ impl SpatialJoinExec {
///
/// When converted from HashJoin, we preserve HashJoin's equivalence
properties by extracting
/// equality conditions from the filter.
+ #[allow(clippy::too_many_arguments)]
fn compute_properties(
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
+ on: &SpatialPredicate,
schema: SchemaRef,
join_type: JoinType,
projection: Option<&Vec<usize>>,
@@ -265,7 +268,13 @@ impl SpatialJoinExec {
// Use symmetric partitioning (like HashJoin) when converted from
HashJoin
// Otherwise use asymmetric partitioning (like NestedLoopJoin)
- let mut output_partitioning = if converted_from_hash_join {
+ let mut output_partitioning = if let
SpatialPredicate::KNearestNeighbors(knn) = on {
+ match knn.probe_side {
+ JoinSide::Left => left.output_partitioning().clone(),
+ JoinSide::Right => right.output_partitioning().clone(),
+ _ => asymmetric_join_output_partitioning(left, right,
&join_type),
+ }
+ } else if converted_from_hash_join {
// Replicate HashJoin's symmetric partitioning logic
// HashJoin preserves partitioning from both sides for inner joins
// and from one side for outer joins
@@ -467,7 +476,6 @@ impl ExecutionPlan for SpatialJoinExec {
})?
};
- // Column indices for regular joins - no swapping needed
let column_indices_after_projection = match &self.projection {
Some(projection) => projection
.iter()
@@ -559,8 +567,7 @@ impl SpatialJoinExec {
})?
};
- // Handle column indices for KNN - need to swap if we swapped
execution plans
- let mut column_indices_after_projection = match &self.projection {
+ let column_indices_after_projection = match &self.projection {
Some(projection) => projection
.iter()
.map(|i| self.column_indices[*i].clone())
@@ -568,21 +575,6 @@ impl SpatialJoinExec {
None => self.column_indices.clone(),
};
- // If we swapped execution plans for KNN, we need to swap the column
indices too
- if !actual_probe_plan_is_left {
- for col_idx in &mut column_indices_after_projection {
- match col_idx.side {
- datafusion_common::JoinSide::Left => {
- col_idx.side = datafusion_common::JoinSide::Right
- }
- datafusion_common::JoinSide::Right => {
- col_idx.side = datafusion_common::JoinSide::Left
- }
- datafusion_common::JoinSide::None => {} // No change needed
- }
- }
- }
-
let join_metrics = SpatialJoinProbeMetrics::new(partition,
&self.metrics);
let probe_stream = probe_plan.execute(partition,
Arc::clone(&context))?;
@@ -614,7 +606,7 @@ impl SpatialJoinExec {
#[cfg(test)]
mod tests {
- use arrow_array::RecordBatch;
+ use arrow_array::{Array, RecordBatch};
use arrow_schema::{DataType, Field, Schema};
use datafusion::{
catalog::{MemTable, TableProvider},
@@ -622,8 +614,11 @@ mod tests {
prelude::{SessionConfig, SessionContext},
};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
+ use datafusion_expr::ColumnarValue;
+ use geo::{Distance, Euclidean};
use geo_types::{Coord, Rect};
use rstest::rstest;
+ use sedona_geo::to_geo::GeoTypesExecutor;
use sedona_geometry::types::GeometryTypeId;
use sedona_schema::datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY};
use sedona_testing::datagen::RandomPartitionedDataBuilder;
@@ -691,6 +686,40 @@ mod tests {
Ok((left_data, right_data))
}
+ /// Creates test data for KNN join (Point-Point)
+ fn create_knn_test_data(
+ size_range: (f64, f64),
+ sedona_type: SedonaType,
+ ) -> Result<(TestPartitions, TestPartitions)> {
+ let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y:
100.0 });
+
+ let left_data = RandomPartitionedDataBuilder::new()
+ .seed(1)
+ .num_partitions(2)
+ .batches_per_partition(2)
+ .rows_per_batch(30)
+ .geometry_type(GeometryTypeId::Point)
+ .sedona_type(sedona_type.clone())
+ .bounds(bounds)
+ .size_range(size_range)
+ .null_rate(0.1)
+ .build()?;
+
+ let right_data = RandomPartitionedDataBuilder::new()
+ .seed(2)
+ .num_partitions(4)
+ .batches_per_partition(4)
+ .rows_per_batch(30)
+ .geometry_type(GeometryTypeId::Point)
+ .sedona_type(sedona_type)
+ .bounds(bounds)
+ .size_range(size_range)
+ .null_rate(0.1)
+ .build()?;
+
+ Ok((left_data, right_data))
+ }
+
fn setup_context(
options: Option<SpatialJoinOptions>,
batch_size: usize,
@@ -1173,4 +1202,138 @@ mod tests {
})?;
Ok(spatial_join_execs)
}
+
+ fn extract_geoms_and_ids(partitions: &[Vec<RecordBatch>]) -> Vec<(i32,
geo::Geometry<f64>)> {
+ let mut result = Vec::new();
+ for partition in partitions {
+ for batch in partition {
+ let id_idx = batch.schema().index_of("id").expect("Id column
not found");
+ let ids = batch
+ .column(id_idx)
+ .as_any()
+ .downcast_ref::<arrow_array::Int32Array>()
+ .expect("Column 'id' should be Int32");
+
+ let geom_idx = batch
+ .schema()
+ .index_of("geometry")
+ .expect("Geometry column not found");
+
+ let geoms_col = batch.column(geom_idx);
+ let geom_type =
SedonaType::from_storage_field(batch.schema().field(geom_idx))
+ .expect("Failed to get SedonaType from geometry field");
+ let arg_types = [geom_type];
+ let arg_values = [ColumnarValue::Array(Arc::clone(geoms_col))];
+
+ let executor = GeoTypesExecutor::new(&arg_types, &arg_values);
+ let mut id_iter = ids.iter();
+ executor
+ .execute_wkb_void(|maybe_geom| {
+ if let Some(id_opt) = id_iter.next() {
+ if let (Some(id), Some(geom)) = (id_opt,
maybe_geom) {
+ result.push((id, geom))
+ }
+ }
+ Ok(())
+ })
+ .expect("Failed to extract geoms and ids from
RecordBatch");
+ }
+ }
+ result
+ }
+
+ fn compute_knn_ground_truth(
+ left_partitions: &[Vec<RecordBatch>],
+ right_partitions: &[Vec<RecordBatch>],
+ k: usize,
+ ) -> Vec<(i32, i32, f64)> {
+ let left_data = extract_geoms_and_ids(left_partitions);
+ let right_data = extract_geoms_and_ids(right_partitions);
+
+ let mut results = Vec::new();
+
+ for (l_id, l_geom) in left_data {
+ let mut distances: Vec<(i32, f64)> = right_data
+ .iter()
+ .map(|(r_id, r_geom)| (*r_id, Euclidean.distance(&l_geom,
r_geom)))
+ .collect();
+
+ // Sort by distance, then by ID for stability
+ distances.sort_by(|a, b| a.1.total_cmp(&b.1).then_with(||
a.0.cmp(&b.0)));
+
+ for (r_id, dist) in distances.iter().take(k.min(distances.len())) {
+ results.push((l_id, *r_id, *dist));
+ }
+ }
+
+ // Sort results by L.id, R.id
+ results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
+ results
+ }
+
+ #[tokio::test]
+ async fn test_knn_join_correctness() -> Result<()> {
+ // Generate slightly larger data
+ let ((left_schema, left_partitions), (right_schema, right_partitions))
=
+ create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)?;
+
+ let options = SpatialJoinOptions::default();
+ let k = 3;
+
+ let sql1 = format!(
+ "SELECT L.id, R.id, ST_Distance(L.geometry, R.geometry) FROM L
JOIN R ON ST_KNN(L.geometry, R.geometry, {}, false) ORDER BY L.id, R.id",
+ k
+ );
+ let expected1 = compute_knn_ground_truth(&left_partitions,
&right_partitions, k)
+ .into_iter()
+ .map(|(l, r, _)| (l, r))
+ .collect::<Vec<_>>();
+
+ let sql2 = format!(
+ "SELECT R.id, L.id, ST_Distance(L.geometry, R.geometry) FROM L
JOIN R ON ST_KNN(R.geometry, L.geometry, {}, false) ORDER BY R.id, L.id",
+ k
+ );
+ let expected2 = compute_knn_ground_truth(&right_partitions,
&left_partitions, k)
+ .into_iter()
+ .map(|(l, r, _)| (l, r))
+ .collect::<Vec<_>>();
+
+ let sqls = [(&sql1, &expected1), (&sql2, &expected2)];
+
+ for (sql, expected_results) in sqls {
+ let batches = run_spatial_join_query(
+ &left_schema,
+ &right_schema,
+ left_partitions.clone(),
+ right_partitions.clone(),
+ Some(options.clone()),
+ 10,
+ sql,
+ )
+ .await?;
+
+ // Collect actual results
+ let mut actual_results = Vec::new();
+ let combined_batch =
arrow::compute::concat_batches(&batches.schema(), &[batches])?;
+ let l_ids = combined_batch
+ .column(0)
+ .as_any()
+ .downcast_ref::<arrow_array::Int32Array>()
+ .unwrap();
+ let r_ids = combined_batch
+ .column(1)
+ .as_any()
+ .downcast_ref::<arrow_array::Int32Array>()
+ .unwrap();
+
+ for i in 0..combined_batch.num_rows() {
+ actual_results.push((l_ids.value(i), r_ids.value(i)));
+ }
+ actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(||
a.1.cmp(&b.1)));
+
+ assert_eq!(actual_results, *expected_results);
+ }
+
+ Ok(())
+ }
}
diff --git a/rust/sedona-spatial-join/src/index/spatial_index.rs
b/rust/sedona-spatial-join/src/index/spatial_index.rs
index 0733571b..64f2cce7 100644
--- a/rust/sedona-spatial-join/src/index/spatial_index.rs
+++ b/rust/sedona-spatial-join/src/index/spatial_index.rs
@@ -24,14 +24,18 @@ use arrow_array::RecordBatch;
use arrow_schema::SchemaRef;
use datafusion_common::Result;
use datafusion_execution::memory_pool::{MemoryPool, MemoryReservation};
-use geo_index::rtree::distance::{DistanceMetric, GeometryAccessor};
+use float_next_after::NextAfter;
+use geo::BoundingRect;
+use geo_index::rtree::{
+ distance::{DistanceMetric, GeometryAccessor},
+ util::f64_box_to_f32,
+};
use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder, RTreeIndex};
use geo_index::IndexableNum;
-use geo_types::{Point, Rect};
+use geo_types::Rect;
use parking_lot::Mutex;
use sedona_expr::statistics::GeoStatistics;
use sedona_geo::to_geo::item_to_geometry;
-use sedona_geo_generic_alg::algorithm::Centroid;
use wkb::reader::Wkb;
use crate::{
@@ -318,25 +322,28 @@ impl SpatialIndex {
// For tie-breakers, create spatial envelope around probe
centroid and use rtree.search()
- let probe_centroid =
probe_geom.centroid().unwrap_or(Point::new(0.0, 0.0));
- let probe_x = probe_centroid.x() as f32;
- let probe_y = probe_centroid.y() as f32;
- let max_distance_f32 = match f32::from_f64(max_distance) {
- Some(val) => val,
- None => {
- // If conversion fails, return empty results for this
probe
- return Ok(QueryResultMetrics {
- count: 0,
- candidate_count: 0,
- });
- }
+ // Create envelope bounds by expanding the probe bounding box
by max_distance
+ let Some(rect) = probe_geom.bounding_rect() else {
+ // If bounding rectangle cannot be computed, return empty
results
+ return Ok(QueryResultMetrics {
+ count: 0,
+ candidate_count: 0,
+ });
};
- // Create envelope bounds around probe centroid
- let min_x = probe_x - max_distance_f32;
- let min_y = probe_y - max_distance_f32;
- let max_x = probe_x + max_distance_f32;
- let max_y = probe_y + max_distance_f32;
+ let min = rect.min();
+ let max = rect.max();
+ let (min_x, min_y, max_x, max_y) = f64_box_to_f32(min.x,
min.y, max.x, max.y);
+ let mut distance_f32 = max_distance as f32;
+ if (distance_f32 as f64) < max_distance {
+ distance_f32 = distance_f32.next_after(f32::INFINITY);
+ }
+ let (min_x, min_y, max_x, max_y) = (
+ min_x - distance_f32,
+ min_y - distance_f32,
+ max_x + distance_f32,
+ max_y + distance_f32,
+ );
// Use rtree.search() with envelope bounds (like the old code)
let expanded_results = self.rtree.search(min_x, min_y, max_x,
max_y);
@@ -1407,8 +1414,33 @@ mod tests {
)
.unwrap();
- // Should return more than 2 results because of ties (all 4 points at
distance sqrt(2))
- assert!(result_with_ties.count >= 2);
+ // Should return 4 results because of ties (all 4 points at distance
sqrt(2))
+ assert!(result_with_ties.count == 4);
+
+ // Query using a box centered at the origin
+ let query_geom = create_array(
+ &[Some(
+ "POLYGON ((-0.5 -0.5, -0.5 0.5, 0.5 0.5, 0.5 -0.5, -0.5
-0.5))",
+ )],
+ &WKB_GEOMETRY,
+ );
+ let query_array = EvaluatedGeometryArray::try_new(query_geom,
&WKB_GEOMETRY).unwrap();
+ let query_wkb = &query_array.wkbs()[0].as_ref().unwrap();
+
+ // This query should return 4 points
+ let mut build_positions_with_ties = Vec::new();
+ let result_with_ties = index
+ .query_knn(
+ query_wkb,
+ 2, // k=2
+ false, // use_spheroid
+ true, // include_tie_breakers=true
+ &mut build_positions_with_ties,
+ )
+ .unwrap();
+
+ // Should return 4 results because of ties (all 4 points at distance
sqrt(2))
+ assert!(result_with_ties.count == 4);
}
#[test]
diff --git a/rust/sedona-spatial-join/src/stream.rs
b/rust/sedona-spatial-join/src/stream.rs
index cb8a4e3d..7fa42231 100644
--- a/rust/sedona-spatial-join/src/stream.rs
+++ b/rust/sedona-spatial-join/src/stream.rs
@@ -244,10 +244,10 @@ impl SpatialJoinStream {
// Extract the necessary data first to avoid borrowing conflicts
let (batch_opt, is_complete) = match &mut self.state {
SpatialJoinStreamState::ProcessProbeBatch(iterator) => {
- // For KNN joins, we swapped build/probe sides, so build_side
should be Right
- // For regular joins, build_side is Left
+ // For KNN joins, we may have swapped build/probe sides, so
build_side might be Right;
+ // For regular joins, build_side is always Left.
let build_side = match &self.spatial_predicate {
- SpatialPredicate::KNearestNeighbors(_) => JoinSide::Right,
+ SpatialPredicate::KNearestNeighbors(knn) =>
knn.probe_side.negate(),
_ => JoinSide::Left,
};