This is an automated email from the ASF dual-hosted git repository.
kontinuation pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/sedona-db.git
The following commit(s) were added to refs/heads/main by this push:
new 931558e3 chore(rust/sedona-spatial-join): Support right mark joins and
enabled tests for all kinds of spatial joins (#514)
931558e3 is described below
commit 931558e3f710330f35103b7ef61a6bd580d435ba
Author: Kristin Cowalcijk <[email protected]>
AuthorDate: Thu Jan 15 23:56:58 2026 +0800
chore(rust/sedona-spatial-join): Support right mark joins and enabled tests
for all kinds of spatial joins (#514)
This patch updates join_utils.rs to support RightMark join. Spatial join
tests for all join types were enabled.
---
python/sedonadb/tests/test_sjoin.py | 167 +++++++++++++++++++++
rust/sedona-spatial-join/src/exec.rs | 183 +++++++++++++++++++++--
rust/sedona-spatial-join/src/stream.rs | 2 +
rust/sedona-spatial-join/src/utils/join_utils.rs | 108 ++++++++-----
4 files changed, 412 insertions(+), 48 deletions(-)
diff --git a/python/sedonadb/tests/test_sjoin.py
b/python/sedonadb/tests/test_sjoin.py
index beb412ce..9169fbfb 100644
--- a/python/sedonadb/tests/test_sjoin.py
+++ b/python/sedonadb/tests/test_sjoin.py
@@ -79,6 +79,173 @@ def test_spatial_join(join_type, on):
eng_postgis.assert_query_result(sql, sedonadb_results)
[email protected](
+ "join_type",
+ [
+ "LEFT SEMI JOIN",
+ "LEFT ANTI JOIN",
+ "RIGHT SEMI JOIN",
+ "RIGHT ANTI JOIN",
+ ],
+)
[email protected](
+ "on",
+ [
+ "ST_Intersects(sjoin_point.geometry, sjoin_polygon.geometry)",
+ "ST_Within(sjoin_point.geometry, sjoin_polygon.geometry)",
+ "ST_Contains(sjoin_polygon.geometry, sjoin_point.geometry)",
+ "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry, 1.0)",
+ "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry,
sjoin_point.dist / 100)",
+ "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry,
sjoin_polygon.dist / 100)",
+ ],
+)
+def test_spatial_join_semi_anti(join_type, on):
+ with (
+ SedonaDB.create_or_skip() as eng_sedonadb,
+ PostGIS.create_or_skip() as eng_postgis,
+ ):
+ options = json.dumps(
+ {
+ "geom_type": "Point",
+ "polygon_hole_rate": 0.5,
+ "num_parts_range": [2, 10],
+ "vertices_per_linestring_range": [2, 10],
+ "seed": 42,
+ }
+ )
+ df_point = eng_sedonadb.execute_and_collect(
+ f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100"
+ )
+ options = json.dumps(
+ {
+ "geom_type": "Polygon",
+ "polygon_hole_rate": 0.5,
+ "num_parts_range": [2, 10],
+ "vertices_per_linestring_range": [2, 10],
+ "seed": 43,
+ }
+ )
+ df_polygon = eng_sedonadb.execute_and_collect(
+ f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100"
+ )
+ eng_sedonadb.create_table_arrow("sjoin_point", df_point)
+ eng_sedonadb.create_table_arrow("sjoin_polygon", df_polygon)
+ eng_postgis.create_table_arrow("sjoin_point", df_point)
+ eng_postgis.create_table_arrow("sjoin_polygon", df_polygon)
+
+ is_left = join_type.startswith("LEFT")
+ is_semi = "SEMI" in join_type
+
+ if is_left:
+ sedona_sql = f"""
+ SELECT sjoin_point.id id0
+ FROM sjoin_point {join_type} sjoin_polygon
+ ON {on}
+ ORDER BY id0
+ """
+ exists = f"EXISTS (SELECT 1 FROM sjoin_polygon WHERE {on})"
+ where = exists if is_semi else f"NOT {exists}"
+ postgis_sql = f"""
+ SELECT sjoin_point.id id0
+ FROM sjoin_point
+ WHERE {where}
+ ORDER BY id0
+ """
+ else:
+ sedona_sql = f"""
+ SELECT sjoin_polygon.id id1
+ FROM sjoin_point {join_type} sjoin_polygon
+ ON {on}
+ ORDER BY id1
+ """
+ exists = f"EXISTS (SELECT 1 FROM sjoin_point WHERE {on})"
+ where = exists if is_semi else f"NOT {exists}"
+ postgis_sql = f"""
+ SELECT sjoin_polygon.id id1
+ FROM sjoin_polygon
+ WHERE {where}
+ ORDER BY id1
+ """
+
+ sedonadb_results =
eng_sedonadb.execute_and_collect(sedona_sql).to_pandas()
+ assert len(sedonadb_results) > 0
+ eng_postgis.assert_query_result(postgis_sql, sedonadb_results)
+
+
[email protected](
+ "outer",
+ ["point", "polygon"],
+)
[email protected](
+ "on",
+ [
+ "ST_Intersects(sjoin_point.geometry, sjoin_polygon.geometry)",
+ "ST_Within(sjoin_point.geometry, sjoin_polygon.geometry)",
+ "ST_DWithin(sjoin_point.geometry, sjoin_polygon.geometry, 1.0)",
+ ],
+)
+def test_spatial_mark_join_via_correlated_exists(outer, on):
+ with (
+ SedonaDB.create_or_skip() as eng_sedonadb,
+ PostGIS.create_or_skip() as eng_postgis,
+ ):
+ options = json.dumps(
+ {
+ "geom_type": "Point",
+ "polygon_hole_rate": 0.5,
+ "num_parts_range": [2, 10],
+ "vertices_per_linestring_range": [2, 10],
+ "seed": 42,
+ }
+ )
+ df_point = eng_sedonadb.execute_and_collect(
+ f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100"
+ )
+ options = json.dumps(
+ {
+ "geom_type": "Polygon",
+ "polygon_hole_rate": 0.5,
+ "num_parts_range": [2, 10],
+ "vertices_per_linestring_range": [2, 10],
+ "seed": 43,
+ }
+ )
+ df_polygon = eng_sedonadb.execute_and_collect(
+ f"SELECT * FROM sd_random_geometry('{options}') LIMIT 100"
+ )
+ eng_sedonadb.create_table_arrow("sjoin_point", df_point)
+ eng_sedonadb.create_table_arrow("sjoin_polygon", df_polygon)
+ eng_postgis.create_table_arrow("sjoin_point", df_point)
+ eng_postgis.create_table_arrow("sjoin_polygon", df_polygon)
+
+ if outer == "point":
+ sql = f"""
+ SELECT sjoin_point.id id0
+ FROM sjoin_point
+ WHERE sjoin_point.id = 1 OR EXISTS (SELECT 1 FROM
sjoin_polygon WHERE {on})
+ ORDER BY id0
+ """
+ else:
+ sql = f"""
+ SELECT sjoin_polygon.id id1,
ST_AsBinary(sjoin_polygon.geometry) geom
+ FROM sjoin_polygon
+ WHERE sjoin_polygon.id = 1 OR EXISTS (SELECT 1 FROM
sjoin_point WHERE {on})
+ ORDER BY id1
+ """
+
+ # Verify the physical query plan contains a Mark join
+ query_plan = eng_sedonadb.execute_and_collect(f"EXPLAIN
{sql}").to_pandas()
+ plan_text = "\n".join(query_plan.iloc[:, 1].astype(str).tolist())
+ assert any(
+ "SpatialJoinExec" in line and ("LeftMark" in line or "RightMark"
in line)
+ for line in plan_text.splitlines()
+ ), plan_text
+
+ sedonadb_results = eng_sedonadb.execute_and_collect(sql).to_pandas()
+ assert len(sedonadb_results) > 0
+ eng_postgis.assert_query_result(sql, sedonadb_results)
+
+
@pytest.mark.parametrize(
"join_type", ["INNER JOIN", "LEFT OUTER JOIN", "RIGHT OUTER JOIN"]
)
diff --git a/rust/sedona-spatial-join/src/exec.rs
b/rust/sedona-spatial-join/src/exec.rs
index e3440cd1..5cdea16d 100644
--- a/rust/sedona-spatial-join/src/exec.rs
+++ b/rust/sedona-spatial-join/src/exec.rs
@@ -272,7 +272,7 @@ impl SpatialJoinExec {
match knn.probe_side {
JoinSide::Left => left.output_partitioning().clone(),
JoinSide::Right => right.output_partitioning().clone(),
- _ => asymmetric_join_output_partitioning(left, right,
&join_type),
+ _ => asymmetric_join_output_partitioning(left, right,
&join_type)?,
}
} else if converted_from_hash_join {
// Replicate HashJoin's symmetric partitioning logic
@@ -290,10 +290,10 @@ impl SpatialJoinExec {
// For full outer join, we can't preserve partitioning
Partitioning::UnknownPartitioning(left.output_partitioning().partition_count())
}
- _ => asymmetric_join_output_partitioning(left, right,
&join_type),
+ _ => asymmetric_join_output_partitioning(left, right,
&join_type)?,
}
} else {
- asymmetric_join_output_partitioning(left, right, &join_type)
+ asymmetric_join_output_partitioning(left, right, &join_type)?
};
if let Some(projection) = projection {
@@ -615,6 +615,7 @@ mod tests {
};
use datafusion_common::tree_node::{TreeNode, TreeNodeRecursion};
use datafusion_expr::ColumnarValue;
+ use datafusion_physical_plan::joins::NestedLoopJoinExec;
use geo::{Distance, Euclidean};
use geo_types::{Coord, Rect};
use rstest::rstest;
@@ -996,7 +997,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn test_left_joins(
- #[values(JoinType::Left, /* JoinType::LeftSemi, JoinType::LeftAnti
*/)] join_type: JoinType,
+ #[values(JoinType::Left, JoinType::LeftSemi, JoinType::LeftAnti)]
join_type: JoinType,
) -> Result<()> {
test_with_join_types(join_type).await?;
Ok(())
@@ -1005,8 +1006,7 @@ mod tests {
#[rstest]
#[tokio::test]
async fn test_right_joins(
- #[values(JoinType::Right, /* JoinType::RightSemi, JoinType::RightAnti
*/)]
- join_type: JoinType,
+ #[values(JoinType::Right, JoinType::RightSemi, JoinType::RightAnti)]
join_type: JoinType,
) -> Result<()> {
test_with_join_types(join_type).await?;
Ok(())
@@ -1018,6 +1018,82 @@ mod tests {
Ok(())
}
+ #[rstest]
+ #[tokio::test]
+ async fn test_mark_joins(
+ #[values(JoinType::LeftMark, JoinType::RightMark)] join_type: JoinType,
+ ) -> Result<()> {
+ let options = SpatialJoinOptions::default();
+ test_mark_join(join_type, options, 10).await?;
+ Ok(())
+ }
+
+ #[tokio::test]
+ async fn test_mark_join_via_correlated_exists_sql() -> Result<()> {
+ let ((left_schema, left_partitions), (right_schema, right_partitions))
=
+ create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
+
+ let mem_table_left: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(
+ left_schema.clone(),
+ left_partitions.clone(),
+ )?);
+ let mem_table_right: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(
+ right_schema.clone(),
+ right_partitions.clone(),
+ )?);
+
+ // DataFusion doesn't have explicit SQL syntax for MARK joins.
Predicate subqueries embedded
+ // in a more complex boolean expression (e.g. OR) are planned using a
MARK join.
+ //
+ // Using EXISTS here (rather than IN) keeps the join filter as the
pulled-up correlated
+ // predicate (ST_Intersects), which is what SpatialJoinExec can
optimize.
+ let sql = "SELECT L.id FROM L WHERE L.id = 1 OR EXISTS (SELECT 1 FROM
R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY L.id";
+
+ let batch_size = 10;
+ let options = SpatialJoinOptions::default();
+
+ // Optimized plan should include a SpatialJoinExec with Mark join type.
+ let ctx = setup_context(Some(options), batch_size)?;
+ ctx.register_table("L", Arc::clone(&mem_table_left))?;
+ ctx.register_table("R", Arc::clone(&mem_table_right))?;
+ let df = ctx.sql(sql).await?;
+ let plan = df.clone().create_physical_plan().await?;
+ let spatial_join_execs = collect_spatial_join_exec(&plan)?;
+ assert!(
+ spatial_join_execs
+ .iter()
+ .any(|exec| matches!(*exec.join_type(), JoinType::LeftMark |
JoinType::RightMark)),
+ "expected correlated IN-subquery to plan using a MARK join when
optimized"
+ );
+ let actual_schema = df.schema().as_arrow().clone();
+ let actual_batches = df.collect().await?;
+ let actual_batch =
+ arrow::compute::concat_batches(&Arc::new(actual_schema),
&actual_batches)?;
+
+ // Unoptimized plan should still contain a Mark join, but implemented
as NestedLoopJoinExec.
+ let ctx_no_opt = setup_context(None, batch_size)?;
+ ctx_no_opt.register_table("L", mem_table_left)?;
+ ctx_no_opt.register_table("R", mem_table_right)?;
+ let df_no_opt = ctx_no_opt.sql(sql).await?;
+ let plan_no_opt = df_no_opt.clone().create_physical_plan().await?;
+ let nlj_execs = collect_nested_loop_join_exec(&plan_no_opt)?;
+ assert!(
+ nlj_execs
+ .iter()
+ .any(|exec| matches!(*exec.join_type(), JoinType::LeftMark |
JoinType::RightMark)),
+ "expected correlated IN-subquery to plan using a MARK join when
not optimized"
+ );
+ let expected_schema = df_no_opt.schema().as_arrow().clone();
+ let expected_batches = df_no_opt.collect().await?;
+ let expected_batch =
+ arrow::compute::concat_batches(&Arc::new(expected_schema),
&expected_batches)?;
+
+ assert!(expected_batch.num_rows() > 0);
+ assert_eq!(expected_batch, actual_batch);
+
+ Ok(())
+ }
+
#[tokio::test]
async fn test_geography_join_is_not_optimized() -> Result<()> {
let options = SpatialJoinOptions::default();
@@ -1075,10 +1151,10 @@ mod tests {
JoinType::Left => "SELECT L.id l_id, R.id r_id FROM L LEFT JOIN R
ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
JoinType::Right => "SELECT L.id l_id, R.id r_id FROM L RIGHT JOIN
R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
JoinType::Full => "SELECT L.id l_id, R.id r_id FROM L FULL OUTER
JOIN R ON ST_Intersects(L.geometry, R.geometry) ORDER BY l_id, r_id",
- JoinType::LeftSemi => "SELECT L.id l_id FROM L WHERE EXISTS
(SELECT 1 FROM R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY l_id",
- JoinType::RightSemi => "SELECT R.id r_id FROM R WHERE EXISTS
(SELECT 1 FROM L WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY r_id",
- JoinType::LeftAnti => "SELECT L.id l_id FROM L WHERE NOT EXISTS
(SELECT 1 FROM R WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY l_id",
- JoinType::RightAnti => "SELECT R.id r_id FROM R WHERE NOT EXISTS
(SELECT 1 FROM L WHERE ST_Intersects(L.geometry, R.geometry)) ORDER BY r_id",
+ JoinType::LeftSemi => "SELECT L.id l_id FROM L LEFT SEMI JOIN R ON
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id",
+ JoinType::RightSemi => "SELECT R.id r_id FROM L RIGHT SEMI JOIN R
ON ST_Intersects(L.geometry, R.geometry) ORDER BY r_id",
+ JoinType::LeftAnti => "SELECT L.id l_id FROM L LEFT ANTI JOIN R ON
ST_Intersects(L.geometry, R.geometry) ORDER BY l_id",
+ JoinType::RightAnti => "SELECT R.id r_id FROM L RIGHT ANTI JOIN R
ON ST_Intersects(L.geometry, R.geometry) ORDER BY r_id",
JoinType::LeftMark => {
unreachable!("LeftMark is not directly supported in SQL, will
be tested in other tests");
}
@@ -1203,6 +1279,93 @@ mod tests {
Ok(spatial_join_execs)
}
+ fn collect_nested_loop_join_exec(
+ plan: &Arc<dyn ExecutionPlan>,
+ ) -> Result<Vec<&NestedLoopJoinExec>> {
+ let mut execs = Vec::new();
+ plan.apply(|node| {
+ if let Some(exec) =
node.as_any().downcast_ref::<NestedLoopJoinExec>() {
+ execs.push(exec);
+ }
+ Ok(TreeNodeRecursion::Continue)
+ })?;
+ Ok(execs)
+ }
+
+ async fn test_mark_join(
+ join_type: JoinType,
+ options: SpatialJoinOptions,
+ batch_size: usize,
+ ) -> Result<()> {
+ let ((left_schema, left_partitions), (right_schema, right_partitions))
=
+ create_test_data_with_size_range((0.1, 10.0), WKB_GEOMETRY)?;
+ let mem_table_left: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(
+ left_schema.clone(),
+ left_partitions.clone(),
+ )?);
+ let mem_table_right: Arc<dyn TableProvider> =
Arc::new(MemTable::try_new(
+ right_schema.clone(),
+ right_partitions.clone(),
+ )?);
+
+ // We use a Left Join as a template to create the plan, then modify it
to Mark Join
+ let sql = "SELECT * FROM L LEFT JOIN R ON ST_Intersects(L.geometry,
R.geometry)";
+
+ // Create SpatialJoinExec plan
+ let ctx = setup_context(Some(options), batch_size)?;
+ ctx.register_table("L", mem_table_left.clone())?;
+ ctx.register_table("R", mem_table_right.clone())?;
+ let df = ctx.sql(sql).await?;
+ let plan = df.create_physical_plan().await?;
+ let spatial_join_execs = collect_spatial_join_exec(&plan)?;
+ assert_eq!(spatial_join_execs.len(), 1);
+ let original_exec = spatial_join_execs[0];
+ let mark_exec = SpatialJoinExec::try_new(
+ original_exec.left.clone(),
+ original_exec.right.clone(),
+ original_exec.on.clone(),
+ original_exec.filter.clone(),
+ &join_type,
+ None,
+ )?;
+
+ // Create NestedLoopJoinExec plan for comparison
+ let ctx_no_opt = setup_context(None, batch_size)?;
+ ctx_no_opt.register_table("L", mem_table_left)?;
+ ctx_no_opt.register_table("R", mem_table_right)?;
+ let df_no_opt = ctx_no_opt.sql(sql).await?;
+ let plan_no_opt = df_no_opt.create_physical_plan().await?;
+ let nlj_execs = collect_nested_loop_join_exec(&plan_no_opt)?;
+ assert_eq!(nlj_execs.len(), 1);
+ let original_nlj = nlj_execs[0];
+ let mark_nlj = NestedLoopJoinExec::try_new(
+ original_nlj.children()[0].clone(),
+ original_nlj.children()[1].clone(),
+ original_nlj.filter().cloned(),
+ &join_type,
+ None,
+ )?;
+
+ async fn run_and_sort(
+ plan: Arc<dyn ExecutionPlan>,
+ ctx: &SessionContext,
+ ) -> Result<RecordBatch> {
+ let results = datafusion_physical_plan::collect(plan,
ctx.task_ctx()).await?;
+ let batch = arrow::compute::concat_batches(&results[0].schema(),
&results)?;
+ let sort_col = batch.column(0);
+ let indices = arrow::compute::sort_to_indices(sort_col, None,
None)?;
+ let sorted_batch = arrow::compute::take_record_batch(&batch,
&indices)?;
+ Ok(sorted_batch)
+ }
+
+ // Run both Mark Join plans and compare results
+ let mark_batch = run_and_sort(Arc::new(mark_exec), &ctx).await?;
+ let mark_nlj_batch = run_and_sort(Arc::new(mark_nlj),
&ctx_no_opt).await?;
+ assert_eq!(mark_batch, mark_nlj_batch);
+
+ Ok(())
+ }
+
fn extract_geoms_and_ids(partitions: &[Vec<RecordBatch>]) -> Vec<(i32,
geo::Geometry<f64>)> {
let mut result = Vec::new();
for partition in partitions {
diff --git a/rust/sedona-spatial-join/src/stream.rs
b/rust/sedona-spatial-join/src/stream.rs
index 7fa42231..37a84523 100644
--- a/rust/sedona-spatial-join/src/stream.rs
+++ b/rust/sedona-spatial-join/src/stream.rs
@@ -698,6 +698,7 @@ impl SpatialJoinBatchIterator {
&probe_indices,
column_indices,
build_side,
+ join_type,
)?;
// Update metrics with actual output
@@ -896,6 +897,7 @@ impl UnmatchedBuildBatchIterator {
&right_side,
column_indices,
build_side,
+ join_type,
)?
};
diff --git a/rust/sedona-spatial-join/src/utils/join_utils.rs
b/rust/sedona-spatial-join/src/utils/join_utils.rs
index 83ec18f4..87aaa9ae 100644
--- a/rust/sedona-spatial-join/src/utils/join_utils.rs
+++ b/rust/sedona-spatial-join/src/utils/join_utils.rs
@@ -16,14 +16,15 @@
// under the License.
/// Most of the code in this module are copied from the
`datafusion_physical_plan::joins::utils` module.
-///
https://github.com/apache/datafusion/blob/48.0.0/datafusion/physical-plan/src/joins/utils.rs
+///
https://github.com/apache/datafusion/blob/50.2.0/datafusion/physical-plan/src/joins/utils.rs
use std::{ops::Range, sync::Arc};
use arrow::array::{
downcast_array, new_null_array, Array, BooleanBufferBuilder, RecordBatch,
RecordBatchOptions,
UInt32Builder, UInt64Builder,
};
-use arrow::compute;
+use arrow::buffer::NullBuffer;
+use arrow::compute::{self, take};
use arrow::datatypes::{ArrowNativeType, Schema, UInt32Type, UInt64Type};
use arrow_array::{ArrowPrimitiveType, NativeAdapter, PrimitiveArray,
UInt32Array, UInt64Array};
use datafusion_common::cast::as_boolean_array;
@@ -112,6 +113,7 @@ pub(crate) fn apply_join_filter_to_indices(
&probe_indices,
filter.column_indices(),
build_side,
+ JoinType::Inner,
)?;
let filter_result = filter
.expression()
@@ -129,6 +131,7 @@ pub(crate) fn apply_join_filter_to_indices(
/// Returns a new [RecordBatch] by combining the `left` and `right` according
to `indices`.
/// The resulting batch has [Schema] `schema`.
+#[allow(clippy::too_many_arguments)]
pub(crate) fn build_batch_from_indices(
schema: &Schema,
build_input_buffer: &RecordBatch,
@@ -137,6 +140,7 @@ pub(crate) fn build_batch_from_indices(
probe_indices: &UInt32Array,
column_indices: &[ColumnIndex],
build_side: JoinSide,
+ join_type: JoinType,
) -> Result<RecordBatch> {
if schema.fields().is_empty() {
let options = RecordBatchOptions::new()
@@ -157,8 +161,12 @@ pub(crate) fn build_batch_from_indices(
for column_index in column_indices {
let array = if column_index.side == JoinSide::None {
- // LeftMark join, the mark column is a true if the indices is not
null, otherwise it will be false
- Arc::new(compute::is_not_null(probe_indices)?)
+ // For mark joins, the mark column is a true if the indices is not
null, otherwise it will be false
+ if join_type == JoinType::RightMark {
+ Arc::new(compute::is_not_null(build_indices)?)
+ } else {
+ Arc::new(compute::is_not_null(probe_indices)?)
+ }
} else if column_index.side == build_side {
let array = build_input_buffer.column(column_index.index);
if array.is_empty() || build_indices.null_count() ==
build_indices.len() {
@@ -168,7 +176,7 @@ pub(crate) fn build_batch_from_indices(
assert_eq!(build_indices.null_count(), build_indices.len());
new_null_array(array.data_type(), build_indices.len())
} else {
- compute::take(array.as_ref(), build_indices, None)?
+ take(array.as_ref(), build_indices, None)?
}
} else {
let array = probe_batch.column(column_index.index);
@@ -176,9 +184,10 @@ pub(crate) fn build_batch_from_indices(
assert_eq!(probe_indices.null_count(), probe_indices.len());
new_null_array(array.data_type(), probe_indices.len())
} else {
- compute::take(array.as_ref(), probe_indices, None)?
+ take(array.as_ref(), probe_indices, None)?
}
};
+
columns.push(array);
}
Ok(RecordBatch::try_new(Arc::new(schema.clone()), columns)?)
@@ -226,7 +235,12 @@ pub(crate) fn adjust_indices_by_join_type(
// the left_indices will not be used later for the `right anti`
join
Ok((left_indices, right_indices))
}
- JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark |
JoinType::RightMark => {
+ JoinType::RightMark => {
+ let new_left_indices = get_mark_indices(&adjust_range,
&right_indices);
+ let new_right_indices = adjust_range.map(|i| i as u32).collect();
+ Ok((new_left_indices, new_right_indices))
+ }
+ JoinType::LeftSemi | JoinType::LeftAnti | JoinType::LeftMark => {
// matched or unmatched left row will be produced in the end of
loop
// When visit the right batch, we can output the matched left row
and don't need to wait the end of loop
Ok((
@@ -328,17 +342,7 @@ pub(crate) fn get_anti_indices<T: ArrowPrimitiveType>(
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
- let mut bitmap = BooleanBufferBuilder::new(range.len());
- bitmap.append_n(range.len(), false);
- input_indices
- .iter()
- .flatten()
- .map(|v| v.as_usize())
- .filter(|v| range.contains(v))
- .for_each(|v| {
- bitmap.set_bit(v - range.start, true);
- });
-
+ let bitmap = build_range_bitmap(&range, input_indices);
let offset = range.start;
// get the anti index
@@ -355,25 +359,52 @@ pub(crate) fn get_semi_indices<T: ArrowPrimitiveType>(
where
NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
{
- let mut bitmap = BooleanBufferBuilder::new(range.len());
- bitmap.append_n(range.len(), false);
- input_indices
- .iter()
- .flatten()
- .map(|v| v.as_usize())
- .filter(|v| range.contains(v))
- .for_each(|v| {
- bitmap.set_bit(v - range.start, true);
- });
-
+ let bitmap = build_range_bitmap(&range, input_indices);
let offset = range.start;
-
// get the semi index
(range)
.filter_map(|idx| (bitmap.get_bit(idx -
offset)).then_some(T::Native::from_usize(idx)))
.collect()
}
+/// Returns an array for mark joins consisting of default values (zeros) with
null/non-null markers.
+///
+/// For each index in `range`:
+/// - If the index appears in `input_indices`, the value is non-null (0)
+/// - If the index does not appear in `input_indices`, the value is null
+///
+/// This is used in mark joins to indicate which rows had matches.
+pub(crate) fn get_mark_indices<T: ArrowPrimitiveType, R: ArrowPrimitiveType>(
+ range: &Range<usize>,
+ input_indices: &PrimitiveArray<T>,
+) -> PrimitiveArray<R>
+where
+ NativeAdapter<T>: From<<T as ArrowPrimitiveType>::Native>,
+{
+ let mut bitmap = build_range_bitmap(range, input_indices);
+ PrimitiveArray::new(
+ vec![R::Native::default(); range.len()].into(),
+ Some(NullBuffer::new(bitmap.finish())),
+ )
+}
+
+fn build_range_bitmap<T: ArrowPrimitiveType>(
+ range: &Range<usize>,
+ input: &PrimitiveArray<T>,
+) -> BooleanBufferBuilder {
+ let mut builder = BooleanBufferBuilder::new(range.len());
+ builder.append_n(range.len(), false);
+
+ input.iter().flatten().for_each(|v| {
+ let idx = v.as_usize();
+ if range.contains(&idx) {
+ builder.set_bit(idx - range.start, true);
+ }
+ });
+
+ builder
+}
+
/// Appends probe indices in order by considering the given build indices.
///
/// This function constructs new build and probe indices by iterating through
@@ -432,23 +463,24 @@ pub(crate) fn asymmetric_join_output_partitioning(
left: &Arc<dyn ExecutionPlan>,
right: &Arc<dyn ExecutionPlan>,
join_type: &JoinType,
-) -> Partitioning {
- match join_type {
+) -> Result<Partitioning> {
+ let result = match join_type {
JoinType::Inner | JoinType::Right => adjust_right_output_partitioning(
right.output_partitioning(),
left.schema().fields().len(),
- )
- .unwrap_or_else(|_| Partitioning::UnknownPartitioning(1)),
- JoinType::RightSemi | JoinType::RightAnti =>
right.output_partitioning().clone(),
+ )?,
+ JoinType::RightSemi | JoinType::RightAnti | JoinType::RightMark => {
+ right.output_partitioning().clone()
+ }
JoinType::Left
| JoinType::LeftSemi
| JoinType::LeftAnti
| JoinType::Full
- | JoinType::LeftMark
- | JoinType::RightMark => {
+ | JoinType::LeftMark => {
Partitioning::UnknownPartitioning(right.output_partitioning().partition_count())
}
- }
+ };
+ Ok(result)
}
/// This function is copied from