This is an automated email from the ASF dual-hosted git repository.

kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new a2d3244c fix(rust/sedona-spatial-join): Fix several bugs related to 
KNN join (#508)
a2d3244c is described below

commit a2d3244c7f79d6dd964ad198699bac67c39dc51f
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Wed Jan 14 13:29:29 2026 +0800

    fix(rust/sedona-spatial-join): Fix several bugs related to KNN join (#508)
    
    Fixes three critical bugs in the KNN join implementation: handling 
different partition counts between joined relations, correcting column 
projection when left/right sides have different column counts, and fixing 
incorrect K-nearest neighbor search results for non-point query objects with 
tie-breakers enabled.
    
    **Changes:**
    - Fixed build/probe side determination for KNN joins to use 
`knn.probe_side.negate()` instead of hardcoded `JoinSide::Right`
    - Corrected tie-breaker envelope calculation to use bounding box expansion 
instead of centroid-based approach
    - Removed incorrect column index swapping logic for KNN joins
    - Added comprehensive test coverage for KNN join correctness
---
 rust/sedona-spatial-join/src/exec.rs               | 203 +++++++++++++++++++--
 .../sedona-spatial-join/src/index/spatial_index.rs |  76 +++++---
 rust/sedona-spatial-join/src/stream.rs             |   6 +-
 3 files changed, 240 insertions(+), 45 deletions(-)

diff --git a/rust/sedona-spatial-join/src/exec.rs 
b/rust/sedona-spatial-join/src/exec.rs
index c8929bff..e3440cd1 100644
--- a/rust/sedona-spatial-join/src/exec.rs
+++ b/rust/sedona-spatial-join/src/exec.rs
@@ -171,6 +171,7 @@ impl SpatialJoinExec {
         let cache = Self::compute_properties(
             &left,
             &right,
+            &on,
             Arc::clone(&join_schema),
             *join_type,
             projection.as_ref(),
@@ -236,9 +237,11 @@ impl SpatialJoinExec {
     ///
     /// When converted from HashJoin, we preserve HashJoin's equivalence 
properties by extracting
     /// equality conditions from the filter.
+    #[allow(clippy::too_many_arguments)]
     fn compute_properties(
         left: &Arc<dyn ExecutionPlan>,
         right: &Arc<dyn ExecutionPlan>,
+        on: &SpatialPredicate,
         schema: SchemaRef,
         join_type: JoinType,
         projection: Option<&Vec<usize>>,
@@ -265,7 +268,13 @@ impl SpatialJoinExec {
 
         // Use symmetric partitioning (like HashJoin) when converted from 
HashJoin
         // Otherwise use asymmetric partitioning (like NestedLoopJoin)
-        let mut output_partitioning = if converted_from_hash_join {
+        let mut output_partitioning = if let 
SpatialPredicate::KNearestNeighbors(knn) = on {
+            match knn.probe_side {
+                JoinSide::Left => left.output_partitioning().clone(),
+                JoinSide::Right => right.output_partitioning().clone(),
+                _ => asymmetric_join_output_partitioning(left, right, 
&join_type),
+            }
+        } else if converted_from_hash_join {
             // Replicate HashJoin's symmetric partitioning logic
             // HashJoin preserves partitioning from both sides for inner joins
             // and from one side for outer joins
@@ -467,7 +476,6 @@ impl ExecutionPlan for SpatialJoinExec {
                         })?
                 };
 
-                // Column indices for regular joins - no swapping needed
                 let column_indices_after_projection = match &self.projection {
                     Some(projection) => projection
                         .iter()
@@ -559,8 +567,7 @@ impl SpatialJoinExec {
                 })?
         };
 
-        // Handle column indices for KNN - need to swap if we swapped 
execution plans
-        let mut column_indices_after_projection = match &self.projection {
+        let column_indices_after_projection = match &self.projection {
             Some(projection) => projection
                 .iter()
                 .map(|i| self.column_indices[*i].clone())
@@ -568,21 +575,6 @@ impl SpatialJoinExec {
             None => self.column_indices.clone(),
         };
 
-        // If we swapped execution plans for KNN, we need to swap the column 
indices too
-        if !actual_probe_plan_is_left {
-            for col_idx in &mut column_indices_after_projection {
-                match col_idx.side {
-                    datafusion_common::JoinSide::Left => {
-                        col_idx.side = datafusion_common::JoinSide::Right
-                    }
-                    datafusion_common::JoinSide::Right => {
-                        col_idx.side = datafusion_common::JoinSide::Left
-                    }
-                    datafusion_common::JoinSide::None => {} // No change needed
-                }
-            }
-        }
-
         let join_metrics = SpatialJoinProbeMetrics::new(partition, 
&self.metrics);
         let probe_stream = probe_plan.execute(partition, 
Arc::clone(&context))?;
 
@@ -614,7 +606,7 @@ impl SpatialJoinExec {
 
 #[cfg(test)]
 mod tests {
-    use arrow_array::RecordBatch;
+    use arrow_array::{Array, RecordBatch};
     use arrow_schema::{DataType, Field, Schema};
     use datafusion::{
         catalog::{MemTable, TableProvider},
@@ -622,8 +614,11 @@ mod tests {
         prelude::{SessionConfig, SessionContext},
     };
     use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
+    use datafusion_expr::ColumnarValue;
+    use geo::{Distance, Euclidean};
     use geo_types::{Coord, Rect};
     use rstest::rstest;
+    use sedona_geo::to_geo::GeoTypesExecutor;
     use sedona_geometry::types::GeometryTypeId;
     use sedona_schema::datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY};
     use sedona_testing::datagen::RandomPartitionedDataBuilder;
@@ -691,6 +686,40 @@ mod tests {
         Ok((left_data, right_data))
     }
 
+    /// Creates test data for KNN join (Point-Point)
+    fn create_knn_test_data(
+        size_range: (f64, f64),
+        sedona_type: SedonaType,
+    ) -> Result<(TestPartitions, TestPartitions)> {
+        let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 
100.0 });
+
+        let left_data = RandomPartitionedDataBuilder::new()
+            .seed(1)
+            .num_partitions(2)
+            .batches_per_partition(2)
+            .rows_per_batch(30)
+            .geometry_type(GeometryTypeId::Point)
+            .sedona_type(sedona_type.clone())
+            .bounds(bounds)
+            .size_range(size_range)
+            .null_rate(0.1)
+            .build()?;
+
+        let right_data = RandomPartitionedDataBuilder::new()
+            .seed(2)
+            .num_partitions(4)
+            .batches_per_partition(4)
+            .rows_per_batch(30)
+            .geometry_type(GeometryTypeId::Point)
+            .sedona_type(sedona_type)
+            .bounds(bounds)
+            .size_range(size_range)
+            .null_rate(0.1)
+            .build()?;
+
+        Ok((left_data, right_data))
+    }
+
     fn setup_context(
         options: Option<SpatialJoinOptions>,
         batch_size: usize,
@@ -1173,4 +1202,138 @@ mod tests {
         })?;
         Ok(spatial_join_execs)
     }
+
+    fn extract_geoms_and_ids(partitions: &[Vec<RecordBatch>]) -> Vec<(i32, 
geo::Geometry<f64>)> {
+        let mut result = Vec::new();
+        for partition in partitions {
+            for batch in partition {
+                let id_idx = batch.schema().index_of("id").expect("Id column 
not found");
+                let ids = batch
+                    .column(id_idx)
+                    .as_any()
+                    .downcast_ref::<arrow_array::Int32Array>()
+                    .expect("Column 'id' should be Int32");
+
+                let geom_idx = batch
+                    .schema()
+                    .index_of("geometry")
+                    .expect("Geometry column not found");
+
+                let geoms_col = batch.column(geom_idx);
+                let geom_type = 
SedonaType::from_storage_field(batch.schema().field(geom_idx))
+                    .expect("Failed to get SedonaType from geometry field");
+                let arg_types = [geom_type];
+                let arg_values = [ColumnarValue::Array(Arc::clone(geoms_col))];
+
+                let executor = GeoTypesExecutor::new(&arg_types, &arg_values);
+                let mut id_iter = ids.iter();
+                executor
+                    .execute_wkb_void(|maybe_geom| {
+                        if let Some(id_opt) = id_iter.next() {
+                            if let (Some(id), Some(geom)) = (id_opt, 
maybe_geom) {
+                                result.push((id, geom))
+                            }
+                        }
+                        Ok(())
+                    })
+                    .expect("Failed to extract geoms and ids from 
RecordBatch");
+            }
+        }
+        result
+    }
+
+    fn compute_knn_ground_truth(
+        left_partitions: &[Vec<RecordBatch>],
+        right_partitions: &[Vec<RecordBatch>],
+        k: usize,
+    ) -> Vec<(i32, i32, f64)> {
+        let left_data = extract_geoms_and_ids(left_partitions);
+        let right_data = extract_geoms_and_ids(right_partitions);
+
+        let mut results = Vec::new();
+
+        for (l_id, l_geom) in left_data {
+            let mut distances: Vec<(i32, f64)> = right_data
+                .iter()
+                .map(|(r_id, r_geom)| (*r_id, Euclidean.distance(&l_geom, 
r_geom)))
+                .collect();
+
+            // Sort by distance, then by ID for stability
+            distances.sort_by(|a, b| a.1.total_cmp(&b.1).then_with(|| 
a.0.cmp(&b.0)));
+
+            for (r_id, dist) in distances.iter().take(k.min(distances.len())) {
+                results.push((l_id, *r_id, *dist));
+            }
+        }
+
+        // Sort results by L.id, R.id
+        results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| a.1.cmp(&b.1)));
+        results
+    }
+
+    #[tokio::test]
+    async fn test_knn_join_correctness() -> Result<()> {
+        // Generate slightly larger data
+        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
+            create_knn_test_data((0.1, 10.0), WKB_GEOMETRY)?;
+
+        let options = SpatialJoinOptions::default();
+        let k = 3;
+
+        let sql1 = format!(
+            "SELECT L.id, R.id, ST_Distance(L.geometry, R.geometry) FROM L 
JOIN R ON ST_KNN(L.geometry, R.geometry, {}, false) ORDER BY L.id, R.id",
+            k
+        );
+        let expected1 = compute_knn_ground_truth(&left_partitions, 
&right_partitions, k)
+            .into_iter()
+            .map(|(l, r, _)| (l, r))
+            .collect::<Vec<_>>();
+
+        let sql2 = format!(
+            "SELECT R.id, L.id, ST_Distance(L.geometry, R.geometry) FROM L 
JOIN R ON ST_KNN(R.geometry, L.geometry, {}, false) ORDER BY R.id, L.id",
+            k
+        );
+        let expected2 = compute_knn_ground_truth(&right_partitions, 
&left_partitions, k)
+            .into_iter()
+            .map(|(l, r, _)| (l, r))
+            .collect::<Vec<_>>();
+
+        let sqls = [(&sql1, &expected1), (&sql2, &expected2)];
+
+        for (sql, expected_results) in sqls {
+            let batches = run_spatial_join_query(
+                &left_schema,
+                &right_schema,
+                left_partitions.clone(),
+                right_partitions.clone(),
+                Some(options.clone()),
+                10,
+                sql,
+            )
+            .await?;
+
+            // Collect actual results
+            let mut actual_results = Vec::new();
+            let combined_batch = 
arrow::compute::concat_batches(&batches.schema(), &[batches])?;
+            let l_ids = combined_batch
+                .column(0)
+                .as_any()
+                .downcast_ref::<arrow_array::Int32Array>()
+                .unwrap();
+            let r_ids = combined_batch
+                .column(1)
+                .as_any()
+                .downcast_ref::<arrow_array::Int32Array>()
+                .unwrap();
+
+            for i in 0..combined_batch.num_rows() {
+                actual_results.push((l_ids.value(i), r_ids.value(i)));
+            }
+            actual_results.sort_by(|a, b| a.0.cmp(&b.0).then_with(|| 
a.1.cmp(&b.1)));
+
+            assert_eq!(actual_results, *expected_results);
+        }
+
+        Ok(())
+    }
 }
diff --git a/rust/sedona-spatial-join/src/index/spatial_index.rs 
b/rust/sedona-spatial-join/src/index/spatial_index.rs
index 0733571b..64f2cce7 100644
--- a/rust/sedona-spatial-join/src/index/spatial_index.rs
+++ b/rust/sedona-spatial-join/src/index/spatial_index.rs
@@ -24,14 +24,18 @@ use arrow_array::RecordBatch;
 use arrow_schema::SchemaRef;
 use datafusion_common::Result;
 use datafusion_execution::memory_pool::{MemoryPool, MemoryReservation};
-use geo_index::rtree::distance::{DistanceMetric, GeometryAccessor};
+use float_next_after::NextAfter;
+use geo::BoundingRect;
+use geo_index::rtree::{
+    distance::{DistanceMetric, GeometryAccessor},
+    util::f64_box_to_f32,
+};
 use geo_index::rtree::{sort::HilbertSort, RTree, RTreeBuilder, RTreeIndex};
 use geo_index::IndexableNum;
-use geo_types::{Point, Rect};
+use geo_types::Rect;
 use parking_lot::Mutex;
 use sedona_expr::statistics::GeoStatistics;
 use sedona_geo::to_geo::item_to_geometry;
-use sedona_geo_generic_alg::algorithm::Centroid;
 use wkb::reader::Wkb;
 
 use crate::{
@@ -318,25 +322,28 @@ impl SpatialIndex {
 
                 // For tie-breakers, create spatial envelope around probe 
centroid and use rtree.search()
 
-                let probe_centroid = 
probe_geom.centroid().unwrap_or(Point::new(0.0, 0.0));
-                let probe_x = probe_centroid.x() as f32;
-                let probe_y = probe_centroid.y() as f32;
-                let max_distance_f32 = match f32::from_f64(max_distance) {
-                    Some(val) => val,
-                    None => {
-                        // If conversion fails, return empty results for this 
probe
-                        return Ok(QueryResultMetrics {
-                            count: 0,
-                            candidate_count: 0,
-                        });
-                    }
+                // Create envelope bounds by expanding the probe bounding box 
by max_distance
+                let Some(rect) = probe_geom.bounding_rect() else {
+                    // If bounding rectangle cannot be computed, return empty 
results
+                    return Ok(QueryResultMetrics {
+                        count: 0,
+                        candidate_count: 0,
+                    });
                 };
 
-                // Create envelope bounds around probe centroid
-                let min_x = probe_x - max_distance_f32;
-                let min_y = probe_y - max_distance_f32;
-                let max_x = probe_x + max_distance_f32;
-                let max_y = probe_y + max_distance_f32;
+                let min = rect.min();
+                let max = rect.max();
+                let (min_x, min_y, max_x, max_y) = f64_box_to_f32(min.x, 
min.y, max.x, max.y);
+                let mut distance_f32 = max_distance as f32;
+                if (distance_f32 as f64) < max_distance {
+                    distance_f32 = distance_f32.next_after(f32::INFINITY);
+                }
+                let (min_x, min_y, max_x, max_y) = (
+                    min_x - distance_f32,
+                    min_y - distance_f32,
+                    max_x + distance_f32,
+                    max_y + distance_f32,
+                );
 
                 // Use rtree.search() with envelope bounds (like the old code)
                 let expanded_results = self.rtree.search(min_x, min_y, max_x, 
max_y);
@@ -1407,8 +1414,33 @@ mod tests {
             )
             .unwrap();
 
-        // Should return more than 2 results because of ties (all 4 points at 
distance sqrt(2))
-        assert!(result_with_ties.count >= 2);
+        // Should return 4 results because of ties (all 4 points at distance 
sqrt(2))
+        assert!(result_with_ties.count == 4);
+
+        // Query using a box centered at the origin
+        let query_geom = create_array(
+            &[Some(
+                "POLYGON ((-0.5 -0.5, -0.5 0.5, 0.5 0.5, 0.5 -0.5, -0.5 
-0.5))",
+            )],
+            &WKB_GEOMETRY,
+        );
+        let query_array = EvaluatedGeometryArray::try_new(query_geom, 
&WKB_GEOMETRY).unwrap();
+        let query_wkb = &query_array.wkbs()[0].as_ref().unwrap();
+
+        // This query should return 4 points
+        let mut build_positions_with_ties = Vec::new();
+        let result_with_ties = index
+            .query_knn(
+                query_wkb,
+                2,     // k=2
+                false, // use_spheroid
+                true,  // include_tie_breakers=true
+                &mut build_positions_with_ties,
+            )
+            .unwrap();
+
+        // Should return 4 results because of ties (all 4 points at distance 
sqrt(2))
+        assert!(result_with_ties.count == 4);
     }
 
     #[test]
diff --git a/rust/sedona-spatial-join/src/stream.rs 
b/rust/sedona-spatial-join/src/stream.rs
index cb8a4e3d..7fa42231 100644
--- a/rust/sedona-spatial-join/src/stream.rs
+++ b/rust/sedona-spatial-join/src/stream.rs
@@ -244,10 +244,10 @@ impl SpatialJoinStream {
         // Extract the necessary data first to avoid borrowing conflicts
         let (batch_opt, is_complete) = match &mut self.state {
             SpatialJoinStreamState::ProcessProbeBatch(iterator) => {
-                // For KNN joins, we swapped build/probe sides, so build_side 
should be Right
-                // For regular joins, build_side is Left
+                // For KNN joins, we may have swapped build/probe sides, so 
build_side might be Right;
+                // For regular joins, build_side is always Left.
                 let build_side = match &self.spatial_predicate {
-                    SpatialPredicate::KNearestNeighbors(_) => JoinSide::Right,
+                    SpatialPredicate::KNearestNeighbors(knn) => 
knn.probe_side.negate(),
                     _ => JoinSide::Left,
                 };
 

Reply via email to