Kontinuation commented on code in PR #593:
URL: https://github.com/apache/sedona-db/pull/593#discussion_r2788908508


##########
rust/sedona-spatial-join/tests/spatial_join_integration.rs:
##########
@@ -0,0 +1,1291 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use std::sync::Arc;
+
+use arrow_array::{Array, RecordBatch};
+use arrow_schema::{DataType, Field, Schema, SchemaRef};
+use datafusion::{
+    catalog::{MemTable, TableProvider},
+    execution::SessionStateBuilder,
+    prelude::{SessionConfig, SessionContext},
+};
+use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
+use datafusion_common::Result;
+use datafusion_expr::{ColumnarValue, JoinType};
+use datafusion_physical_plan::joins::NestedLoopJoinExec;
+use datafusion_physical_plan::ExecutionPlan;
+use geo::{Distance, Euclidean};
+use geo_types::{Coord, Rect};
+use rstest::rstest;
+use sedona_common::SedonaOptions;
+use sedona_geo::to_geo::GeoTypesExecutor;
+use sedona_geometry::types::GeometryTypeId;
+use sedona_schema::datatypes::{SedonaType, WKB_GEOGRAPHY, WKB_GEOMETRY};
+use sedona_spatial_join::{register_spatial_join_optimizer, SpatialJoinExec};
+use sedona_testing::datagen::RandomPartitionedDataBuilder;
+use tokio::sync::OnceCell;
+
+use sedona_common::{
+    option::{add_sedona_option_extension, ExecutionMode, SpatialJoinOptions},
+    NumSpatialPartitionsConfig, SpatialJoinDebugOptions, SpatialLibrary,
+};
+
+type TestPartitions = (SchemaRef, Vec<Vec<RecordBatch>>);
+
+/// Creates standard test data with left (Polygon) and right (Point) partitions
+fn create_default_test_data() -> Result<(TestPartitions, TestPartitions)> {
+    create_test_data_with_size_range((1.0, 10.0), WKB_GEOMETRY)
+}
+
+/// Creates test data with custom size range
+fn create_test_data_with_size_range(
+    size_range: (f64, f64),
+    sedona_type: SedonaType,
+) -> Result<(TestPartitions, TestPartitions)> {
+    let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 
100.0 });
+
+    let left_data = RandomPartitionedDataBuilder::new()
+        .seed(11584)
+        .num_partitions(2)
+        .batches_per_partition(2)
+        .rows_per_batch(30)
+        .geometry_type(GeometryTypeId::Polygon)
+        .sedona_type(sedona_type.clone())
+        .bounds(bounds)
+        .size_range(size_range)
+        .null_rate(0.1)
+        .build()?;
+
+    let right_data = RandomPartitionedDataBuilder::new()
+        .seed(54843)
+        .num_partitions(4)
+        .batches_per_partition(4)
+        .rows_per_batch(30)
+        .geometry_type(GeometryTypeId::Point)
+        .sedona_type(sedona_type)
+        .bounds(bounds)
+        .size_range(size_range)
+        .null_rate(0.1)
+        .build()?;
+
+    Ok((left_data, right_data))
+}
+
+/// Creates test data with empty partitions inserted at beginning and end
+fn create_test_data_with_empty_partitions() -> Result<(TestPartitions, 
TestPartitions)> {
+    let (mut left_data, mut right_data) = create_default_test_data()?;
+
+    // Add empty partitions
+    left_data.1.insert(0, vec![]);
+    left_data.1.push(vec![]);
+    right_data.1.insert(0, vec![]);
+    right_data.1.push(vec![]);
+
+    Ok((left_data, right_data))
+}
+
+/// Creates test data for KNN join (Point-Point)
+fn create_knn_test_data(
+    size_range: (f64, f64),
+    sedona_type: SedonaType,
+) -> Result<(TestPartitions, TestPartitions)> {
+    let bounds = Rect::new(Coord { x: 0.0, y: 0.0 }, Coord { x: 100.0, y: 
100.0 });
+
+    let left_data = RandomPartitionedDataBuilder::new()
+        .seed(1)
+        .num_partitions(2)
+        .batches_per_partition(2)
+        .rows_per_batch(30)
+        .geometry_type(GeometryTypeId::Point)
+        .sedona_type(sedona_type.clone())
+        .bounds(bounds)
+        .size_range(size_range)
+        .null_rate(0.1)
+        .build()?;
+
+    let right_data = RandomPartitionedDataBuilder::new()
+        .seed(2)
+        .num_partitions(4)
+        .batches_per_partition(4)
+        .rows_per_batch(30)
+        .geometry_type(GeometryTypeId::Point)
+        .sedona_type(sedona_type)
+        .bounds(bounds)
+        .size_range(size_range)
+        .null_rate(0.1)
+        .build()?;
+
+    Ok((left_data, right_data))
+}
+
+fn setup_context(options: Option<SpatialJoinOptions>, batch_size: usize) -> 
Result<SessionContext> {
+    let mut session_config = SessionConfig::from_env()?
+        .with_information_schema(true)
+        .with_batch_size(batch_size);
+    session_config = add_sedona_option_extension(session_config);
+    let mut state_builder = SessionStateBuilder::new();
+    if let Some(options) = options {
+        state_builder = register_spatial_join_optimizer(state_builder);
+        let opts = session_config
+            .options_mut()
+            .extensions
+            .get_mut::<SedonaOptions>()
+            .unwrap();
+        opts.spatial_join = options;
+    }
+    let state = state_builder.with_config(session_config).build();
+    let ctx = SessionContext::new_with_state(state);
+
+    let mut function_set = sedona_functions::register::default_function_set();
+    let scalar_kernels = sedona_geos::register::scalar_kernels();
+
+    function_set.scalar_udfs().for_each(|udf| {
+        ctx.register_udf(udf.clone().into());
+    });
+
+    for (name, kernel) in scalar_kernels.into_iter() {
+        let udf = function_set.add_scalar_udf_impl(name, kernel)?;
+        ctx.register_udf(udf.clone().into());
+    }
+
+    Ok(ctx)
+}
+
+#[tokio::test]
+async fn test_empty_data() -> Result<()> {
+    let schema = Arc::new(Schema::new(vec![
+        Field::new("id", DataType::Int32, false),
+        Field::new("dist", DataType::Float64, false),
+        WKB_GEOMETRY.to_storage_field("geometry", true).unwrap(),
+    ]));
+
+    let test_data_vec = vec![vec![vec![]], vec![vec![], vec![]]];
+
+    let options = SpatialJoinOptions::default();
+    let ctx = setup_context(Some(options.clone()), 10)?;
+    for test_data in test_data_vec {
+        let left_partitions = test_data.clone();
+        let right_partitions = test_data;
+
+        let mem_table_left: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
+            Arc::clone(&schema),
+            left_partitions.clone(),
+        )?);
+        let mem_table_right: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
+            Arc::clone(&schema),
+            right_partitions.clone(),
+        )?);
+
+        ctx.deregister_table("L")?;
+        ctx.deregister_table("R")?;
+        ctx.register_table("L", Arc::clone(&mem_table_left))?;
+        ctx.register_table("R", Arc::clone(&mem_table_right))?;
+
+        let sql = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
+        let df = ctx.sql(sql).await?;
+        let result_batches = df.collect().await?;
+        for result_batch in result_batches {
+            assert_eq!(result_batch.num_rows(), 0);
+        }
+    }
+
+    Ok(())
+}
+
+// Shared test data and expected results - computed only once across all 
parameterized test cases
+// Using tokio::sync::OnceCell for async lazy initialization to avoid 
recomputing expensive
+// test data generation and nested loop join results for each test parameter 
combination
+static TEST_DATA: OnceCell<(TestPartitions, TestPartitions)> = 
OnceCell::const_new();
+static RANGE_JOIN_EXPECTED_RESULTS: OnceCell<Vec<RecordBatch>> = 
OnceCell::const_new();
+static DIST_JOIN_EXPECTED_RESULTS: OnceCell<Vec<RecordBatch>> = 
OnceCell::const_new();
+
+const RANGE_JOIN_SQL1: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
+const RANGE_JOIN_SQL2: &str =
+    "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY 
L.id, R.id";
+const RANGE_JOIN_SQLS: &[&str] = &[RANGE_JOIN_SQL1, RANGE_JOIN_SQL2];
+
+const DIST_JOIN_SQL1: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Distance(L.geometry, R.geometry) < 1.0 ORDER BY l_id, r_id";
+const DIST_JOIN_SQL2: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Distance(L.geometry, R.geometry) < L.dist / 10.0 ORDER BY l_id, r_id";
+const DIST_JOIN_SQL3: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Distance(L.geometry, R.geometry) < R.dist / 10.0 ORDER BY l_id, r_id";
+const DIST_JOIN_SQL4: &str = "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_DWithin(L.geometry, R.geometry, 1.0) ORDER BY l_id, r_id";
+const DIST_JOIN_SQLS: &[&str] = &[
+    DIST_JOIN_SQL1,
+    DIST_JOIN_SQL2,
+    DIST_JOIN_SQL3,
+    DIST_JOIN_SQL4,
+];
+
+/// Get test data, computing it only once
+async fn get_default_test_data() -> &'static (TestPartitions, TestPartitions) {
+    TEST_DATA
+        .get_or_init(|| async { create_default_test_data().expect("Failed to 
create test data") })
+        .await
+}
+
+/// Get expected results, computing them only once
+async fn get_expected_range_join_results() -> &'static Vec<RecordBatch> {
+    get_or_init_expected_join_results(&RANGE_JOIN_EXPECTED_RESULTS, 
RANGE_JOIN_SQLS).await
+}
+
+async fn get_expected_distance_join_results() -> &'static Vec<RecordBatch> {
+    get_or_init_expected_join_results(&DIST_JOIN_EXPECTED_RESULTS, 
DIST_JOIN_SQLS).await
+}
+
+async fn get_or_init_expected_join_results<'a>(
+    lazy_init_results: &'a OnceCell<Vec<RecordBatch>>,
+    sql_queries: &[&str],
+) -> &'a Vec<RecordBatch> {
+    lazy_init_results
+        .get_or_init(|| async {
+            let test_data = get_default_test_data().await;
+            let ((left_schema, left_partitions), (right_schema, 
right_partitions)) = test_data;
+
+            let batch_size = 10;
+
+            // Run nested loop join to get expected results
+            let mut expected_results = Vec::with_capacity(sql_queries.len());
+
+            for (i, sql) in sql_queries.iter().enumerate() {
+                let result = run_spatial_join_query(
+                    left_schema,
+                    right_schema,
+                    left_partitions.clone(),
+                    right_partitions.clone(),
+                    None,
+                    batch_size,
+                    sql,
+                )
+                .await
+                .unwrap_or_else(|_| panic!("Failed to generate expected result 
{}", i + 1));
+                expected_results.push(result);
+            }
+
+            expected_results
+        })
+        .await
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_range_join_with_conf(
+    #[values(10, 30, 1000)] max_batch_size: usize,
+    #[values(
+        ExecutionMode::PrepareNone,
+        ExecutionMode::PrepareBuild,
+        ExecutionMode::PrepareProbe,
+        ExecutionMode::Speculative(20)
+    )]
+    execution_mode: ExecutionMode,
+    #[values(SpatialLibrary::Geo, SpatialLibrary::Geos, SpatialLibrary::Tg)]
+    spatial_library: SpatialLibrary,
+) -> Result<()> {
+    let test_data = get_default_test_data().await;
+    let expected_results = get_expected_range_join_results().await;
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) = 
test_data;
+
+    let options = SpatialJoinOptions {
+        spatial_library,
+        execution_mode,
+        ..Default::default()
+    };
+    for (idx, sql) in RANGE_JOIN_SQLS.iter().enumerate() {
+        let actual_result = run_spatial_join_query(
+            left_schema,
+            right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            Some(options.clone()),
+            max_batch_size,
+            sql,
+        )
+        .await?;
+        assert_eq!(&actual_result, &expected_results[idx]);
+    }
+
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_distance_join_with_conf(
+    #[values(30, 1000)] max_batch_size: usize,
+    #[values(SpatialLibrary::Geo, SpatialLibrary::Geos, SpatialLibrary::Tg)]
+    spatial_library: SpatialLibrary,
+) -> Result<()> {
+    let test_data = get_default_test_data().await;
+    let expected_results = get_expected_distance_join_results().await;
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) = 
test_data;
+
+    let options = SpatialJoinOptions {
+        spatial_library,
+        ..Default::default()
+    };
+    for (idx, sql) in DIST_JOIN_SQLS.iter().enumerate() {
+        let actual_result = run_spatial_join_query(
+            left_schema,
+            right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            Some(options.clone()),
+            max_batch_size,
+            sql,
+        )
+        .await?;
+        assert_eq!(&actual_result, &expected_results[idx]);
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_spatial_join_with_filter() -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
+
+    for max_batch_size in [10, 30, 100] {
+        let options = SpatialJoinOptions::default();
+        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+            "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) 
AND L.dist < R.dist ORDER BY L.id, R.id").await?;
+        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+            "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) AND L.dist < R.dist ORDER BY l_id, 
r_id").await?;
+        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+            "SELECT L.id l_id, R.id r_id, L.dist l_dist, R.dist r_dist FROM L 
JOIN R ON ST_Intersects(L.geometry, R.geometry) AND L.dist < R.dist ORDER BY 
l_id, r_id").await?;
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_range_join_with_empty_partitions() -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_empty_partitions()?;
+
+    for max_batch_size in [10, 30, 1000] {
+        let options = SpatialJoinOptions::default();
+        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+            "SELECT L.id l_id, R.id r_id FROM L JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id").await?;
+        test_spatial_join_query(
+            &left_schema,
+            &right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            &options,
+            max_batch_size,
+            "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) 
ORDER BY L.id, R.id",
+        )
+        .await?;
+    }
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_inner_join() -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    test_with_join_types(JoinType::Inner, options, 30).await?;
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_left_joins(
+    #[values(JoinType::Left, JoinType::LeftSemi, JoinType::LeftAnti)] 
join_type: JoinType,
+) -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    test_with_join_types(join_type, options, 30).await?;
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_right_joins(
+    #[values(JoinType::Right, JoinType::RightSemi, JoinType::RightAnti)] 
join_type: JoinType,
+) -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    test_with_join_types(join_type, options, 30).await?;
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_full_outer_join() -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    test_with_join_types(JoinType::Full, options, 30).await?;
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_geography_join_is_not_optimized() -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    let ctx = setup_context(Some(options), 10)?;
+
+    // Prepare geography tables
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_size_range((0.1, 10.0), WKB_GEOGRAPHY)?;
+    let mem_table_left: Arc<dyn TableProvider> =
+        Arc::new(MemTable::try_new(left_schema, left_partitions)?);
+    let mem_table_right: Arc<dyn TableProvider> =
+        Arc::new(MemTable::try_new(right_schema, right_partitions)?);
+    ctx.register_table("L", mem_table_left)?;
+    ctx.register_table("R", mem_table_right)?;
+
+    // Execute geography join query
+    let df = ctx
+        .sql("SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry)")
+        .await?;
+    let plan = df.create_physical_plan().await?;
+
+    // Verify that no SpatialJoinExec is present (geography join should not be 
optimized)
+    let spatial_joins = collect_spatial_join_exec(&plan)?;
+    assert!(
+        spatial_joins.is_empty(),
+        "Geography joins should not be optimized to SpatialJoinExec"
+    );
+
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_query_window_in_subquery() -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_size_range((50.0, 60.0), WKB_GEOMETRY)?;
+    let options = SpatialJoinOptions::default();
+    test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, 10,
+            "SELECT id FROM L WHERE ST_Intersects(L.geometry, (SELECT 
R.geometry FROM R WHERE R.id = 1))").await?;
+    Ok(())
+}
+
+#[tokio::test]
+async fn test_parallel_refinement_for_large_candidate_set() -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_size_range((1.0, 50.0), WKB_GEOMETRY)?;
+
+    for max_batch_size in [10, 30, 100] {
+        let options = SpatialJoinOptions {
+            parallel_refinement_chunk_size: 10,
+            ..Default::default()
+        };
+        test_spatial_join_query(&left_schema, &right_schema, 
left_partitions.clone(), right_partitions.clone(), &options, max_batch_size,
+            "SELECT * FROM L JOIN R ON ST_Intersects(L.geometry, R.geometry) 
AND L.dist < R.dist ORDER BY L.id, R.id").await?;
+    }
+
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_spatial_partitioned_range_join(
+    #[values(10, 30, 1000)] max_batch_size: usize,
+    #[values(
+        ExecutionMode::PrepareNone,
+        ExecutionMode::PrepareBuild,
+        ExecutionMode::PrepareProbe,
+        ExecutionMode::Speculative(20)
+    )]
+    execution_mode: ExecutionMode,
+    #[values(SpatialLibrary::Geo, SpatialLibrary::Geos, SpatialLibrary::Tg)]
+    spatial_library: SpatialLibrary,
+) -> Result<()> {
+    let test_data = get_default_test_data().await;
+    let expected_results = get_expected_range_join_results().await;
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) = 
test_data;
+
+    let debug = SpatialJoinDebugOptions {
+        num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(4),
+        force_spill: true,
+        memory_for_intermittent_usage: None,
+        ..Default::default()
+    };
+    let options = SpatialJoinOptions {
+        spatial_library,
+        execution_mode,
+        debug,
+        ..Default::default()
+    };
+
+    for (idx, sql) in RANGE_JOIN_SQLS.iter().enumerate() {
+        let actual_result = run_spatial_join_query(
+            left_schema,
+            right_schema,
+            left_partitions.clone(),
+            right_partitions.clone(),
+            Some(options.clone()),
+            max_batch_size,
+            sql,
+        )
+        .await?;
+        assert_eq!(&actual_result, &expected_results[idx]);
+    }
+
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_spatial_partitioned_outer_join(
+    #[values(10, 30, 1000)] batch_size: usize,
+    #[values(
+        JoinType::Left,
+        JoinType::Right,
+        JoinType::Full,
+        JoinType::LeftSemi,
+        JoinType::LeftAnti,
+        JoinType::RightSemi,
+        JoinType::RightAnti
+    )]
+    join_type: JoinType,
+) -> Result<()> {
+    let debug = SpatialJoinDebugOptions {
+        num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(4),
+        force_spill: true,
+        memory_for_intermittent_usage: None,
+        ..Default::default()
+    };
+    let options = SpatialJoinOptions {
+        debug,
+        ..Default::default()
+    };
+
+    test_with_join_types(join_type, options, batch_size).await?;
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_mark_joins(
+    #[values(JoinType::LeftMark, JoinType::RightMark)] join_type: JoinType,
+) -> Result<()> {
+    let options = SpatialJoinOptions::default();
+    test_mark_join(join_type, options, 10).await?;
+    Ok(())
+}
+
+#[rstest]
+#[tokio::test]
+async fn test_spatial_partitioned_mark_joins(
+    #[values(JoinType::LeftMark, JoinType::RightMark)] join_type: JoinType,
+) -> Result<()> {
+    let debug = SpatialJoinDebugOptions {
+        num_spatial_partitions: NumSpatialPartitionsConfig::Fixed(4),
+        force_spill: true,
+        memory_for_intermittent_usage: None,
+        ..Default::default()
+    };
+    let options = SpatialJoinOptions {
+        debug,
+        ..Default::default()
+    };
+    test_mark_join(join_type, options, 10).await?;
+    Ok(())
+}
+
+async fn test_with_join_types(
+    join_type: JoinType,
+    options: SpatialJoinOptions,
+    batch_size: usize,
+) -> Result<RecordBatch> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_empty_partitions()?;
+
+    let inner_sql = "SELECT L.id l_id, R.id r_id FROM L INNER JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id";
+    let sql = match join_type {
+        JoinType::Inner => inner_sql,
+        JoinType::Left => "SELECT L.id l_id, R.id r_id FROM L LEFT JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
+        JoinType::Right => "SELECT L.id l_id, R.id r_id FROM L RIGHT JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
+        JoinType::Full => "SELECT L.id l_id, R.id r_id FROM L FULL OUTER JOIN 
R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
+        JoinType::LeftSemi => "SELECT L.id l_id FROM L LEFT SEMI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id",
+        JoinType::RightSemi => "SELECT R.id r_id FROM L RIGHT SEMI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY r_id",
+        JoinType::LeftAnti => "SELECT L.id l_id FROM L LEFT ANTI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id",
+        JoinType::RightAnti => "SELECT R.id r_id FROM L RIGHT ANTI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY r_id",
+        JoinType::LeftMark => {
+            unreachable!("LeftMark is not directly supported in SQL, will be 
tested in other tests");
+        }
+        JoinType::RightMark => {
+            unreachable!("RightMark is not directly supported in SQL, will be 
tested in other tests");
+        }
+    };
+
+    let batches = test_spatial_join_query(
+        &left_schema,
+        &right_schema,
+        left_partitions.clone(),
+        right_partitions.clone(),
+        &options,
+        batch_size,
+        sql,
+    )
+    .await?;
+
+    if matches!(join_type, JoinType::Left | JoinType::Right | JoinType::Full) {
+        // Make sure that we are effectively testing outer joins. If outer 
joins produces the same result as inner join,
+        // it means that the test data is not suitable for testing outer joins.
+        let inner_batches = run_spatial_join_query(
+            &left_schema,
+            &right_schema,
+            left_partitions,
+            right_partitions,
+            Some(options),
+            batch_size,
+            inner_sql,
+        )
+        .await?;
+        assert!(inner_batches.num_rows() < batches.num_rows());
+    }
+
+    Ok(batches)
+}
+
+async fn test_spatial_join_query(
+    left_schema: &SchemaRef,
+    right_schema: &SchemaRef,
+    left_partitions: Vec<Vec<RecordBatch>>,
+    right_partitions: Vec<Vec<RecordBatch>>,
+    options: &SpatialJoinOptions,
+    batch_size: usize,
+    sql: &str,
+) -> Result<RecordBatch> {
+    // Run spatial join using SpatialJoinExec
+    let actual = run_spatial_join_query(
+        left_schema,
+        right_schema,
+        left_partitions.clone(),
+        right_partitions.clone(),
+        Some(options.clone()),
+        batch_size,
+        sql,
+    )
+    .await?;
+
+    // Run spatial join using NestedLoopJoinExec
+    let expected = run_spatial_join_query(
+        left_schema,
+        right_schema,
+        left_partitions.clone(),
+        right_partitions.clone(),
+        None,
+        batch_size,
+        sql,
+    )
+    .await?;
+
+    // Should produce the same result
+    assert!(expected.num_rows() > 0);
+    assert_eq!(expected, actual);
+
+    Ok(actual)
+}
+
+async fn run_spatial_join_query(
+    left_schema: &SchemaRef,
+    right_schema: &SchemaRef,
+    left_partitions: Vec<Vec<RecordBatch>>,
+    right_partitions: Vec<Vec<RecordBatch>>,
+    options: Option<SpatialJoinOptions>,
+    batch_size: usize,
+    sql: &str,
+) -> Result<RecordBatch> {
+    let mem_table_left: Arc<dyn TableProvider> =
+        Arc::new(MemTable::try_new(left_schema.to_owned(), left_partitions)?);
+    let mem_table_right: Arc<dyn TableProvider> = Arc::new(MemTable::try_new(
+        right_schema.to_owned(),
+        right_partitions,
+    )?);
+
+    let is_optimized_spatial_join = options.is_some();
+    let ctx = setup_context(options, batch_size)?;
+    ctx.register_table("L", Arc::clone(&mem_table_left))?;
+    ctx.register_table("R", Arc::clone(&mem_table_right))?;
+    let df = ctx.sql(sql).await?;
+    let actual_schema = df.schema().as_arrow().clone();
+    let plan = df.clone().create_physical_plan().await?;
+    let spatial_join_execs = collect_spatial_join_exec(&plan)?;
+    if is_optimized_spatial_join {
+        assert_eq!(spatial_join_execs.len(), 1);
+    } else {
+        assert!(spatial_join_execs.is_empty());
+    }
+    let result_batches = df.collect().await?;
+    let result_batch = 
arrow::compute::concat_batches(&Arc::new(actual_schema), &result_batches)?;
+    Ok(result_batch)
+}
+
+fn collect_spatial_join_exec(plan: &Arc<dyn ExecutionPlan>) -> 
Result<Vec<&SpatialJoinExec>> {
+    let mut spatial_join_execs = Vec::new();
+    plan.apply(|node| {
+        if let Some(spatial_join_exec) = 
node.as_any().downcast_ref::<SpatialJoinExec>() {
+            spatial_join_execs.push(spatial_join_exec);
+        }
+        Ok(TreeNodeRecursion::Continue)
+    })?;
+    Ok(spatial_join_execs)
+}
+
+async fn test_mark_join(
+    join_type: JoinType,
+    options: SpatialJoinOptions,
+    batch_size: usize,
+) -> Result<()> {
+    let ((left_schema, left_partitions), (right_schema, right_partitions)) =
+        create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
+    let mem_table_left: Arc<dyn TableProvider> = Arc::new(MemTable::try_new(
+        left_schema.clone(),
+        left_partitions.clone(),
+    )?);
+    let mem_table_right: Arc<dyn TableProvider> = Arc::new(MemTable::try_new(
+        right_schema.clone(),
+        right_partitions.clone(),
+    )?);
+
+    // We use a Left Join as a template to create the plan, then modify it to 
Mark Join
+    let sql = "SELECT * FROM L LEFT JOIN R ON ST_Intersects(L.geometry, 
R.geometry)";
+
+    // Create SpatialJoinExec plan
+    let ctx = setup_context(Some(options.clone()), batch_size)?;
+    ctx.register_table("L", mem_table_left.clone())?;
+    ctx.register_table("R", mem_table_right.clone())?;
+    let df = ctx.sql(sql).await?;
+    let plan = df.create_physical_plan().await?;
+    let spatial_join_execs = collect_spatial_join_exec(&plan)?;
+    assert_eq!(spatial_join_execs.len(), 1);
+    let original_exec = spatial_join_execs[0];
+    let mark_exec = SpatialJoinExec::try_new(
+        original_exec.left.clone(),
+        original_exec.right.clone(),
+        original_exec.on.clone(),
+        original_exec.filter.clone(),
+        &join_type,
+        None,
+        &options,
+    )?;

Review Comment:
   Strange, actually the tests compiled fine.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to