This is an automated email from the ASF dual-hosted git repository.
mbutrovich pushed a commit to branch branch-52
in repository https://gitbox.apache.org/repos/asf/datafusion.git
The following commit(s) were added to refs/heads/branch-52 by this push:
new 19a0fcaa27 [branch-52] SortMergeJoin don't wait for all input before
emitting (#20699)
19a0fcaa27 is described below
commit 19a0fcaa276c86beda544c6e01c75f6e0639767e
Author: Matt Butrovich <[email protected]>
AuthorDate: Wed Mar 4 13:58:50 2026 -0500
[branch-52] SortMergeJoin don't wait for all input before emitting (#20699)
## Which issue does this PR close?
Backport of #20482 to branch-52.
## Rationale for this change
Cherry-pick fix and prerequisites so that SortMergeJoin emits output
incrementally instead of waiting for all input to
complete. This resolves OOM issues Comet is seeing with DataFusion 52.
## What changes are included in this PR?
Cherry-picks of the following commits from `main`:
1. #19614 — Extract sort-merge join filter logic into separate module
2. #20463 — Use zero-copy slice instead of take kernel in sort merge
join
3. #20482 — Fix SortMergeJoin to not wait for all input before emitting
## Are these changes tested?
Yes, covered by existing and new tests included in #20482.
## Are there any user-facing changes?
No.
---------
Co-authored-by: Liang-Chi Hsieh <[email protected]>
Co-authored-by: Claude Sonnet 4.5 <[email protected]>
Co-authored-by: Andy Grove <[email protected]>
Co-authored-by: Raz Luvaton <[email protected]>
---
.../src/joins/sort_merge_join/filter.rs | 595 ++++++++++++++++++++
.../physical-plan/src/joins/sort_merge_join/mod.rs | 1 +
.../src/joins/sort_merge_join/stream.rs | 607 +++++----------------
.../src/joins/sort_merge_join/tests.rs | 514 +++++++++++++++--
datafusion/physical-plan/src/test/exec.rs | 111 +++-
5 files changed, 1307 insertions(+), 521 deletions(-)
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs
b/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs
new file mode 100644
index 0000000000..d598442b65
--- /dev/null
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs
@@ -0,0 +1,595 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Filter handling for Sort-Merge Join
+//!
+//! This module encapsulates the complexity of join filter evaluation,
including:
+//! - Immediate filtering for INNER joins
+//! - Deferred filtering for outer/semi/anti/mark joins
+//! - Metadata tracking for grouping output rows by input row
+//! - Correcting filter masks to handle multiple matches per input row
+
+use std::sync::Arc;
+
+use arrow::array::{
+ Array, ArrayBuilder, ArrayRef, BooleanArray, BooleanBuilder, RecordBatch,
+ UInt64Array, UInt64Builder,
+};
+use arrow::compute::{self, concat_batches, filter_record_batch};
+use arrow::datatypes::SchemaRef;
+use datafusion_common::{JoinSide, JoinType, Result};
+
+use crate::joins::utils::JoinFilter;
+
+/// Metadata for tracking filter results during deferred filtering
+///
+/// When a join filter is present and we need to ensure each input row produces
+/// at least one output (outer joins) or exactly one output (semi joins), we
can't
+/// filter immediately. Instead, we accumulate all joined rows with metadata,
+/// then post-process to determine which rows to output.
+#[derive(Debug)]
+pub struct FilterMetadata {
+ /// Did each output row pass the join filter?
+ /// Used to detect if an input row found ANY match
+ pub filter_mask: BooleanBuilder,
+
+ /// Which input row (within batch) produced each output row?
+ /// Used for grouping output rows by input row
+ pub row_indices: UInt64Builder,
+
+ /// Which input batch did each output row come from?
+ /// Used to disambiguate row_indices across multiple batches
+ pub batch_ids: Vec<usize>,
+}
+
+impl FilterMetadata {
+ /// Create new empty filter metadata
+ pub fn new() -> Self {
+ Self {
+ filter_mask: BooleanBuilder::new(),
+ row_indices: UInt64Builder::new(),
+ batch_ids: vec![],
+ }
+ }
+
+ /// Returns (row_indices, filter_mask, batch_ids_ref) and clears builders
+ pub fn finish_metadata(&mut self) -> (UInt64Array, BooleanArray, &[usize])
{
+ let row_indices = self.row_indices.finish();
+ let filter_mask = self.filter_mask.finish();
+ (row_indices, filter_mask, &self.batch_ids)
+ }
+
+ /// Add metadata for null-joined rows (no filter applied)
+ pub fn append_nulls(&mut self, num_rows: usize) {
+ self.filter_mask.append_nulls(num_rows);
+ self.row_indices.append_nulls(num_rows);
+ self.batch_ids.resize(
+ self.batch_ids.len() + num_rows,
+ 0, // batch_id = 0 for null-joined rows
+ );
+ }
+
+ /// Add metadata for filtered rows
+ pub fn append_filter_metadata(
+ &mut self,
+ row_indices: &UInt64Array,
+ filter_mask: &BooleanArray,
+ batch_id: usize,
+ ) {
+ debug_assert_eq!(
+ row_indices.len(),
+ filter_mask.len(),
+ "row_indices and filter_mask must have same length"
+ );
+
+ for i in 0..row_indices.len() {
+ if filter_mask.is_null(i) {
+ self.filter_mask.append_null();
+ } else if filter_mask.value(i) {
+ self.filter_mask.append_value(true);
+ } else {
+ self.filter_mask.append_value(false);
+ }
+
+ if row_indices.is_null(i) {
+ self.row_indices.append_null();
+ } else {
+ self.row_indices.append_value(row_indices.value(i));
+ }
+
+ self.batch_ids.push(batch_id);
+ }
+ }
+
+ /// Verify that metadata arrays are aligned (same length)
+ pub fn debug_assert_metadata_aligned(&self) {
+ if self.filter_mask.len() > 0 {
+ debug_assert_eq!(
+ self.filter_mask.len(),
+ self.row_indices.len(),
+ "filter_mask and row_indices must have same length when
metadata is used"
+ );
+ debug_assert_eq!(
+ self.filter_mask.len(),
+ self.batch_ids.len(),
+ "filter_mask and batch_ids must have same length when metadata
is used"
+ );
+ } else {
+ debug_assert_eq!(
+ self.filter_mask.len(),
+ 0,
+ "filter_mask should be empty when batches is empty"
+ );
+ }
+ }
+}
+
+impl Default for FilterMetadata {
+ fn default() -> Self {
+ Self::new()
+ }
+}
+
+/// Determines if a join type needs deferred filtering
+///
+/// Deferred filtering is required when:
+/// - A filter exists AND
+/// - The join type requires ensuring each input row produces at least one
output
+/// (or exactly one for semi joins)
+pub fn needs_deferred_filtering(
+ filter: &Option<JoinFilter>,
+ join_type: JoinType,
+) -> bool {
+ filter.is_some()
+ && matches!(
+ join_type,
+ JoinType::Left
+ | JoinType::LeftSemi
+ | JoinType::LeftMark
+ | JoinType::Right
+ | JoinType::RightSemi
+ | JoinType::RightMark
+ | JoinType::LeftAnti
+ | JoinType::RightAnti
+ | JoinType::Full
+ )
+}
+
+/// Gets the arrays which join filters are applied on
+///
+/// Extracts the columns needed for filter evaluation from left and right
batch columns
+pub fn get_filter_columns(
+ join_filter: &Option<JoinFilter>,
+ left_columns: &[ArrayRef],
+ right_columns: &[ArrayRef],
+) -> Vec<ArrayRef> {
+ let mut filter_columns = vec![];
+
+ if let Some(f) = join_filter {
+ let left_columns: Vec<ArrayRef> = f
+ .column_indices()
+ .iter()
+ .filter(|col_index| col_index.side == JoinSide::Left)
+ .map(|i| Arc::clone(&left_columns[i.index]))
+ .collect();
+ let right_columns: Vec<ArrayRef> = f
+ .column_indices()
+ .iter()
+ .filter(|col_index| col_index.side == JoinSide::Right)
+ .map(|i| Arc::clone(&right_columns[i.index]))
+ .collect();
+
+ filter_columns.extend(left_columns);
+ filter_columns.extend(right_columns);
+ }
+
+ filter_columns
+}
+
+/// Determines if current index is the last occurrence of a row
+///
+/// Used during filter mask correction to detect row boundaries when grouping
+/// output rows by input row.
+fn last_index_for_row(
+ row_index: usize,
+ indices: &UInt64Array,
+ batch_ids: &[usize],
+ indices_len: usize,
+) -> bool {
+ debug_assert_eq!(
+ indices.len(),
+ indices_len,
+ "indices.len() should match indices_len parameter"
+ );
+ debug_assert_eq!(
+ batch_ids.len(),
+ indices_len,
+ "batch_ids.len() should match indices_len"
+ );
+ debug_assert!(
+ row_index < indices_len,
+ "row_index {row_index} should be < indices_len {indices_len}",
+ );
+
+ // If this is the last index overall, it's definitely the last for this row
+ if row_index == indices_len - 1 {
+ return true;
+ }
+
+ // Check if next row has different (batch_id, index) pair
+ let current_batch_id = batch_ids[row_index];
+ let next_batch_id = batch_ids[row_index + 1];
+
+ if current_batch_id != next_batch_id {
+ return true;
+ }
+
+ // Same batch_id, check if row index is different
+ // Both current and next should be non-null (already joined rows)
+ if indices.is_null(row_index) || indices.is_null(row_index + 1) {
+ return true;
+ }
+
+ indices.value(row_index) != indices.value(row_index + 1)
+}
+
+/// Corrects the filter mask for joins with deferred filtering
+///
+/// When an input row joins with multiple buffered rows, we get multiple
output rows.
+/// This function groups them by input row and applies join-type-specific
logic:
+///
+/// - **Outer joins**: Keep first matching row, convert rest to nulls, add
null-joined for unmatched
+/// - **Semi joins**: Keep first matching row, discard rest
+/// - **Anti joins**: Keep row only if NO matches passed filter
+/// - **Mark joins**: Like semi but first match only
+///
+/// # Arguments
+/// * `join_type` - The type of join being performed
+/// * `row_indices` - Which input row produced each output row
+/// * `batch_ids` - Which batch each output row came from
+/// * `filter_mask` - Whether each output row passed the filter
+/// * `expected_size` - Total number of input rows (for adding unmatched)
+///
+/// # Returns
+/// Corrected mask indicating which rows to include in final output:
+/// - `true`: Include this row
+/// - `false`: Convert to null-joined row (outer joins) or include as
unmatched (anti joins)
+/// - `null`: Discard this row
+pub fn get_corrected_filter_mask(
+ join_type: JoinType,
+ row_indices: &UInt64Array,
+ batch_ids: &[usize],
+ filter_mask: &BooleanArray,
+ expected_size: usize,
+) -> Option<BooleanArray> {
+ let row_indices_length = row_indices.len();
+ let mut corrected_mask: BooleanBuilder =
+ BooleanBuilder::with_capacity(row_indices_length);
+ let mut seen_true = false;
+
+ match join_type {
+ JoinType::Left | JoinType::Right => {
+ // For outer joins: Keep first matching row per input row,
+ // convert rest to nulls, add null-joined rows for unmatched
+ for i in 0..row_indices_length {
+ let last_index =
+ last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
+ if filter_mask.value(i) {
+ seen_true = true;
+ corrected_mask.append_value(true);
+ } else if seen_true || !filter_mask.value(i) && !last_index {
+ corrected_mask.append_null(); // to be ignored and not set
to output
+ } else {
+ corrected_mask.append_value(false); // to be converted to
null joined row
+ }
+
+ if last_index {
+ seen_true = false;
+ }
+ }
+
+ // Generate null joined rows for records which have no matching
join key
+ corrected_mask.append_n(expected_size - corrected_mask.len(),
false);
+ Some(corrected_mask.finish())
+ }
+ JoinType::LeftMark | JoinType::RightMark => {
+ // For mark joins: Like outer but only keep first match, mark with
boolean
+ for i in 0..row_indices_length {
+ let last_index =
+ last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
+ if filter_mask.value(i) && !seen_true {
+ seen_true = true;
+ corrected_mask.append_value(true);
+ } else if seen_true || !filter_mask.value(i) && !last_index {
+ corrected_mask.append_null(); // to be ignored and not set
to output
+ } else {
+ corrected_mask.append_value(false); // to be converted to
null joined row
+ }
+
+ if last_index {
+ seen_true = false;
+ }
+ }
+
+ // Generate null joined rows for records which have no matching
join key
+ corrected_mask.append_n(expected_size - corrected_mask.len(),
false);
+ Some(corrected_mask.finish())
+ }
+ JoinType::LeftSemi | JoinType::RightSemi => {
+ // For semi joins: Keep only first matching row per input row,
discard rest
+ for i in 0..row_indices_length {
+ let last_index =
+ last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
+ if filter_mask.value(i) && !seen_true {
+ seen_true = true;
+ corrected_mask.append_value(true);
+ } else {
+ corrected_mask.append_null(); // to be ignored and not set
to output
+ }
+
+ if last_index {
+ seen_true = false;
+ }
+ }
+
+ Some(corrected_mask.finish())
+ }
+ JoinType::LeftAnti | JoinType::RightAnti => {
+ // For anti joins: Keep row only if NO matches passed the filter
+ for i in 0..row_indices_length {
+ let last_index =
+ last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
+
+ if filter_mask.value(i) {
+ seen_true = true;
+ }
+
+ if last_index {
+ if !seen_true {
+ corrected_mask.append_value(true);
+ } else {
+ corrected_mask.append_null();
+ }
+
+ seen_true = false;
+ } else {
+ corrected_mask.append_null();
+ }
+ }
+ // Generate null joined rows for records which have no matching
join key,
+ // for LeftAnti non-matched considered as true
+ corrected_mask.append_n(expected_size - corrected_mask.len(),
true);
+ Some(corrected_mask.finish())
+ }
+ JoinType::Full => {
+ // For full joins: Similar to outer but handle both sides
+ for i in 0..row_indices_length {
+ let last_index =
+ last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
+
+ if filter_mask.is_null(i) {
+ // null joined
+ corrected_mask.append_value(true);
+ } else if filter_mask.value(i) {
+ seen_true = true;
+ corrected_mask.append_value(true);
+ } else if seen_true || !filter_mask.value(i) && !last_index {
+ corrected_mask.append_null(); // to be ignored and not set
to output
+ } else {
+ corrected_mask.append_value(false); // to be converted to
null joined row
+ }
+
+ if last_index {
+ seen_true = false;
+ }
+ }
+ // Generate null joined rows for records which have no matching
join key
+ corrected_mask.append_n(expected_size - corrected_mask.len(),
false);
+ Some(corrected_mask.finish())
+ }
+ JoinType::Inner => {
+ // Inner joins don't need deferred filtering
+ None
+ }
+ }
+}
+
+/// Applies corrected filter mask to record batch based on join type
+///
+/// Different join types require different handling of filtered results:
+/// - Outer joins: Add null-joined rows for false mask values
+/// - Semi/Anti joins: May need projection to remove right columns
+/// - Full joins: Add null-joined rows for both sides
+pub fn filter_record_batch_by_join_type(
+ record_batch: &RecordBatch,
+ corrected_mask: &BooleanArray,
+ join_type: JoinType,
+ schema: &SchemaRef,
+ streamed_schema: &SchemaRef,
+ buffered_schema: &SchemaRef,
+) -> Result<RecordBatch> {
+ let filtered_record_batch = filter_record_batch(record_batch,
corrected_mask)?;
+
+ match join_type {
+ JoinType::Left | JoinType::LeftMark => {
+ // For left joins, add null-joined rows where mask is false
+ let null_mask = compute::not(corrected_mask)?;
+ let null_joined_batch = filter_record_batch(record_batch,
&null_mask)?;
+
+ if null_joined_batch.num_rows() == 0 {
+ return Ok(filtered_record_batch);
+ }
+
+ // Create null columns for right side
+ let null_joined_streamed_batch = create_null_joined_batch(
+ &null_joined_batch,
+ buffered_schema,
+ JoinSide::Left,
+ join_type,
+ schema,
+ )?;
+
+ Ok(concat_batches(
+ schema,
+ &[filtered_record_batch, null_joined_streamed_batch],
+ )?)
+ }
+ JoinType::LeftSemi
+ | JoinType::LeftAnti
+ | JoinType::RightSemi
+ | JoinType::RightAnti => {
+ // For semi/anti joins, project to only include the outer side
columns
+ // Both Left and Right semi/anti use streamed_schema.len() because:
+ // - For Left: columns are [left, right], so we take first
streamed_schema.len()
+ // - For Right: columns are [right, left], and streamed side is
right, so we take first streamed_schema.len()
+ let output_column_indices: Vec<usize> =
+ (0..streamed_schema.fields().len()).collect();
+ Ok(filtered_record_batch.project(&output_column_indices)?)
+ }
+ JoinType::Right | JoinType::RightMark => {
+ // For right joins, add null-joined rows where mask is false
+ let null_mask = compute::not(corrected_mask)?;
+ let null_joined_batch = filter_record_batch(record_batch,
&null_mask)?;
+
+ if null_joined_batch.num_rows() == 0 {
+ return Ok(filtered_record_batch);
+ }
+
+ // Create null columns for left side (buffered side for RIGHT join)
+ let null_joined_buffered_batch = create_null_joined_batch(
+ &null_joined_batch,
+ buffered_schema, // Pass buffered (left) schema to create
nulls for it
+ JoinSide::Right,
+ join_type,
+ schema,
+ )?;
+
+ Ok(concat_batches(
+ schema,
+ &[filtered_record_batch, null_joined_buffered_batch],
+ )?)
+ }
+ JoinType::Full => {
+ // For full joins, add null-joined rows for both sides
+ let joined_filter_not_matched_mask = compute::not(corrected_mask)?;
+ let joined_filter_not_matched_batch =
+ filter_record_batch(record_batch,
&joined_filter_not_matched_mask)?;
+
+ if joined_filter_not_matched_batch.num_rows() == 0 {
+ return Ok(filtered_record_batch);
+ }
+
+ // Create null-joined batches for both sides
+ let left_null_joined_batch = create_null_joined_batch(
+ &joined_filter_not_matched_batch,
+ buffered_schema,
+ JoinSide::Left,
+ join_type,
+ schema,
+ )?;
+
+ Ok(concat_batches(
+ schema,
+ &[filtered_record_batch, left_null_joined_batch],
+ )?)
+ }
+ JoinType::Inner => Ok(filtered_record_batch),
+ }
+}
+
+/// Creates a batch with null columns for the non-joined side
+///
+/// Note: The input `batch` is assumed to be a fully-joined batch that already
contains
+/// columns from both sides. We need to extract the data side columns and
replace the
+/// null side columns with actual nulls.
+fn create_null_joined_batch(
+ batch: &RecordBatch,
+ null_schema: &SchemaRef,
+ join_side: JoinSide,
+ join_type: JoinType,
+ output_schema: &SchemaRef,
+) -> Result<RecordBatch> {
+ let num_rows = batch.num_rows();
+
+ // The input batch is a fully-joined batch [left_cols..., right_cols...]
+ // We need to extract the appropriate side and replace the other with
nulls (or mark column)
+ let columns = match (join_side, join_type) {
+ (JoinSide::Left, JoinType::LeftMark) => {
+ // For LEFT mark: output is [left_cols..., mark_col]
+ // Batch is [left_cols..., right_cols...], extract left from
beginning
+ // Number of left columns = output columns - 1 (mark column)
+ let left_col_count = output_schema.fields().len() - 1;
+ let mut result: Vec<ArrayRef> =
batch.columns()[..left_col_count].to_vec();
+ result.push(Arc::new(BooleanArray::from(vec![false; num_rows])) as
ArrayRef);
+ result
+ }
+ (JoinSide::Right, JoinType::RightMark) => {
+ // For RIGHT mark: output is [right_cols..., mark_col]
+ // For RIGHT joins, batch is [right_cols..., left_cols...] (right
comes first!)
+ // Extract right columns from the beginning
+ let right_col_count = output_schema.fields().len() - 1; // -1 for
mark column
+ let mut result: Vec<ArrayRef> =
batch.columns()[..right_col_count].to_vec();
+ result.push(Arc::new(BooleanArray::from(vec![false; num_rows])) as
ArrayRef);
+ result
+ }
+ (JoinSide::Left, _) => {
+ // For LEFT join: output is [left_cols..., right_cols...]
+ // Extract left columns, then add null right columns
+ let null_columns: Vec<ArrayRef> = null_schema
+ .fields()
+ .iter()
+ .map(|field| arrow::array::new_null_array(field.data_type(),
num_rows))
+ .collect();
+ let left_col_count = output_schema.fields().len() -
null_columns.len();
+ let mut result: Vec<ArrayRef> =
batch.columns()[..left_col_count].to_vec();
+ result.extend(null_columns);
+ result
+ }
+ (JoinSide::Right, _) => {
+ // For RIGHT join: batch is [left_cols..., right_cols...] (same as
schema)
+ // We want: [null_left..., actual_right...]
+ // Extract left columns from beginning, replace with nulls, keep
right columns
+ let null_columns: Vec<ArrayRef> = null_schema
+ .fields()
+ .iter()
+ .map(|field| arrow::array::new_null_array(field.data_type(),
num_rows))
+ .collect();
+ let left_col_count = null_columns.len();
+ let mut result = null_columns;
+ // Extract right columns starting after left columns
+ result.extend_from_slice(&batch.columns()[left_col_count..]);
+ result
+ }
+ (JoinSide::None, _) => {
+ // This should not happen in normal join operations
+ unreachable!(
+ "JoinSide::None should not be used in null-joined batch
creation"
+ )
+ }
+ };
+
+ // Create the batch - don't validate nullability since outer joins can have
+ // null values in columns that were originally non-nullable
+ use arrow::array::RecordBatchOptions;
+ let mut options = RecordBatchOptions::new();
+ options = options.with_row_count(Some(num_rows));
+ Ok(RecordBatch::try_new_with_options(
+ Arc::clone(output_schema),
+ columns,
+ &options,
+ )?)
+}
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs
b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs
index 82f18e7414..06290ec4d0 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs
@@ -20,6 +20,7 @@
pub use exec::SortMergeJoinExec;
mod exec;
+mod filter;
mod metrics;
mod stream;
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
index b36992caf4..3a57dc6b41 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
@@ -33,6 +33,10 @@ use std::sync::atomic::AtomicUsize;
use std::sync::atomic::Ordering::Relaxed;
use std::task::{Context, Poll};
+use crate::joins::sort_merge_join::filter::{
+ FilterMetadata, filter_record_batch_by_join_type,
get_corrected_filter_mask,
+ get_filter_columns, needs_deferred_filtering,
+};
use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
use crate::joins::utils::{JoinFilter, compare_join_arrays};
use crate::metrics::RecordOutput;
@@ -42,15 +46,13 @@ use crate::{PhysicalExpr, RecordBatchStream,
SendableRecordBatchStream};
use arrow::array::{types::UInt64Type, *};
use arrow::compute::{
self, BatchCoalescer, SortOptions, concat_batches, filter_record_batch,
is_not_null,
- take,
+ take, take_arrays,
};
use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
-use arrow::error::ArrowError;
use arrow::ipc::reader::StreamReader;
use datafusion_common::config::SpillCompression;
use datafusion_common::{
- DataFusionError, HashSet, JoinSide, JoinType, NullEquality, Result,
exec_err,
- internal_err, not_impl_err,
+ HashSet, JoinType, NullEquality, Result, exec_err, internal_err,
not_impl_err,
};
use datafusion_execution::disk_manager::RefCountedTempFile;
use datafusion_execution::memory_pool::MemoryReservation;
@@ -68,6 +70,8 @@ pub(super) enum SortMergeJoinState {
Polling,
/// Joining polled data and making output
JoinOutput,
+ /// Emit ready data if have any and then go back to [`Self::Init`] state
+ EmitReadyThenInit,
/// No more output
Exhausted,
}
@@ -370,12 +374,8 @@ pub(super) struct SortMergeJoinStream {
pub(super) struct JoinedRecordBatches {
/// Joined batches. Each batch is already joined columns from left and
right sources
pub(super) joined_batches: BatchCoalescer,
- /// Did each output row pass the join filter? (detect if input row found
any match)
- pub(super) filter_mask: BooleanBuilder,
- /// Which input row (within batch) produced each output row? (for grouping
by input row)
- pub(super) row_indices: UInt64Builder,
- /// Which input batch did each output row come from? (disambiguate
row_indices)
- pub(super) batch_ids: Vec<usize>,
+ /// Filter metadata for deferred filtering
+ pub(super) filter_metadata: FilterMetadata,
}
impl JoinedRecordBatches {
@@ -398,61 +398,28 @@ impl JoinedRecordBatches {
}
}
- /// Finishes and returns the metadata arrays, clearing the builders
- ///
- /// Returns (row_indices, filter_mask, batch_ids_ref)
- /// Note: batch_ids is returned as a reference since it's still needed in
the struct
- fn finish_metadata(&mut self) -> (UInt64Array, BooleanArray, &[usize]) {
- let row_indices = self.row_indices.finish();
- let filter_mask = self.filter_mask.finish();
- (row_indices, filter_mask, &self.batch_ids)
- }
-
/// Clears batches without touching metadata (for early return when no
filtering needed)
fn clear_batches(&mut self, schema: &SchemaRef, batch_size: usize) {
self.joined_batches = BatchCoalescer::new(Arc::clone(schema),
batch_size)
.with_biggest_coalesce_batch_size(Option::from(batch_size / 2));
}
- /// Asserts that internal metadata arrays are consistent with each other
- /// Only checks if metadata is actually being used (i.e., not all empty)
- #[inline]
- fn debug_assert_metadata_aligned(&self) {
- // Metadata arrays should be aligned IF they're being used
- // (For non-filtered joins, they may all be empty)
- if self.filter_mask.len() > 0
- || self.row_indices.len() > 0
- || !self.batch_ids.is_empty()
- {
- debug_assert_eq!(
- self.filter_mask.len(),
- self.row_indices.len(),
- "filter_mask and row_indices must have same length when
metadata is used"
- );
- debug_assert_eq!(
- self.filter_mask.len(),
- self.batch_ids.len(),
- "filter_mask and batch_ids must have same length when metadata
is used"
- );
- }
- }
-
/// Asserts that if batches is empty, metadata is also empty
#[inline]
fn debug_assert_empty_consistency(&self) {
if self.joined_batches.is_empty() {
debug_assert_eq!(
- self.filter_mask.len(),
+ self.filter_metadata.filter_mask.len(),
0,
"filter_mask should be empty when batches is empty"
);
debug_assert_eq!(
- self.row_indices.len(),
+ self.filter_metadata.row_indices.len(),
0,
"row_indices should be empty when batches is empty"
);
debug_assert_eq!(
- self.batch_ids.len(),
+ self.filter_metadata.batch_ids.len(),
0,
"batch_ids should be empty when batches is empty"
);
@@ -473,14 +440,9 @@ impl JoinedRecordBatches {
let num_rows = batch.num_rows();
- self.filter_mask.append_nulls(num_rows);
- self.row_indices.append_nulls(num_rows);
- self.batch_ids.resize(
- self.batch_ids.len() + num_rows,
- 0, // batch_id = 0 for null-joined rows
- );
+ self.filter_metadata.append_nulls(num_rows);
- self.debug_assert_metadata_aligned();
+ self.filter_metadata.debug_assert_metadata_aligned();
self.joined_batches
.push_batch(batch)
.expect("Failed to push batch to BatchCoalescer");
@@ -525,13 +487,13 @@ impl JoinedRecordBatches {
"row_indices and filter_mask must have same length"
);
- // For Full joins, we keep the pre_mask (with nulls), for others we
keep the cleaned mask
- self.filter_mask.extend(filter_mask);
- self.row_indices.extend(row_indices);
- self.batch_ids
- .resize(self.batch_ids.len() + row_indices.len(),
streamed_batch_id);
+ self.filter_metadata.append_filter_metadata(
+ row_indices,
+ filter_mask,
+ streamed_batch_id,
+ );
- self.debug_assert_metadata_aligned();
+ self.filter_metadata.debug_assert_metadata_aligned();
self.joined_batches
.push_batch(batch)
.expect("Failed to push batch to BatchCoalescer");
@@ -551,9 +513,7 @@ impl JoinedRecordBatches {
fn clear(&mut self, schema: &SchemaRef, batch_size: usize) {
self.joined_batches = BatchCoalescer::new(Arc::clone(schema),
batch_size)
.with_biggest_coalesce_batch_size(Option::from(batch_size / 2));
- self.batch_ids.clear();
- self.filter_mask = BooleanBuilder::new();
- self.row_indices = UInt64Builder::new();
+ self.filter_metadata = FilterMetadata::new();
self.debug_assert_empty_consistency();
}
}
@@ -563,199 +523,6 @@ impl RecordBatchStream for SortMergeJoinStream {
}
}
-/// True if next index refers to either:
-/// - another batch id
-/// - another row index within same batch id
-/// - end of row indices
-#[inline(always)]
-fn last_index_for_row(
- row_index: usize,
- indices: &UInt64Array,
- batch_ids: &[usize],
- indices_len: usize,
-) -> bool {
- debug_assert_eq!(
- indices.len(),
- indices_len,
- "indices.len() should match indices_len parameter"
- );
- debug_assert_eq!(
- batch_ids.len(),
- indices_len,
- "batch_ids.len() should match indices_len"
- );
- debug_assert!(
- row_index < indices_len,
- "row_index {row_index} should be < indices_len {indices_len}",
- );
-
- row_index == indices_len - 1
- || batch_ids[row_index] != batch_ids[row_index + 1]
- || indices.value(row_index) != indices.value(row_index + 1)
-}
-
-// Returns a corrected boolean bitmask for the given join type
-// Values in the corrected bitmask can be: true, false, null
-// `true` - the row found its match and sent to the output
-// `null` - the row ignored, no output
-// `false` - the row sent as NULL joined row
-pub(super) fn get_corrected_filter_mask(
- join_type: JoinType,
- row_indices: &UInt64Array,
- batch_ids: &[usize],
- filter_mask: &BooleanArray,
- expected_size: usize,
-) -> Option<BooleanArray> {
- let row_indices_length = row_indices.len();
- let mut corrected_mask: BooleanBuilder =
- BooleanBuilder::with_capacity(row_indices_length);
- let mut seen_true = false;
-
- match join_type {
- JoinType::Left | JoinType::Right => {
- for i in 0..row_indices_length {
- let last_index =
- last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
- if filter_mask.value(i) {
- seen_true = true;
- corrected_mask.append_value(true);
- } else if seen_true || !filter_mask.value(i) && !last_index {
- corrected_mask.append_null(); // to be ignored and not set
to output
- } else {
- corrected_mask.append_value(false); // to be converted to
null joined row
- }
-
- if last_index {
- seen_true = false;
- }
- }
-
- // Generate null joined rows for records which have no matching
join key
- corrected_mask.append_n(expected_size - corrected_mask.len(),
false);
- Some(corrected_mask.finish())
- }
- JoinType::LeftMark | JoinType::RightMark => {
- for i in 0..row_indices_length {
- let last_index =
- last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
- if filter_mask.value(i) && !seen_true {
- seen_true = true;
- corrected_mask.append_value(true);
- } else if seen_true || !filter_mask.value(i) && !last_index {
- corrected_mask.append_null(); // to be ignored and not set
to output
- } else {
- corrected_mask.append_value(false); // to be converted to
null joined row
- }
-
- if last_index {
- seen_true = false;
- }
- }
-
- // Generate null joined rows for records which have no matching
join key
- corrected_mask.append_n(expected_size - corrected_mask.len(),
false);
- Some(corrected_mask.finish())
- }
- JoinType::LeftSemi | JoinType::RightSemi => {
- for i in 0..row_indices_length {
- let last_index =
- last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
- if filter_mask.value(i) && !seen_true {
- seen_true = true;
- corrected_mask.append_value(true);
- } else {
- corrected_mask.append_null(); // to be ignored and not set
to output
- }
-
- if last_index {
- seen_true = false;
- }
- }
-
- Some(corrected_mask.finish())
- }
- JoinType::LeftAnti | JoinType::RightAnti => {
- for i in 0..row_indices_length {
- let last_index =
- last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
-
- if filter_mask.value(i) {
- seen_true = true;
- }
-
- if last_index {
- if !seen_true {
- corrected_mask.append_value(true);
- } else {
- corrected_mask.append_null();
- }
-
- seen_true = false;
- } else {
- corrected_mask.append_null();
- }
- }
- // Generate null joined rows for records which have no matching
join key,
- // for LeftAnti non-matched considered as true
- corrected_mask.append_n(expected_size - corrected_mask.len(),
true);
- Some(corrected_mask.finish())
- }
- JoinType::Full => {
- let mut mask: Vec<Option<bool>> = vec![Some(true);
row_indices_length];
- let mut last_true_idx = 0;
- let mut first_row_idx = 0;
- let mut seen_false = false;
-
- for i in 0..row_indices_length {
- let last_index =
- last_index_for_row(i, row_indices, batch_ids,
row_indices_length);
- let val = filter_mask.value(i);
- let is_null = filter_mask.is_null(i);
-
- if val {
- // memoize the first seen matched row
- if !seen_true {
- last_true_idx = i;
- }
- seen_true = true;
- }
-
- if is_null || val {
- mask[i] = Some(true);
- } else if !is_null && !val && (seen_true || seen_false) {
- mask[i] = None;
- } else {
- mask[i] = Some(false);
- }
-
- if !is_null && !val {
- seen_false = true;
- }
-
- if last_index {
- // If the left row seen as true its needed to output it
once
- // To do that we mark all other matches for same row as
null to avoid the output
- if seen_true {
- #[expect(clippy::needless_range_loop)]
- for j in first_row_idx..last_true_idx {
- mask[j] = None;
- }
- }
-
- seen_true = false;
- seen_false = false;
- last_true_idx = 0;
- first_row_idx = i + 1;
- }
- }
-
- Some(BooleanArray::from(mask))
- }
- // Only outer joins needs to keep track of processed rows and apply
corrected filter mask
- _ => None,
- }
-}
-
impl Stream for SortMergeJoinStream {
type Item = Result<RecordBatch>;
@@ -778,7 +545,10 @@ impl Stream for SortMergeJoinStream {
match self.current_ordering {
Ordering::Less | Ordering::Equal => {
if !streamed_exhausted {
- if self.needs_deferred_filtering() {
+ if needs_deferred_filtering(
+ &self.filter,
+ self.join_type,
+ ) {
match self.process_filtered_batches()?
{
Poll::Ready(Some(batch)) => {
return
Poll::Ready(Some(Ok(batch)));
@@ -830,22 +600,56 @@ impl Stream for SortMergeJoinStream {
self.current_ordering = self.compare_streamed_buffered()?;
self.state = SortMergeJoinState::JoinOutput;
}
+ SortMergeJoinState::EmitReadyThenInit => {
+ // If have data to emit, emit it and if no more, change to
next
+
+ // Verify metadata alignment before checking if we have
batches to output
+ self.joined_record_batches
+ .filter_metadata
+ .debug_assert_metadata_aligned();
+
+ // For filtered joins, skip output and let Init state
handle it
+ if needs_deferred_filtering(&self.filter, self.join_type) {
+ self.state = SortMergeJoinState::Init;
+ continue;
+ }
+
+ // For non-filtered joins, only output if we have a
completed batch
+ // (opportunistic output when target batch size is reached)
+ if self
+ .joined_record_batches
+ .joined_batches
+ .has_completed_batch()
+ {
+ let record_batch = self
+ .joined_record_batches
+ .joined_batches
+ .next_completed_batch()
+ .expect("has_completed_batch was true");
+ (&record_batch)
+
.record_output(&self.join_metrics.baseline_metrics());
+ return Poll::Ready(Some(Ok(record_batch)));
+ }
+ self.state = SortMergeJoinState::Init;
+ }
SortMergeJoinState::JoinOutput => {
self.join_partial()?;
if self.num_unfrozen_pairs() < self.batch_size {
if self.buffered_data.scanning_finished() {
self.buffered_data.scanning_reset();
- self.state = SortMergeJoinState::Init;
+ self.state = SortMergeJoinState::EmitReadyThenInit;
}
} else {
self.freeze_all()?;
// Verify metadata alignment before checking if we
have batches to output
-
self.joined_record_batches.debug_assert_metadata_aligned();
+ self.joined_record_batches
+ .filter_metadata
+ .debug_assert_metadata_aligned();
// For filtered joins, skip output and let Init state
handle it
- if self.needs_deferred_filtering() {
+ if needs_deferred_filtering(&self.filter,
self.join_type) {
continue;
}
@@ -872,10 +676,12 @@ impl Stream for SortMergeJoinStream {
self.freeze_all()?;
// Verify metadata alignment before final output
- self.joined_record_batches.debug_assert_metadata_aligned();
+ self.joined_record_batches
+ .filter_metadata
+ .debug_assert_metadata_aligned();
// For filtered joins, must concat and filter ALL data at
once
- if self.needs_deferred_filtering()
+ if needs_deferred_filtering(&self.filter, self.join_type)
&&
!self.joined_record_batches.joined_batches.is_empty()
{
let record_batch = self.filter_joined_batch()?;
@@ -975,9 +781,7 @@ impl SortMergeJoinStream {
joined_record_batches: JoinedRecordBatches {
joined_batches: BatchCoalescer::new(Arc::clone(&schema),
batch_size)
.with_biggest_coalesce_batch_size(Option::from(batch_size
/ 2)),
- filter_mask: BooleanBuilder::new(),
- row_indices: UInt64Builder::new(),
- batch_ids: vec![],
+ filter_metadata: FilterMetadata::new(),
},
output: BatchCoalescer::new(schema, batch_size)
.with_biggest_coalesce_batch_size(Option::from(batch_size /
2)),
@@ -996,26 +800,6 @@ impl SortMergeJoinStream {
self.streamed_batch.num_output_rows()
}
- /// Returns true if this join needs deferred filtering
- ///
- /// Deferred filtering is needed when a filter exists and the join type
requires
- /// ensuring each input row produces at least one output row (or exactly
one for semi).
- fn needs_deferred_filtering(&self) -> bool {
- self.filter.is_some()
- && matches!(
- self.join_type,
- JoinType::Left
- | JoinType::LeftSemi
- | JoinType::LeftMark
- | JoinType::Right
- | JoinType::RightSemi
- | JoinType::RightMark
- | JoinType::LeftAnti
- | JoinType::RightAnti
- | JoinType::Full
- )
- }
-
/// Process accumulated batches for filtered joins
///
/// Freezes unfrozen pairs, applies deferred filtering, and outputs if
ready.
@@ -1023,7 +807,9 @@ impl SortMergeJoinStream {
fn process_filtered_batches(&mut self) ->
Poll<Option<Result<RecordBatch>>> {
self.freeze_all()?;
- self.joined_record_batches.debug_assert_metadata_aligned();
+ self.joined_record_batches
+ .filter_metadata
+ .debug_assert_metadata_aligned();
if !self.joined_record_batches.joined_batches.is_empty() {
let out_filtered_batch = self.filter_joined_batch()?;
@@ -1399,7 +1185,9 @@ impl SortMergeJoinStream {
self.freeze_streamed()?;
// After freezing, metadata should be aligned
- self.joined_record_batches.debug_assert_metadata_aligned();
+ self.joined_record_batches
+ .filter_metadata
+ .debug_assert_metadata_aligned();
Ok(())
}
@@ -1414,7 +1202,9 @@ impl SortMergeJoinStream {
self.freeze_buffered(1)?;
// After freezing, metadata should be aligned
- self.joined_record_batches.debug_assert_metadata_aligned();
+ self.joined_record_batches
+ .filter_metadata
+ .debug_assert_metadata_aligned();
Ok(())
}
@@ -1490,13 +1280,19 @@ impl SortMergeJoinStream {
continue;
}
- let mut left_columns = self
- .streamed_batch
- .batch
- .columns()
- .iter()
- .map(|column| take(column, &left_indices, None))
- .collect::<Result<Vec<_>, ArrowError>>()?;
+ let mut left_columns = if let Some(range) =
is_contiguous_range(&left_indices)
+ {
+ // When indices form a contiguous range (common for the
streamed
+ // side which advances sequentially), use zero-copy slice
instead
+ // of the O(n) take kernel.
+ self.streamed_batch
+ .batch
+ .slice(range.start, range.len())
+ .columns()
+ .to_vec()
+ } else {
+ take_arrays(self.streamed_batch.batch.columns(),
&left_indices, None)?
+ };
// The row indices of joined buffered batch
let right_indices: UInt64Array = chunk.buffered_indices.finish();
@@ -1541,7 +1337,7 @@ impl SortMergeJoinStream {
&right_indices,
)?;
- get_filter_column(&self.filter, &left_columns,
&right_cols)
+ get_filter_columns(&self.filter, &left_columns,
&right_cols)
} else if matches!(
self.join_type,
JoinType::RightAnti | JoinType::RightSemi |
JoinType::RightMark
@@ -1552,12 +1348,12 @@ impl SortMergeJoinStream {
&right_indices,
)?;
- get_filter_column(&self.filter, &right_cols,
&left_columns)
+ get_filter_columns(&self.filter, &right_cols,
&left_columns)
} else {
- get_filter_column(&self.filter, &left_columns,
&right_columns)
+ get_filter_columns(&self.filter, &left_columns,
&right_columns)
}
} else {
- get_filter_column(&self.filter, &right_columns,
&left_columns)
+ get_filter_columns(&self.filter, &right_columns,
&left_columns)
}
} else {
// This chunk is totally for null joined rows (outer join), we
don't need to apply join filter.
@@ -1679,11 +1475,13 @@ impl SortMergeJoinStream {
fn filter_joined_batch(&mut self) -> Result<RecordBatch> {
// Metadata should be aligned before processing
- self.joined_record_batches.debug_assert_metadata_aligned();
+ self.joined_record_batches
+ .filter_metadata
+ .debug_assert_metadata_aligned();
let record_batch =
self.joined_record_batches.concat_batches(&self.schema)?;
let (mut out_indices, mut out_mask, mut batch_ids) =
- self.joined_record_batches.finish_metadata();
+ self.joined_record_batches.filter_metadata.finish_metadata();
let default_batch_ids = vec![0; record_batch.num_rows()];
// If only nulls come in and indices sizes doesn't match with expected
record batch count
@@ -1754,139 +1552,14 @@ impl SortMergeJoinStream {
record_batch: &RecordBatch,
corrected_mask: &BooleanArray,
) -> Result<RecordBatch> {
- // Corrected mask should have length matching or exceeding
record_batch rows
- // (for outer joins it may be longer to include null-joined rows)
- debug_assert!(
- corrected_mask.len() >= record_batch.num_rows(),
- "corrected_mask length ({}) should be >= record_batch rows ({})",
- corrected_mask.len(),
- record_batch.num_rows()
- );
-
- let mut filtered_record_batch =
- filter_record_batch(record_batch, corrected_mask)?;
- let left_columns_length = self.streamed_schema.fields.len();
- let right_columns_length = self.buffered_schema.fields.len();
-
- if matches!(
+ let filtered_record_batch = filter_record_batch_by_join_type(
+ record_batch,
+ corrected_mask,
self.join_type,
- JoinType::Left | JoinType::LeftMark | JoinType::Right |
JoinType::RightMark
- ) {
- let null_mask = compute::not(corrected_mask)?;
- let null_joined_batch = filter_record_batch(record_batch,
&null_mask)?;
-
- let mut right_columns = create_unmatched_columns(
- self.join_type,
- &self.buffered_schema,
- null_joined_batch.num_rows(),
- );
-
- let columns = match self.join_type {
- JoinType::Right => {
- // The first columns are the right columns.
- let left_columns = null_joined_batch
- .columns()
- .iter()
- .skip(right_columns_length)
- .cloned()
- .collect::<Vec<_>>();
-
- right_columns.extend(left_columns);
- right_columns
- }
- JoinType::Left | JoinType::LeftMark | JoinType::RightMark => {
- // The first columns are the left columns.
- let mut left_columns = null_joined_batch
- .columns()
- .iter()
- .take(left_columns_length)
- .cloned()
- .collect::<Vec<_>>();
-
- left_columns.extend(right_columns);
- left_columns
- }
- _ => exec_err!("Did not expect join type {}", self.join_type)?,
- };
-
- // Push the streamed/buffered batch joined nulls to the output
- let null_joined_streamed_batch =
- RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
-
- filtered_record_batch = concat_batches(
- &self.schema,
- &[filtered_record_batch, null_joined_streamed_batch],
- )?;
- } else if matches!(
- self.join_type,
- JoinType::LeftSemi
- | JoinType::LeftAnti
- | JoinType::RightAnti
- | JoinType::RightSemi
- ) {
- let output_column_indices =
(0..left_columns_length).collect::<Vec<_>>();
- filtered_record_batch =
- filtered_record_batch.project(&output_column_indices)?;
- } else if matches!(self.join_type, JoinType::Full)
- && corrected_mask.false_count() > 0
- {
- // Find rows which joined by key but Filter predicate evaluated as
false
- let joined_filter_not_matched_mask = compute::not(corrected_mask)?;
- let joined_filter_not_matched_batch =
- filter_record_batch(record_batch,
&joined_filter_not_matched_mask)?;
-
- // Add left unmatched rows adding the right side as nulls
- let right_null_columns = self
- .buffered_schema
- .fields()
- .iter()
- .map(|f| {
- new_null_array(
- f.data_type(),
- joined_filter_not_matched_batch.num_rows(),
- )
- })
- .collect::<Vec<_>>();
-
- let mut result_joined = joined_filter_not_matched_batch
- .columns()
- .iter()
- .take(left_columns_length)
- .cloned()
- .collect::<Vec<_>>();
-
- result_joined.extend(right_null_columns);
-
- let left_null_joined_batch =
- RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?;
-
- // Add right unmatched rows adding the left side as nulls
- let mut result_joined = self
- .streamed_schema
- .fields()
- .iter()
- .map(|f| {
- new_null_array(
- f.data_type(),
- joined_filter_not_matched_batch.num_rows(),
- )
- })
- .collect::<Vec<_>>();
-
- let right_data = joined_filter_not_matched_batch
- .columns()
- .iter()
- .skip(left_columns_length)
- .cloned()
- .collect::<Vec<_>>();
-
- result_joined.extend(right_data);
-
- filtered_record_batch = concat_batches(
- &self.schema,
- &[filtered_record_batch, left_null_joined_batch],
- )?;
- }
+ &self.schema,
+ &self.streamed_schema,
+ &self.buffered_schema,
+ )?;
self.joined_record_batches
.clear(&self.schema, self.batch_size);
@@ -1911,36 +1584,6 @@ fn create_unmatched_columns(
}
}
-/// Gets the arrays which join filters are applied on.
-fn get_filter_column(
- join_filter: &Option<JoinFilter>,
- streamed_columns: &[ArrayRef],
- buffered_columns: &[ArrayRef],
-) -> Vec<ArrayRef> {
- let mut filter_columns = vec![];
-
- if let Some(f) = join_filter {
- let left_columns = f
- .column_indices()
- .iter()
- .filter(|col_index| col_index.side == JoinSide::Left)
- .map(|i| Arc::clone(&streamed_columns[i.index]))
- .collect::<Vec<_>>();
-
- let right_columns = f
- .column_indices()
- .iter()
- .filter(|col_index| col_index.side == JoinSide::Right)
- .map(|i| Arc::clone(&buffered_columns[i.index]))
- .collect::<Vec<_>>();
-
- filter_columns.extend(left_columns);
- filter_columns.extend(right_columns);
- }
-
- filter_columns
-}
-
fn produce_buffered_null_batch(
schema: &SchemaRef,
streamed_schema: &SchemaRef,
@@ -1970,6 +1613,30 @@ fn produce_buffered_null_batch(
)?))
}
+/// Checks if a `UInt64Array` contains a contiguous ascending range (e.g.
\[3,4,5,6\]).
+/// Returns `Some(start..start+len)` if so, `None` otherwise.
+/// This allows replacing an O(n) `take` with an O(1) `slice`.
+#[inline]
+fn is_contiguous_range(indices: &UInt64Array) -> Option<Range<usize>> {
+ if indices.is_empty() || indices.null_count() > 0 {
+ return None;
+ }
+ let values = indices.values();
+ let start = values[0];
+ let len = values.len() as u64;
+ // Quick rejection: if last element doesn't match expected, not contiguous
+ if values[values.len() - 1] != start + len - 1 {
+ return None;
+ }
+ // Verify every element is sequential (handles duplicates and gaps)
+ for i in 1..values.len() {
+ if values[i] != start + i as u64 {
+ return None;
+ }
+ }
+ Some(start as usize..(start + len) as usize)
+}
+
/// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by
specific column indices
#[inline(always)]
fn fetch_right_columns_by_idxs(
@@ -1990,12 +1657,16 @@ fn fetch_right_columns_from_batch_by_idxs(
) -> Result<Vec<ArrayRef>> {
match &buffered_batch.batch {
// In memory batch
- BufferedBatchState::InMemory(batch) => Ok(batch
- .columns()
- .iter()
- .map(|column| take(column, &buffered_indices, None))
- .collect::<Result<Vec<_>, ArrowError>>()
- .map_err(Into::<DataFusionError>::into)?),
+ // In memory batch
+ BufferedBatchState::InMemory(batch) => {
+ // When indices form a contiguous range (common in SMJ since the
+ // buffered side is scanned sequentially), use zero-copy slice.
+ if let Some(range) = is_contiguous_range(buffered_indices) {
+ Ok(batch.slice(range.start, range.len()).columns().to_vec())
+ } else {
+ Ok(take_arrays(batch.columns(), buffered_indices, None)?)
+ }
+ }
// If the batch was spilled to disk, less likely
BufferedBatchState::Spilled(spill_file) => {
let mut buffered_cols: Vec<ArrayRef> =
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
index d0bcc79636..4329abdd52 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
@@ -24,42 +24,44 @@
//!
//! Add relevant tests under the specified sections.
-use std::sync::Arc;
-
+use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
+use crate::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec};
+use crate::test::TestMemoryExec;
+use crate::test::exec::BarrierExec;
+use crate::test::{build_table_i32, build_table_i32_two_cols};
+use crate::{ExecutionPlan, common};
+use crate::{
+ expressions::Column,
joins::sort_merge_join::filter::get_corrected_filter_mask,
+ joins::sort_merge_join::stream::JoinedRecordBatches,
+};
use arrow::array::{
BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray,
Int32Array, RecordBatch, UInt64Array,
- builder::{BooleanBuilder, UInt64Builder},
};
use arrow::compute::{BatchCoalescer, SortOptions, filter_record_batch};
use arrow::datatypes::{DataType, Field, Schema};
-
+use arrow_ord::sort::SortColumn;
+use arrow_schema::SchemaRef;
use datafusion_common::JoinType::*;
use datafusion_common::{
- JoinSide,
+ JoinSide, internal_err,
test_util::{batches_to_sort_string, batches_to_string},
};
use datafusion_common::{
JoinType, NullEquality, Result, assert_batches_eq, assert_contains,
};
-use datafusion_execution::TaskContext;
+use datafusion_common_runtime::JoinSet;
use datafusion_execution::config::SessionConfig;
use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
use datafusion_execution::runtime_env::RuntimeEnvBuilder;
+use datafusion_execution::{SendableRecordBatchStream, TaskContext};
use datafusion_expr::Operator;
use datafusion_physical_expr::expressions::BinaryExpr;
+use futures::StreamExt;
use insta::{allow_duplicates, assert_snapshot};
-
-use crate::{
- expressions::Column,
- joins::sort_merge_join::stream::{JoinedRecordBatches,
get_corrected_filter_mask},
-};
-
-use crate::joins::SortMergeJoinExec;
-use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
-use crate::test::TestMemoryExec;
-use crate::test::{build_table_i32, build_table_i32_two_cols};
-use crate::{ExecutionPlan, common};
+use itertools::Itertools;
+use std::sync::Arc;
+use std::task::Poll;
fn build_table(
a: (&str, &Vec<i32>),
@@ -2375,9 +2377,7 @@ fn build_joined_record_batches() ->
Result<JoinedRecordBatches> {
let mut batches = JoinedRecordBatches {
joined_batches: BatchCoalescer::new(Arc::clone(&schema), 8192),
- filter_mask: BooleanBuilder::new(),
- row_indices: UInt64Builder::new(),
- batch_ids: vec![],
+ filter_metadata:
crate::joins::sort_merge_join::filter::FilterMetadata::new(),
};
// Insert already prejoined non-filtered rows
@@ -2432,44 +2432,73 @@ fn build_joined_record_batches() ->
Result<JoinedRecordBatches> {
)?)?;
let streamed_indices = vec![0, 0];
- batches.batch_ids.extend(vec![0; streamed_indices.len()]);
batches
+ .filter_metadata
+ .batch_ids
+ .extend(vec![0; streamed_indices.len()]);
+ batches
+ .filter_metadata
.row_indices
.extend(&UInt64Array::from(streamed_indices));
let streamed_indices = vec![1];
- batches.batch_ids.extend(vec![0; streamed_indices.len()]);
batches
+ .filter_metadata
+ .batch_ids
+ .extend(vec![0; streamed_indices.len()]);
+ batches
+ .filter_metadata
.row_indices
.extend(&UInt64Array::from(streamed_indices));
let streamed_indices = vec![0, 0];
- batches.batch_ids.extend(vec![1; streamed_indices.len()]);
batches
+ .filter_metadata
+ .batch_ids
+ .extend(vec![1; streamed_indices.len()]);
+ batches
+ .filter_metadata
.row_indices
.extend(&UInt64Array::from(streamed_indices));
let streamed_indices = vec![0];
- batches.batch_ids.extend(vec![2; streamed_indices.len()]);
batches
+ .filter_metadata
+ .batch_ids
+ .extend(vec![2; streamed_indices.len()]);
+ batches
+ .filter_metadata
.row_indices
.extend(&UInt64Array::from(streamed_indices));
let streamed_indices = vec![0, 0];
- batches.batch_ids.extend(vec![3; streamed_indices.len()]);
batches
+ .filter_metadata
+ .batch_ids
+ .extend(vec![3; streamed_indices.len()]);
+ batches
+ .filter_metadata
.row_indices
.extend(&UInt64Array::from(streamed_indices));
batches
+ .filter_metadata
.filter_mask
.extend(&BooleanArray::from(vec![true, false]));
- batches.filter_mask.extend(&BooleanArray::from(vec![true]));
batches
+ .filter_metadata
+ .filter_mask
+ .extend(&BooleanArray::from(vec![true]));
+ batches
+ .filter_metadata
.filter_mask
.extend(&BooleanArray::from(vec![false, true]));
- batches.filter_mask.extend(&BooleanArray::from(vec![false]));
batches
+ .filter_metadata
+ .filter_mask
+ .extend(&BooleanArray::from(vec![false]));
+ batches
+ .filter_metadata
.filter_mask
.extend(&BooleanArray::from(vec![false, false]));
@@ -2482,8 +2511,8 @@ async fn test_left_outer_join_filtered_mask() ->
Result<()> {
let schema = joined_batches.joined_batches.schema();
let output = joined_batches.concat_batches(&schema)?;
- let out_mask = joined_batches.filter_mask.finish();
- let out_indices = joined_batches.row_indices.finish();
+ let out_mask = joined_batches.filter_metadata.filter_mask.finish();
+ let out_indices = joined_batches.filter_metadata.row_indices.finish();
assert_eq!(
get_corrected_filter_mask(
@@ -2620,7 +2649,7 @@ async fn test_left_outer_join_filtered_mask() ->
Result<()> {
let corrected_mask = get_corrected_filter_mask(
Left,
&out_indices,
- &joined_batches.batch_ids,
+ &joined_batches.filter_metadata.batch_ids,
&out_mask,
output.num_rows(),
)
@@ -2689,8 +2718,8 @@ async fn test_semi_join_filtered_mask() -> Result<()> {
let schema = joined_batches.joined_batches.schema();
let output = joined_batches.concat_batches(&schema)?;
- let out_mask = joined_batches.filter_mask.finish();
- let out_indices = joined_batches.row_indices.finish();
+ let out_mask = joined_batches.filter_metadata.filter_mask.finish();
+ let out_indices = joined_batches.filter_metadata.row_indices.finish();
assert_eq!(
get_corrected_filter_mask(
@@ -2791,7 +2820,7 @@ async fn test_semi_join_filtered_mask() -> Result<()> {
let corrected_mask = get_corrected_filter_mask(
join_type,
&out_indices,
- &joined_batches.batch_ids,
+ &joined_batches.filter_metadata.batch_ids,
&out_mask,
output.num_rows(),
)
@@ -2864,8 +2893,8 @@ async fn test_anti_join_filtered_mask() -> Result<()> {
let schema = joined_batches.joined_batches.schema();
let output = joined_batches.concat_batches(&schema)?;
- let out_mask = joined_batches.filter_mask.finish();
- let out_indices = joined_batches.row_indices.finish();
+ let out_mask = joined_batches.filter_metadata.filter_mask.finish();
+ let out_indices = joined_batches.filter_metadata.row_indices.finish();
assert_eq!(
get_corrected_filter_mask(
@@ -2966,7 +2995,7 @@ async fn test_anti_join_filtered_mask() -> Result<()> {
let corrected_mask = get_corrected_filter_mask(
join_type,
&out_indices,
- &joined_batches.batch_ids,
+ &joined_batches.filter_metadata.batch_ids,
&out_mask,
output.num_rows(),
)
@@ -3104,6 +3133,419 @@ fn test_partition_statistics() -> Result<()> {
Ok(())
}
+fn build_batches(
+ a: (&str, &[Vec<bool>]),
+ b: (&str, &[Vec<i32>]),
+ c: (&str, &[Vec<i32>]),
+) -> (Vec<RecordBatch>, SchemaRef) {
+ assert_eq!(a.1.len(), b.1.len());
+ let mut batches = vec![];
+
+ let schema = Arc::new(Schema::new(vec![
+ Field::new(a.0, DataType::Boolean, false),
+ Field::new(b.0, DataType::Int32, false),
+ Field::new(c.0, DataType::Int32, false),
+ ]));
+
+ for i in 0..a.1.len() {
+ batches.push(
+ RecordBatch::try_new(
+ Arc::clone(&schema),
+ vec![
+ Arc::new(BooleanArray::from(a.1[i].clone())),
+ Arc::new(Int32Array::from(b.1[i].clone())),
+ Arc::new(Int32Array::from(c.1[i].clone())),
+ ],
+ )
+ .unwrap(),
+ );
+ }
+ let schema = batches[0].schema();
+ (batches, schema)
+}
+
+fn build_batched_finish_barrier_table(
+ a: (&str, &[Vec<bool>]),
+ b: (&str, &[Vec<i32>]),
+ c: (&str, &[Vec<i32>]),
+) -> (Arc<BarrierExec>, Arc<TestMemoryExec>) {
+ let (batches, schema) = build_batches(a, b, c);
+
+ let memory_exec = TestMemoryExec::try_new_exec(
+ std::slice::from_ref(&batches),
+ Arc::clone(&schema),
+ None,
+ )
+ .unwrap();
+
+ let barrier_exec = Arc::new(
+ BarrierExec::new(vec![batches], schema)
+ .with_log(false)
+ .without_start_barrier()
+ .with_finish_barrier(),
+ );
+
+ (barrier_exec, memory_exec)
+}
+
+/// Concat and sort batches by all the columns to make sure we can compare
them with different join
+fn prepare_record_batches_for_cmp(output: Vec<RecordBatch>) -> RecordBatch {
+ let output_batch = arrow::compute::concat_batches(output[0].schema_ref(),
&output)
+ .expect("failed to concat batches");
+
+ // Sort on all columns to make sure we have a deterministic order for the
assertion
+ let sort_columns = output_batch
+ .columns()
+ .iter()
+ .map(|c| SortColumn {
+ values: Arc::clone(c),
+ options: None,
+ })
+ .collect::<Vec<_>>();
+
+ let sorted_columns =
+ arrow::compute::lexsort(&sort_columns, None).expect("failed to sort");
+
+ RecordBatch::try_new(output_batch.schema(), sorted_columns)
+ .expect("failed to create batch")
+}
+
+#[expect(clippy::too_many_arguments)]
+async fn join_get_stream_and_get_expected(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ oracle_left: Arc<dyn ExecutionPlan>,
+ oracle_right: Arc<dyn ExecutionPlan>,
+ on: JoinOn,
+ join_type: JoinType,
+ filter: Option<JoinFilter>,
+ batch_size: usize,
+) -> Result<(SendableRecordBatchStream, RecordBatch)> {
+ let sort_options = vec![SortOptions::default(); on.len()];
+ let null_equality = NullEquality::NullEqualsNothing;
+ let task_ctx = Arc::new(
+ TaskContext::default()
+
.with_session_config(SessionConfig::default().with_batch_size(batch_size)),
+ );
+
+ let expected_output = {
+ let oracle = HashJoinExec::try_new(
+ oracle_left,
+ oracle_right,
+ on.clone(),
+ filter.clone(),
+ &join_type,
+ None,
+ PartitionMode::Partitioned,
+ null_equality,
+ )?;
+
+ let stream = oracle.execute(0, Arc::clone(&task_ctx))?;
+
+ let batches = common::collect(stream).await?;
+
+ prepare_record_batches_for_cmp(batches)
+ };
+
+ let join = SortMergeJoinExec::try_new(
+ left,
+ right,
+ on,
+ filter,
+ join_type,
+ sort_options,
+ null_equality,
+ )?;
+
+ let stream = join.execute(0, task_ctx)?;
+
+ Ok((stream, expected_output))
+}
+
+fn generate_data_for_emit_early_test(
+ batch_size: usize,
+ number_of_batches: usize,
+ join_type: JoinType,
+) -> (
+ Arc<BarrierExec>,
+ Arc<BarrierExec>,
+ Arc<TestMemoryExec>,
+ Arc<TestMemoryExec>,
+) {
+ let number_of_rows_per_batch = number_of_batches * batch_size;
+ // Prepare data
+ let left_a1 = (0..number_of_rows_per_batch as i32)
+ .chunks(batch_size)
+ .into_iter()
+ .map(|chunk| chunk.collect::<Vec<_>>())
+ .collect::<Vec<_>>();
+ let left_b1 = (0..1000000)
+ .filter(|item| {
+ match join_type {
+ LeftAnti | RightAnti => {
+ let remainder = item % (batch_size as i32);
+
+ // Make sure to have one that match and one that don't
+ remainder == 0 || remainder == 1
+ }
+ // Have at least 1 that is not matching
+ _ => item % batch_size as i32 != 0,
+ }
+ })
+ .take(number_of_rows_per_batch)
+ .chunks(batch_size)
+ .into_iter()
+ .map(|chunk| chunk.collect::<Vec<_>>())
+ .collect::<Vec<_>>();
+
+ let left_bool_col1 = left_a1
+ .clone()
+ .into_iter()
+ .map(|b| {
+ b.into_iter()
+ // Mostly true but have some false that not overlap with the
right column
+ .map(|a| a % (batch_size as i32) != (batch_size as i32) - 2)
+ .collect::<Vec<_>>()
+ })
+ .collect::<Vec<_>>();
+
+ let (left, left_memory) = build_batched_finish_barrier_table(
+ ("bool_col1", left_bool_col1.as_slice()),
+ ("b1", left_b1.as_slice()),
+ ("a1", left_a1.as_slice()),
+ );
+
+ let right_a2 = (0..number_of_rows_per_batch as i32)
+ .map(|item| item * 11)
+ .chunks(batch_size)
+ .into_iter()
+ .map(|chunk| chunk.collect::<Vec<_>>())
+ .collect::<Vec<_>>();
+ let right_b1 = (0..1000000)
+ .filter(|item| {
+ match join_type {
+ LeftAnti | RightAnti => {
+ let remainder = item % (batch_size as i32);
+
+ // Make sure to have one that match and one that don't
+ remainder == 1 || remainder == 2
+ }
+ // Have at least 1 that is not matching
+ _ => item % batch_size as i32 != 1,
+ }
+ })
+ .take(number_of_rows_per_batch)
+ .chunks(batch_size)
+ .into_iter()
+ .map(|chunk| chunk.collect::<Vec<_>>())
+ .collect::<Vec<_>>();
+ let right_bool_col2 = right_a2
+ .clone()
+ .into_iter()
+ .map(|b| {
+ b.into_iter()
+ // Mostly true but have some false that not overlap with the
left column
+ .map(|a| a % (batch_size as i32) != (batch_size as i32) - 1)
+ .collect::<Vec<_>>()
+ })
+ .collect::<Vec<_>>();
+
+ let (right, right_memory) = build_batched_finish_barrier_table(
+ ("bool_col2", right_bool_col2.as_slice()),
+ ("b1", right_b1.as_slice()),
+ ("a2", right_a2.as_slice()),
+ );
+
+ (left, right, left_memory, right_memory)
+}
+
+#[tokio::test]
+async fn test_should_emit_early_when_have_enough_data_to_emit() -> Result<()> {
+ for with_filtering in [false, true] {
+ let join_types = vec![
+ Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark,
RightMark,
+ ];
+ const BATCH_SIZE: usize = 10;
+ for join_type in join_types {
+ for output_batch_size in [
+ BATCH_SIZE / 3,
+ BATCH_SIZE / 2,
+ BATCH_SIZE,
+ BATCH_SIZE * 2,
+ BATCH_SIZE * 3,
+ ] {
+ // Make sure the number of batches is enough for all join type
to emit some output
+ let number_of_batches = if output_batch_size <= BATCH_SIZE {
+ 100
+ } else {
+ // Have enough batches
+ (output_batch_size * 100) / BATCH_SIZE
+ };
+
+ let (left, right, left_memory, right_memory) =
+ generate_data_for_emit_early_test(
+ BATCH_SIZE,
+ number_of_batches,
+ join_type,
+ );
+
+ let on = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?)
as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?)
as _,
+ )];
+
+ let join_filter = if with_filtering {
+ let filter = JoinFilter::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("bool_col1", 0)),
+ Operator::And,
+ Arc::new(Column::new("bool_col2", 1)),
+ )),
+ vec![
+ ColumnIndex {
+ index: 0,
+ side: JoinSide::Left,
+ },
+ ColumnIndex {
+ index: 0,
+ side: JoinSide::Right,
+ },
+ ],
+ Arc::new(Schema::new(vec![
+ Field::new("bool_col1", DataType::Boolean, true),
+ Field::new("bool_col2", DataType::Boolean, true),
+ ])),
+ );
+ Some(filter)
+ } else {
+ None
+ };
+
+ // select *
+ // from t1
+ // right join t2 on t1.b1 = t2.b1 and t1.bool_col1 AND
t2.bool_col2
+ let (mut output_stream, expected) =
join_get_stream_and_get_expected(
+ Arc::clone(&left) as Arc<dyn ExecutionPlan>,
+ Arc::clone(&right) as Arc<dyn ExecutionPlan>,
+ left_memory as Arc<dyn ExecutionPlan>,
+ right_memory as Arc<dyn ExecutionPlan>,
+ on,
+ join_type,
+ join_filter,
+ output_batch_size,
+ )
+ .await?;
+
+ let (output_batched, output_batches_after_finish) =
+ consume_stream_until_finish_barrier_reached(left, right,
&mut output_stream).await.unwrap_or_else(|e| panic!("Failed to consume stream
for join type: '{join_type}' and with filtering '{with_filtering}': {e:?}"));
+
+ // It should emit more than that, but we are being generous
+ // and to make sure the test pass for all
+ const MINIMUM_OUTPUT_BATCHES: usize = 5;
+ assert!(
+ MINIMUM_OUTPUT_BATCHES <= number_of_batches / 5,
+ "Make sure that the minimum output batches is realistic"
+ );
+ // Test to make sure that we are not waiting for input to be
fully consumed to emit some output
+ assert!(
+ output_batched.len() >= MINIMUM_OUTPUT_BATCHES,
+ "[Sort Merge Join {join_type}] Stream must have at least
emit {} batches, but only got {} batches",
+ MINIMUM_OUTPUT_BATCHES,
+ output_batched.len()
+ );
+
+ // Just sanity test to make sure we are still producing valid
output
+ {
+ let output = [output_batched,
output_batches_after_finish].concat();
+ let actual_prepared =
prepare_record_batches_for_cmp(output);
+
+ assert_eq!(actual_prepared.columns(), expected.columns());
+ }
+ }
+ }
+ }
+ Ok(())
+}
+
+/// Polls the stream until both barriers are reached,
+/// collecting the emitted batches along the way.
+///
+/// If the stream is pending for too long (5s) without emitting any batches,
+/// it panics to avoid hanging the test indefinitely.
+///
+/// Note: The left and right BarrierExec might be the input of the output
stream
+async fn consume_stream_until_finish_barrier_reached(
+ left: Arc<BarrierExec>,
+ right: Arc<BarrierExec>,
+ output_stream: &mut SendableRecordBatchStream,
+) -> Result<(Vec<RecordBatch>, Vec<RecordBatch>)> {
+ let mut switch_to_finish_barrier = false;
+ let mut output_batched = vec![];
+ let mut after_finish_barrier_reached = vec![];
+ let mut background_task = JoinSet::new();
+
+ let mut start_time_since_last_ready =
datafusion_common::instant::Instant::now();
+ loop {
+ let next_item = output_stream.next();
+
+ // Manual polling
+ let poll_output = futures::poll!(next_item);
+
+ // Wake up the stream to make sure it makes progress
+ tokio::task::yield_now().await;
+
+ match poll_output {
+ Poll::Ready(Some(Ok(batch))) => {
+ if batch.num_rows() == 0 {
+ return internal_err!("join stream should not emit empty
batch");
+ }
+ if switch_to_finish_barrier {
+ after_finish_barrier_reached.push(batch);
+ } else {
+ output_batched.push(batch);
+ }
+ start_time_since_last_ready =
datafusion_common::instant::Instant::now();
+ }
+ Poll::Ready(Some(Err(e))) => return Err(e),
+ Poll::Ready(None) if !switch_to_finish_barrier => {
+ unreachable!("Stream should not end before manually finishing
it")
+ }
+ Poll::Ready(None) => {
+ break;
+ }
+ Poll::Pending => {
+ if right.is_finish_barrier_reached()
+ && left.is_finish_barrier_reached()
+ && !switch_to_finish_barrier
+ {
+ switch_to_finish_barrier = true;
+
+ let right = Arc::clone(&right);
+ background_task.spawn(async move {
+ right.wait_finish().await;
+ });
+ let left = Arc::clone(&left);
+ background_task.spawn(async move {
+ left.wait_finish().await;
+ });
+ }
+
+ // Make sure the test doesn't run forever
+ if start_time_since_last_ready.elapsed()
+ > std::time::Duration::from_secs(5)
+ {
+ return internal_err!(
+ "Stream should have emitted data by now, but it's
still pending. Output batches so far: {}",
+ output_batched.len()
+ );
+ }
+ }
+ }
+ }
+
+ Ok((output_batched, after_finish_barrier_reached))
+}
+
/// Returns the column names on the schema
fn columns(schema: &Schema) -> Vec<String> {
schema.fields().iter().map(|f| f.name().clone()).collect()
diff --git a/datafusion/physical-plan/src/test/exec.rs
b/datafusion/physical-plan/src/test/exec.rs
index 4507cccba0..c229318415 100644
--- a/datafusion/physical-plan/src/test/exec.rs
+++ b/datafusion/physical-plan/src/test/exec.rs
@@ -17,13 +17,6 @@
//! Simple iterator over batches for use in testing
-use std::{
- any::Any,
- pin::Pin,
- sync::{Arc, Weak},
- task::{Context, Poll},
-};
-
use crate::{
DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
RecordBatchStream, SendableRecordBatchStream, Statistics, common,
@@ -33,6 +26,13 @@ use crate::{
execution_plan::EmissionType,
stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter},
};
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::{
+ any::Any,
+ pin::Pin,
+ sync::{Arc, Weak},
+ task::{Context, Poll},
+};
use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
use arrow::record_batch::RecordBatch;
@@ -298,29 +298,91 @@ pub struct BarrierExec {
schema: SchemaRef,
/// all streams wait on this barrier to produce
- barrier: Arc<Barrier>,
+ start_data_barrier: Option<Arc<Barrier>>,
+
+ /// the stream wait for this to return Poll::Ready(None)
+ finish_barrier: Option<Arc<(Barrier, AtomicUsize)>>,
+
cache: PlanProperties,
+
+ log: bool,
}
impl BarrierExec {
/// Create a new exec with some number of partitions.
pub fn new(data: Vec<Vec<RecordBatch>>, schema: SchemaRef) -> Self {
// wait for all streams and the input
- let barrier = Arc::new(Barrier::new(data.len() + 1));
+ let barrier = Some(Arc::new(Barrier::new(data.len() + 1)));
let cache = Self::compute_properties(Arc::clone(&schema), &data);
Self {
data,
schema,
- barrier,
+ start_data_barrier: barrier,
cache,
+ finish_barrier: None,
+ log: true,
}
}
+ pub fn with_log(mut self, log: bool) -> Self {
+ self.log = log;
+ self
+ }
+
+ pub fn without_start_barrier(mut self) -> Self {
+ self.start_data_barrier = None;
+ self
+ }
+
+ pub fn with_finish_barrier(mut self) -> Self {
+ let barrier = Arc::new((
+ // wait for all streams and the input
+ Barrier::new(self.data.len() + 1),
+ AtomicUsize::new(0),
+ ));
+
+ self.finish_barrier = Some(barrier);
+ self
+ }
+
/// wait until all the input streams and this function is ready
pub async fn wait(&self) {
- println!("BarrierExec::wait waiting on barrier");
- self.barrier.wait().await;
- println!("BarrierExec::wait done waiting");
+ let barrier = &self
+ .start_data_barrier
+ .as_ref()
+ .expect("Must only be called when having a start barrier");
+ if self.log {
+ println!("BarrierExec::wait waiting on barrier");
+ }
+ barrier.wait().await;
+ if self.log {
+ println!("BarrierExec::wait done waiting");
+ }
+ }
+
+ pub async fn wait_finish(&self) {
+ let (barrier, _) = &self
+ .finish_barrier
+ .as_deref()
+ .expect("Must only be called when having a finish barrier");
+
+ if self.log {
+ println!("BarrierExec::wait_finish waiting on barrier");
+ }
+ barrier.wait().await;
+ if self.log {
+ println!("BarrierExec::wait_finish done waiting");
+ }
+ }
+
+ /// Return true if the finish barrier has been reached in all partitions
+ pub fn is_finish_barrier_reached(&self) -> bool {
+ let (_, reached_finish) = self
+ .finish_barrier
+ .as_deref()
+ .expect("Must only be called when having finish barrier");
+
+ reached_finish.load(Ordering::Relaxed) == self.data.len()
}
/// This function creates the cache object that stores the plan properties
such as schema, equivalence properties, ordering, partitioning, etc.
@@ -391,17 +453,32 @@ impl ExecutionPlan for BarrierExec {
// task simply sends data in order after barrier is reached
let data = self.data[partition].clone();
- let b = Arc::clone(&self.barrier);
+ let start_barrier = self.start_data_barrier.as_ref().map(Arc::clone);
+ let finish_barrier = self.finish_barrier.as_ref().map(Arc::clone);
+ let log = self.log;
let tx = builder.tx();
builder.spawn(async move {
- println!("Partition {partition} waiting on barrier");
- b.wait().await;
+ if let Some(barrier) = start_barrier {
+ if log {
+ println!("Partition {partition} waiting on barrier");
+ }
+ barrier.wait().await;
+ }
for batch in data {
- println!("Partition {partition} sending batch");
+ if log {
+ println!("Partition {partition} sending batch");
+ }
if let Err(e) = tx.send(Ok(batch)).await {
println!("ERROR batch via barrier stream stream: {e}");
}
}
+ if let Some((barrier, reached_finish)) = finish_barrier.as_deref()
{
+ if log {
+ println!("Partition {partition} waiting on finish
barrier");
+ }
+ reached_finish.fetch_add(1, Ordering::Relaxed);
+ barrier.wait().await;
+ }
Ok(())
});
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]