This is an automated email from the ASF dual-hosted git repository.

kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git


The following commit(s) were added to refs/heads/main by this push:
     new 931558e3 chore(rust/sedona-spatial-join): Support right mark joins and 
enabled tests for all kinds of spatial joins (#514)
931558e3 is described below

commit 931558e3f710330f35103b7ef61a6bd580d435ba
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Thu Jan 15 23:56:58 2026 +0800

    chore(rust/sedona-spatial-join): Support right mark joins and enabled tests 
for all kinds of spatial joins (#514)
    
    This patch updates join_utils.rs to support RightMark join. Spatial join 
tests for all join types were enabled.
---
 python/sedonadb/tests/test_sjoin.py              | 167 +++++++++++++++++++++
 rust/sedona-spatial-join/src/exec.rs             | 183 +++++++++++++++++++++--
 rust/sedona-spatial-join/src/stream.rs           |   2 +
 rust/sedona-spatial-join/src/utils/join_utils.rs | 108 ++++++++-----
 4 files changed, 412 insertions(+), 48 deletions(-)

diff --git a/python/sedonadb/tests/test_sjoin.py 
b/python/sedonadb/tests/test_sjoin.py
index beb412ce..9169fbfb 100644
--- a/python/sedonadb/tests/test_sjoin.py
+++ b/python/sedonadb/tests/test_sjoin.py
@@ -79,6 +79,173 @@ def test_spatial_join(join_type, on):
         eng_postgis.assert_query_result(sql, sedonadb_results)
 
 
[email protected](
+    "join_type",
+    [
+        "LEFT SEMI JOIN",
+        "LEFT ANTI JOIN",
+        "RIGHT SEMI JOIN",
+        "RIGHT ANTI JOIN",
+    ],
+)
[email protected](
+    "on",
+    [
+        "ST_Intersects(sjoin_point.geometry, sjoin_polygon.geometry)",
+        "ST_Within(sjoin_point.geometry, sjoin_polygon.geometry)",
+        "ST_Contains(sjoin_polygon.geometry, sjoin_point.geometry)",
+        "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry, 1.0)",
+        "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry, 
sjoin_point.dist / 100)",
+        "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry, 
sjoin_polygon.dist / 100)",
+    ],
+)
+def test_spatial_join_semi_anti(join_type, on):
+    with (
+        SedonaDB.create_or_skip() as eng_sedonadb,
+        PostGIS.create_or_skip() as eng_postgis,
+    ):
+        options = json.dumps(
+            {
+                "geom_type": "Point",
+                "polygon_hole_rate": 0.5,
+                "num_parts_range": [2, 10],
+                "vertices_per_linestring_range": [2, 10],
+                "seed": 42,
+            }
+        )
+        df_point = eng_sedonadb.execute_and_collect(
+            f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100"
+        )
+        options = json.dumps(
+            {
+                "geom_type": "Polygon",
+                "polygon_hole_rate": 0.5,
+                "num_parts_range": [2, 10],
+                "vertices_per_linestring_range": [2, 10],
+                "seed": 43,
+            }
+        )
+        df_polygon = eng_sedonadb.execute_and_collect(
+            f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100"
+        )
+        eng_sedonadb.create_table_arrow("sjoin_point", df_point)
+        eng_sedonadb.create_table_arrow("sjoin_polygon", df_polygon)
+        eng_postgis.create_table_arrow("sjoin_point", df_point)
+        eng_postgis.create_table_arrow("sjoin_polygon", df_polygon)
+
+        is_left = join_type.startswith("LEFT")
+        is_semi = "SEMI" in join_type
+
+        if is_left:
+            sedona_sql = f"""
+                SELECT sjoin_point.id id0
+                FROM sjoin_point {join_type} sjoin_polygon
+                ON {on}
+                ORDER BY id0
+            """
+            exists = f"EXISTS (SELECT 1 FROM sjoin_polygon WHERE {on})"
+            where = exists if is_semi else f"NOT {exists}"
+            postgis_sql = f"""
+                SELECT sjoin_point.id id0
+                FROM sjoin_point
+                WHERE {where}
+                ORDER BY id0
+            """
+        else:
+            sedona_sql = f"""
+                SELECT sjoin_polygon.id id1
+                FROM sjoin_point {join_type} sjoin_polygon
+                ON {on}
+                ORDER BY id1
+            """
+            exists = f"EXISTS (SELECT 1 FROM sjoin_point WHERE {on})"
+            where = exists if is_semi else f"NOT {exists}"
+            postgis_sql = f"""
+                SELECT sjoin_polygon.id id1
+                FROM sjoin_polygon
+                WHERE {where}
+                ORDER BY id1
+            """
+
+        sedonadb_results = 
eng_sedonadb.execute_and_collect(sedona_sql).to_pandas()
+        assert len(sedonadb_results) > 0
+        eng_postgis.assert_query_result(postgis_sql, sedonadb_results)
+
+
[email protected](
+    "outer",
+    ["point", "polygon"],
+)
[email protected](
+    "on",
+    [
+        "ST_Intersects(sjoin_point.geometry, sjoin_polygon.geometry)",
+        "ST_Within(sjoin_point.geometry, sjoin_polygon.geometry)",
+        "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry, 1.0)",
+    ],
+)
+def test_spatial_mark_join_via_correlated_exists(outer, on):
+    with (
+        SedonaDB.create_or_skip() as eng_sedonadb,
+        PostGIS.create_or_skip() as eng_postgis,
+    ):
+        options = json.dumps(
+            {
+                "geom_type": "Point",
+                "polygon_hole_rate": 0.5,
+                "num_parts_range": [2, 10],
+                "vertices_per_linestring_range": [2, 10],
+                "seed": 42,
+            }
+        )
+        df_point = eng_sedonadb.execute_and_collect(
+            f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100"
+        )
+        options = json.dumps(
+            {
+                "geom_type": "Polygon",
+                "polygon_hole_rate": 0.5,
+                "num_parts_range": [2, 10],
+                "vertices_per_linestring_range": [2, 10],
+                "seed": 43,
+            }
+        )
+        df_polygon = eng_sedonadb.execute_and_collect(
+            f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100"
+        )
+        eng_sedonadb.create_table_arrow("sjoin_point", df_point)
+        eng_sedonadb.create_table_arrow("sjoin_polygon", df_polygon)
+        eng_postgis.create_table_arrow("sjoin_point", df_point)
+        eng_postgis.create_table_arrow("sjoin_polygon", df_polygon)
+
+        if outer == "point":
+            sql = f"""
+                SELECT sjoin_point.id id0
+                FROM sjoin_point
+                WHERE sjoin_point.id = 1 OR EXISTS (SELECT 1 FROM 
sjoin_polygon WHERE {on})
+                ORDER BY id0
+            """
+        else:
+            sql = f"""
+                SELECT sjoin_polygon.id id1, 
ST_AsBinary(sjoin_polygon.geometry) geom
+                FROM sjoin_polygon
+                WHERE sjoin_polygon.id = 1 OR EXISTS (SELECT 1 FROM 
sjoin_point WHERE {on})
+                ORDER BY id1
+            """
+
+        # Verify the physical query plan contains a Mark join
+        query_plan = eng_sedonadb.execute_and_collect(f"EXPLAIN 
{sql}").to_pandas()
+        plan_text = "\n".join(query_plan.iloc[:, 1].astype(str).tolist())
+        assert any(
+            "SpatialJoinExec" in line and ("LeftMark" in line or "RightMark" 
in line)
+            for line in plan_text.splitlines()
+        ), plan_text
+
+        sedonadb_results = eng_sedonadb.execute_and_collect(sql).to_pandas()
+        assert len(sedonadb_results) > 0
+        eng_postgis.assert_query_result(sql, sedonadb_results)
+
+
 @pytest.mark.parametrize(
     "join_type", ["INNER JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN"]
 )
diff --git a/rust/sedona-spatial-join/src/exec.rs 
b/rust/sedona-spatial-join/src/exec.rs
index e3440cd1..5cdea16d 100644
--- a/rust/sedona-spatial-join/src/exec.rs
+++ b/rust/sedona-spatial-join/src/exec.rs
@@ -272,7 +272,7 @@ impl SpatialJoinExec {
             match knn.probe_side {
                 JoinSide::Left => left.output_partitioning().clone(),
                 JoinSide::Right => right.output_partitioning().clone(),
-                _ => asymmetric_join_output_partitioning(left, right, 
&join_type),
+                _ => asymmetric_join_output_partitioning(left, right, 
&join_type)?,
             }
         } else if converted_from_hash_join {
             // Replicate HashJoin's symmetric partitioning logic
@@ -290,10 +290,10 @@ impl SpatialJoinExec {
                     // For full outer join, we can't preserve partitioning
                     
Partitioning::UnknownPartitioning(left.output_partitioning().partition_count())
                 }
-                _ => asymmetric_join_output_partitioning(left, right, 
&join_type),
+                _ => asymmetric_join_output_partitioning(left, right, 
&join_type)?,
             }
         } else {
-            asymmetric_join_output_partitioning(left, right, &join_type)
+            asymmetric_join_output_partitioning(left, right, &join_type)?
         };
 
         if let Some(projection) = projection {
@@ -615,6 +615,7 @@ mod tests {
     };
     use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
     use datafusion_expr::ColumnarValue;
+    use datafusion_physical_plan::joins::NestedLoopJoinExec;
     use geo::{Distance, Euclidean};
     use geo_types::{Coord, Rect};
     use rstest::rstest;
@@ -996,7 +997,7 @@ mod tests {
     #[rstest]
     #[tokio::test]
     async fn test_left_joins(
-        #[values(JoinType::Left, /* JoinType::LeftSemi, JoinType::LeftAnti 
*/)] join_type: JoinType,
+        #[values(JoinType::Left, JoinType::LeftSemi, JoinType::LeftAnti)] 
join_type: JoinType,
     ) -> Result<()> {
         test_with_join_types(join_type).await?;
         Ok(())
@@ -1005,8 +1006,7 @@ mod tests {
     #[rstest]
     #[tokio::test]
     async fn test_right_joins(
-        #[values(JoinType::Right, /* JoinType::RightSemi, JoinType::RightAnti 
*/)]
-        join_type: JoinType,
+        #[values(JoinType::Right, JoinType::RightSemi, JoinType::RightAnti)] 
join_type: JoinType,
     ) -> Result<()> {
         test_with_join_types(join_type).await?;
         Ok(())
@@ -1018,6 +1018,82 @@ mod tests {
         Ok(())
     }
 
+    #[rstest]
+    #[tokio::test]
+    async fn test_mark_joins(
+        #[values(JoinType::LeftMark, JoinType::RightMark)] join_type: JoinType,
+    ) -> Result<()> {
+        let options = SpatialJoinOptions::default();
+        test_mark_join(join_type, options, 10).await?;
+        Ok(())
+    }
+
+    #[tokio::test]
+    async fn test_mark_join_via_correlated_exists_sql() -> Result<()> {
+        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
+            create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
+
+        let mem_table_left: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
+            left_schema.clone(),
+            left_partitions.clone(),
+        )?);
+        let mem_table_right: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
+            right_schema.clone(),
+            right_partitions.clone(),
+        )?);
+
+        // DataFusion doesn't have explicit SQL syntax for MARK joins. 
Predicate subqueries embedded
+        // in a more complex boolean expression (e.g. OR) are planned using a 
MARK join.
+        //
+        // Using EXISTS here (rather than IN) keeps the join filter as the 
pulled-up correlated
+        // predicate (ST_Intersects), which is what SpatialJoinExec can 
optimize.
+        let sql = "SELECT L.id FROM L WHERE L.id = 1 OR EXISTS (SELECT 1 FROM 
R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY L.id";
+
+        let batch_size = 10;
+        let options = SpatialJoinOptions::default();
+
+        // Optimized plan should include a SpatialJoinExec with Mark join type.
+        let ctx = setup_context(Some(options), batch_size)?;
+        ctx.register_table("L", Arc::clone(&mem_table_left))?;
+        ctx.register_table("R", Arc::clone(&mem_table_right))?;
+        let df = ctx.sql(sql).await?;
+        let plan = df.clone().create_physical_plan().await?;
+        let spatial_join_execs = collect_spatial_join_exec(&plan)?;
+        assert!(
+            spatial_join_execs
+                .iter()
+                .any(|exec| matches!(*exec.join_type(), JoinType::LeftMark | 
JoinType::RightMark)),
+            "expected correlated IN-subquery to plan using a MARK join when 
optimized"
+        );
+        let actual_schema = df.schema().as_arrow().clone();
+        let actual_batches = df.collect().await?;
+        let actual_batch =
+            arrow::compute::concat_batches(&Arc::new(actual_schema), 
&actual_batches)?;
+
+        // Unoptimized plan should still contain a Mark join, but implemented 
as NestedLoopJoinExec.
+        let ctx_no_opt = setup_context(None, batch_size)?;
+        ctx_no_opt.register_table("L", mem_table_left)?;
+        ctx_no_opt.register_table("R", mem_table_right)?;
+        let df_no_opt = ctx_no_opt.sql(sql).await?;
+        let plan_no_opt = df_no_opt.clone().create_physical_plan().await?;
+        let nlj_execs = collect_nested_loop_join_exec(&plan_no_opt)?;
+        assert!(
+            nlj_execs
+                .iter()
+                .any(|exec| matches!(*exec.join_type(), JoinType::LeftMark | 
JoinType::RightMark)),
+            "expected correlated IN-subquery to plan using a MARK join when 
not optimized"
+        );
+        let expected_schema = df_no_opt.schema().as_arrow().clone();
+        let expected_batches = df_no_opt.collect().await?;
+        let expected_batch =
+            arrow::compute::concat_batches(&Arc::new(expected_schema), 
&expected_batches)?;
+
+        assert!(expected_batch.num_rows() > 0);
+        assert_eq!(expected_batch, actual_batch);
+
+        Ok(())
+    }
+
     #[tokio::test]
     async fn test_geography_join_is_not_optimized() -> Result<()> {
         let options = SpatialJoinOptions::default();
@@ -1075,10 +1151,10 @@ mod tests {
             JoinType::Left => "SELECT L.id l_id, R.id r_id FROM L LEFT JOIN R 
ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
             JoinType::Right => "SELECT L.id l_id, R.id r_id FROM L RIGHT JOIN 
R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
             JoinType::Full => "SELECT L.id l_id, R.id r_id FROM L FULL OUTER 
JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
-            JoinType::LeftSemi => "SELECT L.id l_id FROM L WHERE EXISTS 
(SELECT 1 FROM R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY l_id",
-            JoinType::RightSemi => "SELECT R.id r_id FROM R WHERE EXISTS 
(SELECT 1 FROM L WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY r_id",
-            JoinType::LeftAnti => "SELECT L.id l_id FROM L WHERE NOT EXISTS 
(SELECT 1 FROM R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY l_id",
-            JoinType::RightAnti => "SELECT R.id r_id FROM R WHERE NOT EXISTS 
(SELECT 1 FROM L WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY r_id",
+            JoinType::LeftSemi => "SELECT L.id l_id FROM L LEFT SEMI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id",
+            JoinType::RightSemi => "SELECT R.id r_id FROM L RIGHT SEMI JOIN R 
ON ST_Intersects(L.geometry, R.geometry) ORDER BY r_id",
+            JoinType::LeftAnti => "SELECT L.id l_id FROM L LEFT ANTI JOIN R ON 
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id",
+            JoinType::RightAnti => "SELECT R.id r_id FROM L RIGHT ANTI JOIN R 
ON ST_Intersects(L.geometry, R.geometry) ORDER BY r_id",
             JoinType::LeftMark => {
                 unreachable!("LeftMark is not directly supported in SQL, will 
be tested in other tests");
             }
@@ -1203,6 +1279,93 @@ mod tests {
         Ok(spatial_join_execs)
     }
 
+    fn collect_nested_loop_join_exec(
+        plan: &Arc<dyn ExecutionPlan>,
+    ) -> Result<Vec<&NestedLoopJoinExec>> {
+        let mut execs = Vec::new();
+        plan.apply(|node| {
+            if let Some(exec) = 
node.as_any().downcast_ref::<NestedLoopJoinExec>() {
+                execs.push(exec);
+            }
+            Ok(TreeNodeRecursion::Continue)
+        })?;
+        Ok(execs)
+    }
+
+    async fn test_mark_join(
+        join_type: JoinType,
+        options: SpatialJoinOptions,
+        batch_size: usize,
+    ) -> Result<()> {
+        let ((left_schema, left_partitions), (right_schema, right_partitions)) 
=
+            create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
+        let mem_table_left: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
+            left_schema.clone(),
+            left_partitions.clone(),
+        )?);
+        let mem_table_right: Arc<dyn TableProvider> = 
Arc::new(MemTable::try_new(
+            right_schema.clone(),
+            right_partitions.clone(),
+        )?);
+
+        // We use a Left Join as a template to create the plan, then modify it 
to Mark Join
+        let sql = "SELECT * FROM L LEFT JOIN R ON ST_Intersects(L.geometry, 
R.geometry)";
+
+        // Create SpatialJoinExec plan
+        let ctx = setup_context(Some(options), batch_size)?;
+        ctx.register_table("L", mem_table_left.clone())?;
+        ctx.register_table("R", mem_table_right.clone())?;
+        let df = ctx.sql(sql).await?;
+        let plan = df.create_physical_plan().await?;
+        let spatial_join_execs = collect_spatial_join_exec(&plan)?;
+        assert_eq!(spatial_join_execs.len(), 1);
+        let original_exec = spatial_join_execs[0];
+        let mark_exec = SpatialJoinExec::try_new(
+            original_exec.left.clone(),
+            original_exec.right.clone(),
+            original_exec.on.clone(),
+            original_exec.filter.clone(),
+            &join_type,
+            None,
+        )?;
+
+        // Create NestedLoopJoinExec plan for comparison
+        let ctx_no_opt = setup_context(None, batch_size)?;
+        ctx_no_opt.register_table("L", mem_table_left)?;
+        ctx_no_opt.register_table("R", mem_table_right)?;
+        let df_no_opt = ctx_no_opt.sql(sql).await?;
+        let plan_no_opt = df_no_opt.create_physical_plan().await?;
+        let nlj_execs = collect_nested_loop_join_exec(&plan_no_opt)?;
+        assert_eq!(nlj_execs.len(), 1);
+        let original_nlj = nlj_execs[0];
+        let mark_nlj = NestedLoopJoinExec::try_new(
+            original_nlj.children()[0].clone(),
+            original_nlj.children()[1].clone(),
+            original_nlj.filter().cloned(),
+            &join_type,
+            None,
+        )?;
+
+        async fn run_and_sort(
+            plan: Arc<dyn ExecutionPlan>,
+            ctx: &SessionContext,
+        ) -> Result<RecordBatch> {
+            let results = datafusion_physical_plan::collect(plan, 
ctx.task_ctx()).await?;
+            let batch = arrow::compute::concat_batches(&results[0].schema(), 
&results)?;
+            let sort_col = batch.column(0);
+            let indices = arrow::compute::sort_to_indices(sort_col, None, 
None)?;
+            let sorted_batch = arrow::compute::take_record_batch(&batch, 
&indices)?;
+            Ok(sorted_batch)
+        }
+
+        // Run both Mark Join plans and compare results
+        let mark_batch = run_and_sort(Arc::new(mark_exec), &ctx).await?;
+        let mark_nlj_batch = run_and_sort(Arc::new(mark_nlj), 
&ctx_no_opt).await?;
+        assert_eq!(mark_batch, mark_nlj_batch);
+
+        Ok(())
+    }
+
     fn extract_geoms_and_ids(partitions: &[Vec<RecordBatch>]) -> Vec<(i32, 
geo::Geometry<f64>)> {
         let mut result = Vec::new();
         for partition in partitions {
diff --git a/rust/sedona-spatial-join/src/stream.rs 
b/rust/sedona-spatial-join/src/stream.rs
index 7fa42231..37a84523 100644
--- a/rust/sedona-spatial-join/src/stream.rs
+++ b/rust/sedona-spatial-join/src/stream.rs
@@ -698,6 +698,7 @@ impl SpatialJoinBatchIterator {
             &probe_indices,
             column_indices,
             build_side,
+            join_type,
         )?;
 
         // Update metrics with actual output
@@ -896,6 +897,7 @@ impl UnmatchedBuildBatchIterator {
                     &right_side,
                     column_indices,
                     build_side,
+                    join_type,
                 )?
             };
 
diff --git a/rust/sedona-spatial-join/src/utils/join_utils.rs 
b/rust/sedona-spatial-join/src/utils/join_utils.rs
index 83ec18f4..87aaa9ae 100644
--- a/rust/sedona-spatial-join/src/utils/join_utils.rs
+++ b/rust/sedona-spatial-join/src/utils/join_utils.rs
@@ -16,14 +16,15 @@
 // under the License.
 
 /// Most of the code in this module are copied from the 
`datafusion_physical_plan::joins::utils` module.
-/// 
https://github.com/apache/datafusion/blob/48.0.0/datafusion/physical-plan/src/joins/utils.rs
+/// 
https://github.com/apache/datafusion/blob/50.2.0/datafusion/physical-plan/src/joins/utils.rs
 use std::{ops::Range, sync::Arc};
 
 use arrow::array::{
     downcast_array, new_null_array, Array, BooleanBufferBuilder, RecordBatch, 
RecordBatchOptions,
     UInt32Builder, UInt64Builder,
 };
-use arrow::compute;
+use arrow::buffer::NullBuffer;
+use arrow::compute::{self, take};
 use arrow::datatypes::{ArrowNativeType, Schema, UInt32Type, UInt64Type};
 use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray, 
UInt32Array, UInt64Array};
 use datafusion_common::cast::as_boolean_array;
@@ -112,6 +113,7 @@ pub(crate) fn apply_join_filter_to_indices(
         &probe_indices,
         filter.column_indices(),
         build_side,
+        JoinType::Inner,
     )?;
     let filter_result = filter
         .expression()
@@ -129,6 +131,7 @@ pub(crate) fn apply_join_filter_to_indices(
 
 /// Returns a new [RecordBatch] by combining the `left` and `right` according 
to `indices`.
 /// The resulting batch has [Schema] `schema`.
+#[allow(clippy::too_many_arguments)]
 pub(crate) fn build_batch_from_indices(
     schema: &Schema,
     build_input_buffer: &RecordBatch,
@@ -137,6 +140,7 @@ pub(crate) fn build_batch_from_indices(
     probe_indices: &UInt32Array,
     column_indices: &[ColumnIndex],
     build_side: JoinSide,
+    join_type: JoinType,
 ) -> Result<RecordBatch> {
     if schema.fields().is_empty() {
         let options = RecordBatchOptions::new()
@@ -157,8 +161,12 @@ pub(crate) fn build_batch_from_indices(
 
     for column_index in column_indices {
         let array = if column_index.side == JoinSide::None {
-            // LeftMark join, the mark column is a true if the indices is not 
null, otherwise it will be false
-            Arc::new(compute::is_not_null(probe_indices)?)
+            // For mark joins, the mark column is a true if the indices is not 
null, otherwise it will be false
+            if join_type == JoinType::RightMark {
+                Arc::new(compute::is_not_null(build_indices)?)
+            } else {
+                Arc::new(compute::is_not_null(probe_indices)?)
+            }
         } else if column_index.side == build_side {
             let array = build_input_buffer.column(column_index.index);
             if array.is_empty() || build_indices.null_count() == 
build_indices.len() {
@@ -168,7 +176,7 @@ pub(crate) fn build_batch_from_indices(
                 assert_eq!(build_indices.null_count(), build_indices.len());
                 new_null_array(array.data_type(), build_indices.len())
             } else {
-                compute::take(array.as_ref(), build_indices, None)?
+                take(array.as_ref(), build_indices, None)?
             }
         } else {
             let array = probe_batch.column(column_index.index);
@@ -176,9 +184,10 @@ pub(crate) fn build_batch_from_indices(
                 assert_eq!(probe_indices.null_count(), probe_indices.len());
                 new_null_array(array.data_type(), probe_indices.len())
             } else {
-                compute::take(array.as_ref(), probe_indices, None)?
+                take(array.as_ref(), probe_indices, None)?
             }
         };
+
         columns.push(array);
     }
     Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
@@ -226,7 +235,12 @@ pub(crate) fn adjust_indices_by_join_type(
             // the left_indices will not be used later for the `right anti` 
join
             Ok((left_indices, right_indices))
         }
-        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark | 
JoinType::RightMark => {
+        JoinType::RightMark => {
+            let new_left_indices = get_mark_indices(&adjust_range, 
&right_indices);
+            let new_right_indices = adjust_range.map(|i| i as u32).collect();
+            Ok((new_left_indices, new_right_indices))
+        }
+        JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
             // matched or unmatched left row will be produced in the end of 
loop
             // When visit the right batch, we can output the matched left row 
and don't need to wait the end of loop
             Ok((
@@ -328,17 +342,7 @@ pub(crate) fn get_anti_indices<T: ArrowPrimitiveType>(
 where
     NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
 {
-    let mut bitmap = BooleanBufferBuilder::new(range.len());
-    bitmap.append_n(range.len(), false);
-    input_indices
-        .iter()
-        .flatten()
-        .map(|v| v.as_usize())
-        .filter(|v| range.contains(v))
-        .for_each(|v| {
-            bitmap.set_bit(v - range.start, true);
-        });
-
+    let bitmap = build_range_bitmap(&range, input_indices);
     let offset = range.start;
 
     // get the anti index
@@ -355,25 +359,52 @@ pub(crate) fn get_semi_indices<T: ArrowPrimitiveType>(
 where
     NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
 {
-    let mut bitmap = BooleanBufferBuilder::new(range.len());
-    bitmap.append_n(range.len(), false);
-    input_indices
-        .iter()
-        .flatten()
-        .map(|v| v.as_usize())
-        .filter(|v| range.contains(v))
-        .for_each(|v| {
-            bitmap.set_bit(v - range.start, true);
-        });
-
+    let bitmap = build_range_bitmap(&range, input_indices);
     let offset = range.start;
-
     // get the semi index
     (range)
         .filter_map(|idx| (bitmap.get_bit(idx - 
offset)).then_some(T::Native::from_usize(idx)))
         .collect()
 }
 
+/// Returns an array for mark joins consisting of default values (zeros) with 
null/non-null markers.
+///
+/// For each index in `range`:
+/// - If the index appears in `input_indices`, the value is non-null (0)
+/// - If the index does not appear in `input_indices`, the value is null
+///
+/// This is used in mark joins to indicate which rows had matches.
+pub(crate) fn get_mark_indices<T: ArrowPrimitiveType, R: ArrowPrimitiveType>(
+    range: &Range<usize>,
+    input_indices: &PrimitiveArray<T>,
+) -> PrimitiveArray<R>
+where
+    NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
+{
+    let mut bitmap = build_range_bitmap(range, input_indices);
+    PrimitiveArray::new(
+        vec![R::Native::default(); range.len()].into(),
+        Some(NullBuffer::new(bitmap.finish())),
+    )
+}
+
+fn build_range_bitmap<T: ArrowPrimitiveType>(
+    range: &Range<usize>,
+    input: &PrimitiveArray<T>,
+) -> BooleanBufferBuilder {
+    let mut builder = BooleanBufferBuilder::new(range.len());
+    builder.append_n(range.len(), false);
+
+    input.iter().flatten().for_each(|v| {
+        let idx = v.as_usize();
+        if range.contains(&idx) {
+            builder.set_bit(idx - range.start, true);
+        }
+    });
+
+    builder
+}
+
 /// Appends probe indices in order by considering the given build indices.
 ///
 /// This function constructs new build and probe indices by iterating through
@@ -432,23 +463,24 @@ pub(crate) fn asymmetric_join_output_partitioning(
     left: &Arc<dyn ExecutionPlan>,
     right: &Arc<dyn ExecutionPlan>,
     join_type: &JoinType,
-) -> Partitioning {
-    match join_type {
+) -> Result<Partitioning> {
+    let result = match join_type {
         JoinType::Inner | JoinType::Right => adjust_right_output_partitioning(
             right.output_partitioning(),
             left.schema().fields().len(),
-        )
-        .unwrap_or_else(|_| Partitioning::UnknownPartitioning(1)),
-        JoinType::RightSemi | JoinType::RightAnti => 
right.output_partitioning().clone(),
+        )?,
+        JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
+            right.output_partitioning().clone()
+        }
         JoinType::Left
         | JoinType::LeftSemi
         | JoinType::LeftAnti
         | JoinType::Full
-        | JoinType::LeftMark
-        | JoinType::RightMark => {
+        | JoinType::LeftMark => {
             
Partitioning::UnknownPartitioning(right.output_partitioning().partition_count())
         }
-    }
+    };
+    Ok(result)
 }
 
 /// This function is copied from

Reply via email to