This is an automated email from the ASF dual-hosted git repository.

mbutrovich pushed a commit to branch branch-52
in repository https://gitbox.apache.org/repos/asf/datafusion.git


The following commit(s) were added to refs/heads/branch-52 by this push:
     new 19a0fcaa27 [branch-52] SortMergeJoin don't wait for all input before 
emitting (#20699)
19a0fcaa27 is described below

commit 19a0fcaa276c86beda544c6e01c75f6e0639767e
Author: Matt Butrovich <[email protected]>
AuthorDate: Wed Mar 4 13:58:50 2026 -0500

    [branch-52] SortMergeJoin don't wait for all input before emitting (#20699)
    
    ## Which issue does this PR close?
    
    Backport of #20482 to branch-52.
    
    ## Rationale for this change
    
    Cherry-pick fix and prerequisites so that SortMergeJoin emits output
    incrementally instead of waiting for all input to
     complete. This resolves OOM issues Comet is seeing with DataFusion 52.
    
    ## What changes are included in this PR?
    
    Cherry-picks of the following commits from `main`:
    
    1. #19614 — Extract sort-merge join filter logic into separate module
    2. #20463 — Use zero-copy slice instead of take kernel in sort merge
    join
    3. #20482 — Fix SortMergeJoin to not wait for all input before emitting
    
    ## Are these changes tested?
    
    Yes, covered by existing and new tests included in #20482.
    
    ## Are there any user-facing changes?
    
    No.
    
    ---------
    
    Co-authored-by: Liang-Chi Hsieh <[email protected]>
    Co-authored-by: Claude Sonnet 4.5 <[email protected]>
    Co-authored-by: Andy Grove <[email protected]>
    Co-authored-by: Raz Luvaton <[email protected]>
---
 .../src/joins/sort_merge_join/filter.rs            | 595 ++++++++++++++++++++
 .../physical-plan/src/joins/sort_merge_join/mod.rs |   1 +
 .../src/joins/sort_merge_join/stream.rs            | 607 +++++----------------
 .../src/joins/sort_merge_join/tests.rs             | 514 +++++++++++++++--
 datafusion/physical-plan/src/test/exec.rs          | 111 +++-
 5 files changed, 1307 insertions(+), 521 deletions(-)

diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs
new file mode 100644
index 0000000000..d598442b65
--- /dev/null
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/filter.rs
@@ -0,0 +1,595 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! Filter handling for Sort-Merge Join
+//!
+//! This module encapsulates the complexity of join filter evaluation, 
including:
+//! - Immediate filtering for INNER joins
+//! - Deferred filtering for outer/semi/anti/mark joins
+//! - Metadata tracking for grouping output rows by input row
+//! - Correcting filter masks to handle multiple matches per input row
+
+use std::sync::Arc;
+
+use arrow::array::{
+    Array, ArrayBuilder, ArrayRef, BooleanArray, BooleanBuilder, RecordBatch,
+    UInt64Array, UInt64Builder,
+};
+use arrow::compute::{self, concat_batches, filter_record_batch};
+use arrow::datatypes::SchemaRef;
+use datafusion_common::{JoinSide, JoinType, Result};
+
+use crate::joins::utils::JoinFilter;
+
+/// Metadata for tracking filter results during deferred filtering
+///
+/// When a join filter is present and we need to ensure each input row produces
+/// at least one output (outer joins) or exactly one output (semi joins), we 
can't
+/// filter immediately. Instead, we accumulate all joined rows with metadata,
+/// then post-process to determine which rows to output.
+#[derive(Debug)]
+pub struct FilterMetadata {
+    /// Did each output row pass the join filter?
+    /// Used to detect if an input row found ANY match
+    pub filter_mask: BooleanBuilder,
+
+    /// Which input row (within batch) produced each output row?
+    /// Used for grouping output rows by input row
+    pub row_indices: UInt64Builder,
+
+    /// Which input batch did each output row come from?
+    /// Used to disambiguate row_indices across multiple batches
+    pub batch_ids: Vec<usize>,
+}
+
+impl FilterMetadata {
+    /// Create new empty filter metadata
+    pub fn new() -> Self {
+        Self {
+            filter_mask: BooleanBuilder::new(),
+            row_indices: UInt64Builder::new(),
+            batch_ids: vec![],
+        }
+    }
+
+    /// Returns (row_indices, filter_mask, batch_ids_ref) and clears builders
+    pub fn finish_metadata(&mut self) -> (UInt64Array, BooleanArray, &[usize]) 
{
+        let row_indices = self.row_indices.finish();
+        let filter_mask = self.filter_mask.finish();
+        (row_indices, filter_mask, &self.batch_ids)
+    }
+
+    /// Add metadata for null-joined rows (no filter applied)
+    pub fn append_nulls(&mut self, num_rows: usize) {
+        self.filter_mask.append_nulls(num_rows);
+        self.row_indices.append_nulls(num_rows);
+        self.batch_ids.resize(
+            self.batch_ids.len() + num_rows,
+            0, // batch_id = 0 for null-joined rows
+        );
+    }
+
+    /// Add metadata for filtered rows
+    pub fn append_filter_metadata(
+        &mut self,
+        row_indices: &UInt64Array,
+        filter_mask: &BooleanArray,
+        batch_id: usize,
+    ) {
+        debug_assert_eq!(
+            row_indices.len(),
+            filter_mask.len(),
+            "row_indices and filter_mask must have same length"
+        );
+
+        for i in 0..row_indices.len() {
+            if filter_mask.is_null(i) {
+                self.filter_mask.append_null();
+            } else if filter_mask.value(i) {
+                self.filter_mask.append_value(true);
+            } else {
+                self.filter_mask.append_value(false);
+            }
+
+            if row_indices.is_null(i) {
+                self.row_indices.append_null();
+            } else {
+                self.row_indices.append_value(row_indices.value(i));
+            }
+
+            self.batch_ids.push(batch_id);
+        }
+    }
+
+    /// Verify that metadata arrays are aligned (same length)
+    pub fn debug_assert_metadata_aligned(&self) {
+        if self.filter_mask.len() > 0 {
+            debug_assert_eq!(
+                self.filter_mask.len(),
+                self.row_indices.len(),
+                "filter_mask and row_indices must have same length when 
metadata is used"
+            );
+            debug_assert_eq!(
+                self.filter_mask.len(),
+                self.batch_ids.len(),
+                "filter_mask and batch_ids must have same length when metadata 
is used"
+            );
+        } else {
+            debug_assert_eq!(
+                self.filter_mask.len(),
+                0,
+                "filter_mask should be empty when batches is empty"
+            );
+        }
+    }
+}
+
+impl Default for FilterMetadata {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+/// Determines if a join type needs deferred filtering
+///
+/// Deferred filtering is required when:
+/// - A filter exists AND
+/// - The join type requires ensuring each input row produces at least one 
output
+///   (or exactly one for semi joins)
+pub fn needs_deferred_filtering(
+    filter: &Option<JoinFilter>,
+    join_type: JoinType,
+) -> bool {
+    filter.is_some()
+        && matches!(
+            join_type,
+            JoinType::Left
+                | JoinType::LeftSemi
+                | JoinType::LeftMark
+                | JoinType::Right
+                | JoinType::RightSemi
+                | JoinType::RightMark
+                | JoinType::LeftAnti
+                | JoinType::RightAnti
+                | JoinType::Full
+        )
+}
+
+/// Gets the arrays which join filters are applied on
+///
+/// Extracts the columns needed for filter evaluation from left and right 
batch columns
+pub fn get_filter_columns(
+    join_filter: &Option<JoinFilter>,
+    left_columns: &[ArrayRef],
+    right_columns: &[ArrayRef],
+) -> Vec<ArrayRef> {
+    let mut filter_columns = vec![];
+
+    if let Some(f) = join_filter {
+        let left_columns: Vec<ArrayRef> = f
+            .column_indices()
+            .iter()
+            .filter(|col_index| col_index.side == JoinSide::Left)
+            .map(|i| Arc::clone(&left_columns[i.index]))
+            .collect();
+        let right_columns: Vec<ArrayRef> = f
+            .column_indices()
+            .iter()
+            .filter(|col_index| col_index.side == JoinSide::Right)
+            .map(|i| Arc::clone(&right_columns[i.index]))
+            .collect();
+
+        filter_columns.extend(left_columns);
+        filter_columns.extend(right_columns);
+    }
+
+    filter_columns
+}
+
+/// Determines if current index is the last occurrence of a row
+///
+/// Used during filter mask correction to detect row boundaries when grouping
+/// output rows by input row.
+fn last_index_for_row(
+    row_index: usize,
+    indices: &UInt64Array,
+    batch_ids: &[usize],
+    indices_len: usize,
+) -> bool {
+    debug_assert_eq!(
+        indices.len(),
+        indices_len,
+        "indices.len() should match indices_len parameter"
+    );
+    debug_assert_eq!(
+        batch_ids.len(),
+        indices_len,
+        "batch_ids.len() should match indices_len"
+    );
+    debug_assert!(
+        row_index < indices_len,
+        "row_index {row_index} should be < indices_len {indices_len}",
+    );
+
+    // If this is the last index overall, it's definitely the last for this row
+    if row_index == indices_len - 1 {
+        return true;
+    }
+
+    // Check if next row has different (batch_id, index) pair
+    let current_batch_id = batch_ids[row_index];
+    let next_batch_id = batch_ids[row_index + 1];
+
+    if current_batch_id != next_batch_id {
+        return true;
+    }
+
+    // Same batch_id, check if row index is different
+    // Both current and next should be non-null (already joined rows)
+    if indices.is_null(row_index) || indices.is_null(row_index + 1) {
+        return true;
+    }
+
+    indices.value(row_index) != indices.value(row_index + 1)
+}
+
+/// Corrects the filter mask for joins with deferred filtering
+///
+/// When an input row joins with multiple buffered rows, we get multiple 
output rows.
+/// This function groups them by input row and applies join-type-specific 
logic:
+///
+/// - **Outer joins**: Keep first matching row, convert rest to nulls, add 
null-joined for unmatched
+/// - **Semi joins**: Keep first matching row, discard rest
+/// - **Anti joins**: Keep row only if NO matches passed filter
+/// - **Mark joins**: Like semi but first match only
+///
+/// # Arguments
+/// * `join_type` - The type of join being performed
+/// * `row_indices` - Which input row produced each output row
+/// * `batch_ids` - Which batch each output row came from
+/// * `filter_mask` - Whether each output row passed the filter
+/// * `expected_size` - Total number of input rows (for adding unmatched)
+///
+/// # Returns
+/// Corrected mask indicating which rows to include in final output:
+/// - `true`: Include this row
+/// - `false`: Convert to null-joined row (outer joins) or include as 
unmatched (anti joins)
+/// - `null`: Discard this row
+pub fn get_corrected_filter_mask(
+    join_type: JoinType,
+    row_indices: &UInt64Array,
+    batch_ids: &[usize],
+    filter_mask: &BooleanArray,
+    expected_size: usize,
+) -> Option<BooleanArray> {
+    let row_indices_length = row_indices.len();
+    let mut corrected_mask: BooleanBuilder =
+        BooleanBuilder::with_capacity(row_indices_length);
+    let mut seen_true = false;
+
+    match join_type {
+        JoinType::Left | JoinType::Right => {
+            // For outer joins: Keep first matching row per input row,
+            // convert rest to nulls, add null-joined rows for unmatched
+            for i in 0..row_indices_length {
+                let last_index =
+                    last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
+                if filter_mask.value(i) {
+                    seen_true = true;
+                    corrected_mask.append_value(true);
+                } else if seen_true || !filter_mask.value(i) && !last_index {
+                    corrected_mask.append_null(); // to be ignored and not set 
to output
+                } else {
+                    corrected_mask.append_value(false); // to be converted to 
null joined row
+                }
+
+                if last_index {
+                    seen_true = false;
+                }
+            }
+
+            // Generate null joined rows for records which have no matching 
join key
+            corrected_mask.append_n(expected_size - corrected_mask.len(), 
false);
+            Some(corrected_mask.finish())
+        }
+        JoinType::LeftMark | JoinType::RightMark => {
+            // For mark joins: Like outer but only keep first match, mark with 
boolean
+            for i in 0..row_indices_length {
+                let last_index =
+                    last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
+                if filter_mask.value(i) && !seen_true {
+                    seen_true = true;
+                    corrected_mask.append_value(true);
+                } else if seen_true || !filter_mask.value(i) && !last_index {
+                    corrected_mask.append_null(); // to be ignored and not set 
to output
+                } else {
+                    corrected_mask.append_value(false); // to be converted to 
null joined row
+                }
+
+                if last_index {
+                    seen_true = false;
+                }
+            }
+
+            // Generate null joined rows for records which have no matching 
join key
+            corrected_mask.append_n(expected_size - corrected_mask.len(), 
false);
+            Some(corrected_mask.finish())
+        }
+        JoinType::LeftSemi | JoinType::RightSemi => {
+            // For semi joins: Keep only first matching row per input row, 
discard rest
+            for i in 0..row_indices_length {
+                let last_index =
+                    last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
+                if filter_mask.value(i) && !seen_true {
+                    seen_true = true;
+                    corrected_mask.append_value(true);
+                } else {
+                    corrected_mask.append_null(); // to be ignored and not set 
to output
+                }
+
+                if last_index {
+                    seen_true = false;
+                }
+            }
+
+            Some(corrected_mask.finish())
+        }
+        JoinType::LeftAnti | JoinType::RightAnti => {
+            // For anti joins: Keep row only if NO matches passed the filter
+            for i in 0..row_indices_length {
+                let last_index =
+                    last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
+
+                if filter_mask.value(i) {
+                    seen_true = true;
+                }
+
+                if last_index {
+                    if !seen_true {
+                        corrected_mask.append_value(true);
+                    } else {
+                        corrected_mask.append_null();
+                    }
+
+                    seen_true = false;
+                } else {
+                    corrected_mask.append_null();
+                }
+            }
+            // Generate null joined rows for records which have no matching 
join key,
+            // for LeftAnti non-matched considered as true
+            corrected_mask.append_n(expected_size - corrected_mask.len(), 
true);
+            Some(corrected_mask.finish())
+        }
+        JoinType::Full => {
+            // For full joins: Similar to outer but handle both sides
+            for i in 0..row_indices_length {
+                let last_index =
+                    last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
+
+                if filter_mask.is_null(i) {
+                    // null joined
+                    corrected_mask.append_value(true);
+                } else if filter_mask.value(i) {
+                    seen_true = true;
+                    corrected_mask.append_value(true);
+                } else if seen_true || !filter_mask.value(i) && !last_index {
+                    corrected_mask.append_null(); // to be ignored and not set 
to output
+                } else {
+                    corrected_mask.append_value(false); // to be converted to 
null joined row
+                }
+
+                if last_index {
+                    seen_true = false;
+                }
+            }
+            // Generate null joined rows for records which have no matching 
join key
+            corrected_mask.append_n(expected_size - corrected_mask.len(), 
false);
+            Some(corrected_mask.finish())
+        }
+        JoinType::Inner => {
+            // Inner joins don't need deferred filtering
+            None
+        }
+    }
+}
+
+/// Applies corrected filter mask to record batch based on join type
+///
+/// Different join types require different handling of filtered results:
+/// - Outer joins: Add null-joined rows for false mask values
+/// - Semi/Anti joins: May need projection to remove right columns
+/// - Full joins: Add null-joined rows for both sides
+pub fn filter_record_batch_by_join_type(
+    record_batch: &RecordBatch,
+    corrected_mask: &BooleanArray,
+    join_type: JoinType,
+    schema: &SchemaRef,
+    streamed_schema: &SchemaRef,
+    buffered_schema: &SchemaRef,
+) -> Result<RecordBatch> {
+    let filtered_record_batch = filter_record_batch(record_batch, 
corrected_mask)?;
+
+    match join_type {
+        JoinType::Left | JoinType::LeftMark => {
+            // For left joins, add null-joined rows where mask is false
+            let null_mask = compute::not(corrected_mask)?;
+            let null_joined_batch = filter_record_batch(record_batch, 
&null_mask)?;
+
+            if null_joined_batch.num_rows() == 0 {
+                return Ok(filtered_record_batch);
+            }
+
+            // Create null columns for right side
+            let null_joined_streamed_batch = create_null_joined_batch(
+                &null_joined_batch,
+                buffered_schema,
+                JoinSide::Left,
+                join_type,
+                schema,
+            )?;
+
+            Ok(concat_batches(
+                schema,
+                &[filtered_record_batch, null_joined_streamed_batch],
+            )?)
+        }
+        JoinType::LeftSemi
+        | JoinType::LeftAnti
+        | JoinType::RightSemi
+        | JoinType::RightAnti => {
+            // For semi/anti joins, project to only include the outer side 
columns
+            // Both Left and Right semi/anti use streamed_schema.len() because:
+            // - For Left: columns are [left, right], so we take first 
streamed_schema.len()
+            // - For Right: columns are [right, left], and streamed side is 
right, so we take first streamed_schema.len()
+            let output_column_indices: Vec<usize> =
+                (0..streamed_schema.fields().len()).collect();
+            Ok(filtered_record_batch.project(&output_column_indices)?)
+        }
+        JoinType::Right | JoinType::RightMark => {
+            // For right joins, add null-joined rows where mask is false
+            let null_mask = compute::not(corrected_mask)?;
+            let null_joined_batch = filter_record_batch(record_batch, 
&null_mask)?;
+
+            if null_joined_batch.num_rows() == 0 {
+                return Ok(filtered_record_batch);
+            }
+
+            // Create null columns for left side (buffered side for RIGHT join)
+            let null_joined_buffered_batch = create_null_joined_batch(
+                &null_joined_batch,
+                buffered_schema, // Pass buffered (left) schema to create 
nulls for it
+                JoinSide::Right,
+                join_type,
+                schema,
+            )?;
+
+            Ok(concat_batches(
+                schema,
+                &[filtered_record_batch, null_joined_buffered_batch],
+            )?)
+        }
+        JoinType::Full => {
+            // For full joins, add null-joined rows for both sides
+            let joined_filter_not_matched_mask = compute::not(corrected_mask)?;
+            let joined_filter_not_matched_batch =
+                filter_record_batch(record_batch, 
&joined_filter_not_matched_mask)?;
+
+            if joined_filter_not_matched_batch.num_rows() == 0 {
+                return Ok(filtered_record_batch);
+            }
+
+            // Create null-joined batches for both sides
+            let left_null_joined_batch = create_null_joined_batch(
+                &joined_filter_not_matched_batch,
+                buffered_schema,
+                JoinSide::Left,
+                join_type,
+                schema,
+            )?;
+
+            Ok(concat_batches(
+                schema,
+                &[filtered_record_batch, left_null_joined_batch],
+            )?)
+        }
+        JoinType::Inner => Ok(filtered_record_batch),
+    }
+}
+
+/// Creates a batch with null columns for the non-joined side
+///
+/// Note: The input `batch` is assumed to be a fully-joined batch that already 
contains
+/// columns from both sides. We need to extract the data side columns and 
replace the
+/// null side columns with actual nulls.
+fn create_null_joined_batch(
+    batch: &RecordBatch,
+    null_schema: &SchemaRef,
+    join_side: JoinSide,
+    join_type: JoinType,
+    output_schema: &SchemaRef,
+) -> Result<RecordBatch> {
+    let num_rows = batch.num_rows();
+
+    // The input batch is a fully-joined batch [left_cols..., right_cols...]
+    // We need to extract the appropriate side and replace the other with 
nulls (or mark column)
+    let columns = match (join_side, join_type) {
+        (JoinSide::Left, JoinType::LeftMark) => {
+            // For LEFT mark: output is [left_cols..., mark_col]
+            // Batch is [left_cols..., right_cols...], extract left from 
beginning
+            // Number of left columns = output columns - 1 (mark column)
+            let left_col_count = output_schema.fields().len() - 1;
+            let mut result: Vec<ArrayRef> = 
batch.columns()[..left_col_count].to_vec();
+            result.push(Arc::new(BooleanArray::from(vec![false; num_rows])) as 
ArrayRef);
+            result
+        }
+        (JoinSide::Right, JoinType::RightMark) => {
+            // For RIGHT mark: output is [right_cols..., mark_col]
+            // For RIGHT joins, batch is [right_cols..., left_cols...] (right 
comes first!)
+            // Extract right columns from the beginning
+            let right_col_count = output_schema.fields().len() - 1; // -1 for 
mark column
+            let mut result: Vec<ArrayRef> = 
batch.columns()[..right_col_count].to_vec();
+            result.push(Arc::new(BooleanArray::from(vec![false; num_rows])) as 
ArrayRef);
+            result
+        }
+        (JoinSide::Left, _) => {
+            // For LEFT join: output is [left_cols..., right_cols...]
+            // Extract left columns, then add null right columns
+            let null_columns: Vec<ArrayRef> = null_schema
+                .fields()
+                .iter()
+                .map(|field| arrow::array::new_null_array(field.data_type(), 
num_rows))
+                .collect();
+            let left_col_count = output_schema.fields().len() - 
null_columns.len();
+            let mut result: Vec<ArrayRef> = 
batch.columns()[..left_col_count].to_vec();
+            result.extend(null_columns);
+            result
+        }
+        (JoinSide::Right, _) => {
+            // For RIGHT join: batch is [left_cols..., right_cols...] (same as 
schema)
+            // We want: [null_left..., actual_right...]
+            // Extract left columns from beginning, replace with nulls, keep 
right columns
+            let null_columns: Vec<ArrayRef> = null_schema
+                .fields()
+                .iter()
+                .map(|field| arrow::array::new_null_array(field.data_type(), 
num_rows))
+                .collect();
+            let left_col_count = null_columns.len();
+            let mut result = null_columns;
+            // Extract right columns starting after left columns
+            result.extend_from_slice(&batch.columns()[left_col_count..]);
+            result
+        }
+        (JoinSide::None, _) => {
+            // This should not happen in normal join operations
+            unreachable!(
+                "JoinSide::None should not be used in null-joined batch 
creation"
+            )
+        }
+    };
+
+    // Create the batch - don't validate nullability since outer joins can have
+    // null values in columns that were originally non-nullable
+    use arrow::array::RecordBatchOptions;
+    let mut options = RecordBatchOptions::new();
+    options = options.with_row_count(Some(num_rows));
+    Ok(RecordBatch::try_new_with_options(
+        Arc::clone(output_schema),
+        columns,
+        &options,
+    )?)
+}
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs
index 82f18e7414..06290ec4d0 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/mod.rs
@@ -20,6 +20,7 @@
 pub use exec::SortMergeJoinExec;
 
 mod exec;
+mod filter;
 mod metrics;
 mod stream;
 
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
index b36992caf4..3a57dc6b41 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/stream.rs
@@ -33,6 +33,10 @@ use std::sync::atomic::AtomicUsize;
 use std::sync::atomic::Ordering::Relaxed;
 use std::task::{Context, Poll};
 
+use crate::joins::sort_merge_join::filter::{
+    FilterMetadata, filter_record_batch_by_join_type, 
get_corrected_filter_mask,
+    get_filter_columns, needs_deferred_filtering,
+};
 use crate::joins::sort_merge_join::metrics::SortMergeJoinMetrics;
 use crate::joins::utils::{JoinFilter, compare_join_arrays};
 use crate::metrics::RecordOutput;
@@ -42,15 +46,13 @@ use crate::{PhysicalExpr, RecordBatchStream, 
SendableRecordBatchStream};
 use arrow::array::{types::UInt64Type, *};
 use arrow::compute::{
     self, BatchCoalescer, SortOptions, concat_batches, filter_record_batch, 
is_not_null,
-    take,
+    take, take_arrays,
 };
 use arrow::datatypes::{DataType, SchemaRef, TimeUnit};
-use arrow::error::ArrowError;
 use arrow::ipc::reader::StreamReader;
 use datafusion_common::config::SpillCompression;
 use datafusion_common::{
-    DataFusionError, HashSet, JoinSide, JoinType, NullEquality, Result, 
exec_err,
-    internal_err, not_impl_err,
+    HashSet, JoinType, NullEquality, Result, exec_err, internal_err, 
not_impl_err,
 };
 use datafusion_execution::disk_manager::RefCountedTempFile;
 use datafusion_execution::memory_pool::MemoryReservation;
@@ -68,6 +70,8 @@ pub(super) enum SortMergeJoinState {
     Polling,
     /// Joining polled data and making output
     JoinOutput,
+    /// Emit ready data if have any and then go back to [`Self::Init`] state
+    EmitReadyThenInit,
     /// No more output
     Exhausted,
 }
@@ -370,12 +374,8 @@ pub(super) struct SortMergeJoinStream {
 pub(super) struct JoinedRecordBatches {
     /// Joined batches. Each batch is already joined columns from left and 
right sources
     pub(super) joined_batches: BatchCoalescer,
-    /// Did each output row pass the join filter? (detect if input row found 
any match)
-    pub(super) filter_mask: BooleanBuilder,
-    /// Which input row (within batch) produced each output row? (for grouping 
by input row)
-    pub(super) row_indices: UInt64Builder,
-    /// Which input batch did each output row come from? (disambiguate 
row_indices)
-    pub(super) batch_ids: Vec<usize>,
+    /// Filter metadata for deferred filtering
+    pub(super) filter_metadata: FilterMetadata,
 }
 
 impl JoinedRecordBatches {
@@ -398,61 +398,28 @@ impl JoinedRecordBatches {
         }
     }
 
-    /// Finishes and returns the metadata arrays, clearing the builders
-    ///
-    /// Returns (row_indices, filter_mask, batch_ids_ref)
-    /// Note: batch_ids is returned as a reference since it's still needed in 
the struct
-    fn finish_metadata(&mut self) -> (UInt64Array, BooleanArray, &[usize]) {
-        let row_indices = self.row_indices.finish();
-        let filter_mask = self.filter_mask.finish();
-        (row_indices, filter_mask, &self.batch_ids)
-    }
-
     /// Clears batches without touching metadata (for early return when no 
filtering needed)
     fn clear_batches(&mut self, schema: &SchemaRef, batch_size: usize) {
         self.joined_batches = BatchCoalescer::new(Arc::clone(schema), 
batch_size)
             .with_biggest_coalesce_batch_size(Option::from(batch_size / 2));
     }
 
-    /// Asserts that internal metadata arrays are consistent with each other
-    /// Only checks if metadata is actually being used (i.e., not all empty)
-    #[inline]
-    fn debug_assert_metadata_aligned(&self) {
-        // Metadata arrays should be aligned IF they're being used
-        // (For non-filtered joins, they may all be empty)
-        if self.filter_mask.len() > 0
-            || self.row_indices.len() > 0
-            || !self.batch_ids.is_empty()
-        {
-            debug_assert_eq!(
-                self.filter_mask.len(),
-                self.row_indices.len(),
-                "filter_mask and row_indices must have same length when 
metadata is used"
-            );
-            debug_assert_eq!(
-                self.filter_mask.len(),
-                self.batch_ids.len(),
-                "filter_mask and batch_ids must have same length when metadata 
is used"
-            );
-        }
-    }
-
     /// Asserts that if batches is empty, metadata is also empty
     #[inline]
     fn debug_assert_empty_consistency(&self) {
         if self.joined_batches.is_empty() {
             debug_assert_eq!(
-                self.filter_mask.len(),
+                self.filter_metadata.filter_mask.len(),
                 0,
                 "filter_mask should be empty when batches is empty"
             );
             debug_assert_eq!(
-                self.row_indices.len(),
+                self.filter_metadata.row_indices.len(),
                 0,
                 "row_indices should be empty when batches is empty"
             );
             debug_assert_eq!(
-                self.batch_ids.len(),
+                self.filter_metadata.batch_ids.len(),
                 0,
                 "batch_ids should be empty when batches is empty"
             );
@@ -473,14 +440,9 @@ impl JoinedRecordBatches {
 
         let num_rows = batch.num_rows();
 
-        self.filter_mask.append_nulls(num_rows);
-        self.row_indices.append_nulls(num_rows);
-        self.batch_ids.resize(
-            self.batch_ids.len() + num_rows,
-            0, // batch_id = 0 for null-joined rows
-        );
+        self.filter_metadata.append_nulls(num_rows);
 
-        self.debug_assert_metadata_aligned();
+        self.filter_metadata.debug_assert_metadata_aligned();
         self.joined_batches
             .push_batch(batch)
             .expect("Failed to push batch to BatchCoalescer");
@@ -525,13 +487,13 @@ impl JoinedRecordBatches {
             "row_indices and filter_mask must have same length"
         );
 
-        // For Full joins, we keep the pre_mask (with nulls), for others we 
keep the cleaned mask
-        self.filter_mask.extend(filter_mask);
-        self.row_indices.extend(row_indices);
-        self.batch_ids
-            .resize(self.batch_ids.len() + row_indices.len(), 
streamed_batch_id);
+        self.filter_metadata.append_filter_metadata(
+            row_indices,
+            filter_mask,
+            streamed_batch_id,
+        );
 
-        self.debug_assert_metadata_aligned();
+        self.filter_metadata.debug_assert_metadata_aligned();
         self.joined_batches
             .push_batch(batch)
             .expect("Failed to push batch to BatchCoalescer");
@@ -551,9 +513,7 @@ impl JoinedRecordBatches {
     fn clear(&mut self, schema: &SchemaRef, batch_size: usize) {
         self.joined_batches = BatchCoalescer::new(Arc::clone(schema), 
batch_size)
             .with_biggest_coalesce_batch_size(Option::from(batch_size / 2));
-        self.batch_ids.clear();
-        self.filter_mask = BooleanBuilder::new();
-        self.row_indices = UInt64Builder::new();
+        self.filter_metadata = FilterMetadata::new();
         self.debug_assert_empty_consistency();
     }
 }
@@ -563,199 +523,6 @@ impl RecordBatchStream for SortMergeJoinStream {
     }
 }
 
-/// True if next index refers to either:
-/// - another batch id
-/// - another row index within same batch id
-/// - end of row indices
-#[inline(always)]
-fn last_index_for_row(
-    row_index: usize,
-    indices: &UInt64Array,
-    batch_ids: &[usize],
-    indices_len: usize,
-) -> bool {
-    debug_assert_eq!(
-        indices.len(),
-        indices_len,
-        "indices.len() should match indices_len parameter"
-    );
-    debug_assert_eq!(
-        batch_ids.len(),
-        indices_len,
-        "batch_ids.len() should match indices_len"
-    );
-    debug_assert!(
-        row_index < indices_len,
-        "row_index {row_index} should be < indices_len {indices_len}",
-    );
-
-    row_index == indices_len - 1
-        || batch_ids[row_index] != batch_ids[row_index + 1]
-        || indices.value(row_index) != indices.value(row_index + 1)
-}
-
-// Returns a corrected boolean bitmask for the given join type
-// Values in the corrected bitmask can be: true, false, null
-// `true` - the row found its match and sent to the output
-// `null` - the row ignored, no output
-// `false` - the row sent as NULL joined row
-pub(super) fn get_corrected_filter_mask(
-    join_type: JoinType,
-    row_indices: &UInt64Array,
-    batch_ids: &[usize],
-    filter_mask: &BooleanArray,
-    expected_size: usize,
-) -> Option<BooleanArray> {
-    let row_indices_length = row_indices.len();
-    let mut corrected_mask: BooleanBuilder =
-        BooleanBuilder::with_capacity(row_indices_length);
-    let mut seen_true = false;
-
-    match join_type {
-        JoinType::Left | JoinType::Right => {
-            for i in 0..row_indices_length {
-                let last_index =
-                    last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
-                if filter_mask.value(i) {
-                    seen_true = true;
-                    corrected_mask.append_value(true);
-                } else if seen_true || !filter_mask.value(i) && !last_index {
-                    corrected_mask.append_null(); // to be ignored and not set 
to output
-                } else {
-                    corrected_mask.append_value(false); // to be converted to 
null joined row
-                }
-
-                if last_index {
-                    seen_true = false;
-                }
-            }
-
-            // Generate null joined rows for records which have no matching 
join key
-            corrected_mask.append_n(expected_size - corrected_mask.len(), 
false);
-            Some(corrected_mask.finish())
-        }
-        JoinType::LeftMark | JoinType::RightMark => {
-            for i in 0..row_indices_length {
-                let last_index =
-                    last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
-                if filter_mask.value(i) && !seen_true {
-                    seen_true = true;
-                    corrected_mask.append_value(true);
-                } else if seen_true || !filter_mask.value(i) && !last_index {
-                    corrected_mask.append_null(); // to be ignored and not set 
to output
-                } else {
-                    corrected_mask.append_value(false); // to be converted to 
null joined row
-                }
-
-                if last_index {
-                    seen_true = false;
-                }
-            }
-
-            // Generate null joined rows for records which have no matching 
join key
-            corrected_mask.append_n(expected_size - corrected_mask.len(), 
false);
-            Some(corrected_mask.finish())
-        }
-        JoinType::LeftSemi | JoinType::RightSemi => {
-            for i in 0..row_indices_length {
-                let last_index =
-                    last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
-                if filter_mask.value(i) && !seen_true {
-                    seen_true = true;
-                    corrected_mask.append_value(true);
-                } else {
-                    corrected_mask.append_null(); // to be ignored and not set 
to output
-                }
-
-                if last_index {
-                    seen_true = false;
-                }
-            }
-
-            Some(corrected_mask.finish())
-        }
-        JoinType::LeftAnti | JoinType::RightAnti => {
-            for i in 0..row_indices_length {
-                let last_index =
-                    last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
-
-                if filter_mask.value(i) {
-                    seen_true = true;
-                }
-
-                if last_index {
-                    if !seen_true {
-                        corrected_mask.append_value(true);
-                    } else {
-                        corrected_mask.append_null();
-                    }
-
-                    seen_true = false;
-                } else {
-                    corrected_mask.append_null();
-                }
-            }
-            // Generate null joined rows for records which have no matching 
join key,
-            // for LeftAnti non-matched considered as true
-            corrected_mask.append_n(expected_size - corrected_mask.len(), 
true);
-            Some(corrected_mask.finish())
-        }
-        JoinType::Full => {
-            let mut mask: Vec<Option<bool>> = vec![Some(true); 
row_indices_length];
-            let mut last_true_idx = 0;
-            let mut first_row_idx = 0;
-            let mut seen_false = false;
-
-            for i in 0..row_indices_length {
-                let last_index =
-                    last_index_for_row(i, row_indices, batch_ids, 
row_indices_length);
-                let val = filter_mask.value(i);
-                let is_null = filter_mask.is_null(i);
-
-                if val {
-                    // memoize the first seen matched row
-                    if !seen_true {
-                        last_true_idx = i;
-                    }
-                    seen_true = true;
-                }
-
-                if is_null || val {
-                    mask[i] = Some(true);
-                } else if !is_null && !val && (seen_true || seen_false) {
-                    mask[i] = None;
-                } else {
-                    mask[i] = Some(false);
-                }
-
-                if !is_null && !val {
-                    seen_false = true;
-                }
-
-                if last_index {
-                    // If the left row seen as true its needed to output it 
once
-                    // To do that we mark all other matches for same row as 
null to avoid the output
-                    if seen_true {
-                        #[expect(clippy::needless_range_loop)]
-                        for j in first_row_idx..last_true_idx {
-                            mask[j] = None;
-                        }
-                    }
-
-                    seen_true = false;
-                    seen_false = false;
-                    last_true_idx = 0;
-                    first_row_idx = i + 1;
-                }
-            }
-
-            Some(BooleanArray::from(mask))
-        }
-        // Only outer joins needs to keep track of processed rows and apply 
corrected filter mask
-        _ => None,
-    }
-}
-
 impl Stream for SortMergeJoinStream {
     type Item = Result<RecordBatch>;
 
@@ -778,7 +545,10 @@ impl Stream for SortMergeJoinStream {
                         match self.current_ordering {
                             Ordering::Less | Ordering::Equal => {
                                 if !streamed_exhausted {
-                                    if self.needs_deferred_filtering() {
+                                    if needs_deferred_filtering(
+                                        &self.filter,
+                                        self.join_type,
+                                    ) {
                                         match self.process_filtered_batches()? 
{
                                             Poll::Ready(Some(batch)) => {
                                                 return 
Poll::Ready(Some(Ok(batch)));
@@ -830,22 +600,56 @@ impl Stream for SortMergeJoinStream {
                     self.current_ordering = self.compare_streamed_buffered()?;
                     self.state = SortMergeJoinState::JoinOutput;
                 }
+                SortMergeJoinState::EmitReadyThenInit => {
+                    // If have data to emit, emit it and if no more, change to 
next
+
+                    // Verify metadata alignment before checking if we have 
batches to output
+                    self.joined_record_batches
+                        .filter_metadata
+                        .debug_assert_metadata_aligned();
+
+                    // For filtered joins, skip output and let Init state 
handle it
+                    if needs_deferred_filtering(&self.filter, self.join_type) {
+                        self.state = SortMergeJoinState::Init;
+                        continue;
+                    }
+
+                    // For non-filtered joins, only output if we have a 
completed batch
+                    // (opportunistic output when target batch size is reached)
+                    if self
+                        .joined_record_batches
+                        .joined_batches
+                        .has_completed_batch()
+                    {
+                        let record_batch = self
+                            .joined_record_batches
+                            .joined_batches
+                            .next_completed_batch()
+                            .expect("has_completed_batch was true");
+                        (&record_batch)
+                            
.record_output(&self.join_metrics.baseline_metrics());
+                        return Poll::Ready(Some(Ok(record_batch)));
+                    }
+                    self.state = SortMergeJoinState::Init;
+                }
                 SortMergeJoinState::JoinOutput => {
                     self.join_partial()?;
 
                     if self.num_unfrozen_pairs() < self.batch_size {
                         if self.buffered_data.scanning_finished() {
                             self.buffered_data.scanning_reset();
-                            self.state = SortMergeJoinState::Init;
+                            self.state = SortMergeJoinState::EmitReadyThenInit;
                         }
                     } else {
                         self.freeze_all()?;
 
                         // Verify metadata alignment before checking if we 
have batches to output
-                        
self.joined_record_batches.debug_assert_metadata_aligned();
+                        self.joined_record_batches
+                            .filter_metadata
+                            .debug_assert_metadata_aligned();
 
                         // For filtered joins, skip output and let Init state 
handle it
-                        if self.needs_deferred_filtering() {
+                        if needs_deferred_filtering(&self.filter, 
self.join_type) {
                             continue;
                         }
 
@@ -872,10 +676,12 @@ impl Stream for SortMergeJoinStream {
                     self.freeze_all()?;
 
                     // Verify metadata alignment before final output
-                    self.joined_record_batches.debug_assert_metadata_aligned();
+                    self.joined_record_batches
+                        .filter_metadata
+                        .debug_assert_metadata_aligned();
 
                     // For filtered joins, must concat and filter ALL data at 
once
-                    if self.needs_deferred_filtering()
+                    if needs_deferred_filtering(&self.filter, self.join_type)
                         && 
!self.joined_record_batches.joined_batches.is_empty()
                     {
                         let record_batch = self.filter_joined_batch()?;
@@ -975,9 +781,7 @@ impl SortMergeJoinStream {
             joined_record_batches: JoinedRecordBatches {
                 joined_batches: BatchCoalescer::new(Arc::clone(&schema), 
batch_size)
                     .with_biggest_coalesce_batch_size(Option::from(batch_size 
/ 2)),
-                filter_mask: BooleanBuilder::new(),
-                row_indices: UInt64Builder::new(),
-                batch_ids: vec![],
+                filter_metadata: FilterMetadata::new(),
             },
             output: BatchCoalescer::new(schema, batch_size)
                 .with_biggest_coalesce_batch_size(Option::from(batch_size / 
2)),
@@ -996,26 +800,6 @@ impl SortMergeJoinStream {
         self.streamed_batch.num_output_rows()
     }
 
-    /// Returns true if this join needs deferred filtering
-    ///
-    /// Deferred filtering is needed when a filter exists and the join type 
requires
-    /// ensuring each input row produces at least one output row (or exactly 
one for semi).
-    fn needs_deferred_filtering(&self) -> bool {
-        self.filter.is_some()
-            && matches!(
-                self.join_type,
-                JoinType::Left
-                    | JoinType::LeftSemi
-                    | JoinType::LeftMark
-                    | JoinType::Right
-                    | JoinType::RightSemi
-                    | JoinType::RightMark
-                    | JoinType::LeftAnti
-                    | JoinType::RightAnti
-                    | JoinType::Full
-            )
-    }
-
     /// Process accumulated batches for filtered joins
     ///
     /// Freezes unfrozen pairs, applies deferred filtering, and outputs if 
ready.
@@ -1023,7 +807,9 @@ impl SortMergeJoinStream {
     fn process_filtered_batches(&mut self) -> 
Poll<Option<Result<RecordBatch>>> {
         self.freeze_all()?;
 
-        self.joined_record_batches.debug_assert_metadata_aligned();
+        self.joined_record_batches
+            .filter_metadata
+            .debug_assert_metadata_aligned();
 
         if !self.joined_record_batches.joined_batches.is_empty() {
             let out_filtered_batch = self.filter_joined_batch()?;
@@ -1399,7 +1185,9 @@ impl SortMergeJoinStream {
         self.freeze_streamed()?;
 
         // After freezing, metadata should be aligned
-        self.joined_record_batches.debug_assert_metadata_aligned();
+        self.joined_record_batches
+            .filter_metadata
+            .debug_assert_metadata_aligned();
 
         Ok(())
     }
@@ -1414,7 +1202,9 @@ impl SortMergeJoinStream {
         self.freeze_buffered(1)?;
 
         // After freezing, metadata should be aligned
-        self.joined_record_batches.debug_assert_metadata_aligned();
+        self.joined_record_batches
+            .filter_metadata
+            .debug_assert_metadata_aligned();
 
         Ok(())
     }
@@ -1490,13 +1280,19 @@ impl SortMergeJoinStream {
                 continue;
             }
 
-            let mut left_columns = self
-                .streamed_batch
-                .batch
-                .columns()
-                .iter()
-                .map(|column| take(column, &left_indices, None))
-                .collect::<Result<Vec<_>, ArrowError>>()?;
+            let mut left_columns = if let Some(range) = 
is_contiguous_range(&left_indices)
+            {
+                // When indices form a contiguous range (common for the 
streamed
+                // side which advances sequentially), use zero-copy slice 
instead
+                // of the O(n) take kernel.
+                self.streamed_batch
+                    .batch
+                    .slice(range.start, range.len())
+                    .columns()
+                    .to_vec()
+            } else {
+                take_arrays(self.streamed_batch.batch.columns(), 
&left_indices, None)?
+            };
 
             // The row indices of joined buffered batch
             let right_indices: UInt64Array = chunk.buffered_indices.finish();
@@ -1541,7 +1337,7 @@ impl SortMergeJoinStream {
                             &right_indices,
                         )?;
 
-                        get_filter_column(&self.filter, &left_columns, 
&right_cols)
+                        get_filter_columns(&self.filter, &left_columns, 
&right_cols)
                     } else if matches!(
                         self.join_type,
                         JoinType::RightAnti | JoinType::RightSemi | 
JoinType::RightMark
@@ -1552,12 +1348,12 @@ impl SortMergeJoinStream {
                             &right_indices,
                         )?;
 
-                        get_filter_column(&self.filter, &right_cols, 
&left_columns)
+                        get_filter_columns(&self.filter, &right_cols, 
&left_columns)
                     } else {
-                        get_filter_column(&self.filter, &left_columns, 
&right_columns)
+                        get_filter_columns(&self.filter, &left_columns, 
&right_columns)
                     }
                 } else {
-                    get_filter_column(&self.filter, &right_columns, 
&left_columns)
+                    get_filter_columns(&self.filter, &right_columns, 
&left_columns)
                 }
             } else {
                 // This chunk is totally for null joined rows (outer join), we 
don't need to apply join filter.
@@ -1679,11 +1475,13 @@ impl SortMergeJoinStream {
 
     fn filter_joined_batch(&mut self) -> Result<RecordBatch> {
         // Metadata should be aligned before processing
-        self.joined_record_batches.debug_assert_metadata_aligned();
+        self.joined_record_batches
+            .filter_metadata
+            .debug_assert_metadata_aligned();
 
         let record_batch = 
self.joined_record_batches.concat_batches(&self.schema)?;
         let (mut out_indices, mut out_mask, mut batch_ids) =
-            self.joined_record_batches.finish_metadata();
+            self.joined_record_batches.filter_metadata.finish_metadata();
         let default_batch_ids = vec![0; record_batch.num_rows()];
 
         // If only nulls come in and indices sizes doesn't match with expected 
record batch count
@@ -1754,139 +1552,14 @@ impl SortMergeJoinStream {
         record_batch: &RecordBatch,
         corrected_mask: &BooleanArray,
     ) -> Result<RecordBatch> {
-        // Corrected mask should have length matching or exceeding 
record_batch rows
-        // (for outer joins it may be longer to include null-joined rows)
-        debug_assert!(
-            corrected_mask.len() >= record_batch.num_rows(),
-            "corrected_mask length ({}) should be >= record_batch rows ({})",
-            corrected_mask.len(),
-            record_batch.num_rows()
-        );
-
-        let mut filtered_record_batch =
-            filter_record_batch(record_batch, corrected_mask)?;
-        let left_columns_length = self.streamed_schema.fields.len();
-        let right_columns_length = self.buffered_schema.fields.len();
-
-        if matches!(
+        let filtered_record_batch = filter_record_batch_by_join_type(
+            record_batch,
+            corrected_mask,
             self.join_type,
-            JoinType::Left | JoinType::LeftMark | JoinType::Right | 
JoinType::RightMark
-        ) {
-            let null_mask = compute::not(corrected_mask)?;
-            let null_joined_batch = filter_record_batch(record_batch, 
&null_mask)?;
-
-            let mut right_columns = create_unmatched_columns(
-                self.join_type,
-                &self.buffered_schema,
-                null_joined_batch.num_rows(),
-            );
-
-            let columns = match self.join_type {
-                JoinType::Right => {
-                    // The first columns are the right columns.
-                    let left_columns = null_joined_batch
-                        .columns()
-                        .iter()
-                        .skip(right_columns_length)
-                        .cloned()
-                        .collect::<Vec<_>>();
-
-                    right_columns.extend(left_columns);
-                    right_columns
-                }
-                JoinType::Left | JoinType::LeftMark | JoinType::RightMark => {
-                    // The first columns are the left columns.
-                    let mut left_columns = null_joined_batch
-                        .columns()
-                        .iter()
-                        .take(left_columns_length)
-                        .cloned()
-                        .collect::<Vec<_>>();
-
-                    left_columns.extend(right_columns);
-                    left_columns
-                }
-                _ => exec_err!("Did not expect join type {}", self.join_type)?,
-            };
-
-            // Push the streamed/buffered batch joined nulls to the output
-            let null_joined_streamed_batch =
-                RecordBatch::try_new(Arc::clone(&self.schema), columns)?;
-
-            filtered_record_batch = concat_batches(
-                &self.schema,
-                &[filtered_record_batch, null_joined_streamed_batch],
-            )?;
-        } else if matches!(
-            self.join_type,
-            JoinType::LeftSemi
-                | JoinType::LeftAnti
-                | JoinType::RightAnti
-                | JoinType::RightSemi
-        ) {
-            let output_column_indices = 
(0..left_columns_length).collect::<Vec<_>>();
-            filtered_record_batch =
-                filtered_record_batch.project(&output_column_indices)?;
-        } else if matches!(self.join_type, JoinType::Full)
-            && corrected_mask.false_count() > 0
-        {
-            // Find rows which joined by key but Filter predicate evaluated as 
false
-            let joined_filter_not_matched_mask = compute::not(corrected_mask)?;
-            let joined_filter_not_matched_batch =
-                filter_record_batch(record_batch, 
&joined_filter_not_matched_mask)?;
-
-            // Add left unmatched rows adding the right side as nulls
-            let right_null_columns = self
-                .buffered_schema
-                .fields()
-                .iter()
-                .map(|f| {
-                    new_null_array(
-                        f.data_type(),
-                        joined_filter_not_matched_batch.num_rows(),
-                    )
-                })
-                .collect::<Vec<_>>();
-
-            let mut result_joined = joined_filter_not_matched_batch
-                .columns()
-                .iter()
-                .take(left_columns_length)
-                .cloned()
-                .collect::<Vec<_>>();
-
-            result_joined.extend(right_null_columns);
-
-            let left_null_joined_batch =
-                RecordBatch::try_new(Arc::clone(&self.schema), result_joined)?;
-
-            // Add right unmatched rows adding the left side as nulls
-            let mut result_joined = self
-                .streamed_schema
-                .fields()
-                .iter()
-                .map(|f| {
-                    new_null_array(
-                        f.data_type(),
-                        joined_filter_not_matched_batch.num_rows(),
-                    )
-                })
-                .collect::<Vec<_>>();
-
-            let right_data = joined_filter_not_matched_batch
-                .columns()
-                .iter()
-                .skip(left_columns_length)
-                .cloned()
-                .collect::<Vec<_>>();
-
-            result_joined.extend(right_data);
-
-            filtered_record_batch = concat_batches(
-                &self.schema,
-                &[filtered_record_batch, left_null_joined_batch],
-            )?;
-        }
+            &self.schema,
+            &self.streamed_schema,
+            &self.buffered_schema,
+        )?;
 
         self.joined_record_batches
             .clear(&self.schema, self.batch_size);
@@ -1911,36 +1584,6 @@ fn create_unmatched_columns(
     }
 }
 
-/// Gets the arrays which join filters are applied on.
-fn get_filter_column(
-    join_filter: &Option<JoinFilter>,
-    streamed_columns: &[ArrayRef],
-    buffered_columns: &[ArrayRef],
-) -> Vec<ArrayRef> {
-    let mut filter_columns = vec![];
-
-    if let Some(f) = join_filter {
-        let left_columns = f
-            .column_indices()
-            .iter()
-            .filter(|col_index| col_index.side == JoinSide::Left)
-            .map(|i| Arc::clone(&streamed_columns[i.index]))
-            .collect::<Vec<_>>();
-
-        let right_columns = f
-            .column_indices()
-            .iter()
-            .filter(|col_index| col_index.side == JoinSide::Right)
-            .map(|i| Arc::clone(&buffered_columns[i.index]))
-            .collect::<Vec<_>>();
-
-        filter_columns.extend(left_columns);
-        filter_columns.extend(right_columns);
-    }
-
-    filter_columns
-}
-
 fn produce_buffered_null_batch(
     schema: &SchemaRef,
     streamed_schema: &SchemaRef,
@@ -1970,6 +1613,30 @@ fn produce_buffered_null_batch(
     )?))
 }
 
+/// Checks if a `UInt64Array` contains a contiguous ascending range (e.g. 
\[3,4,5,6\]).
+/// Returns `Some(start..start+len)` if so, `None` otherwise.
+/// This allows replacing an O(n) `take` with an O(1) `slice`.
+#[inline]
+fn is_contiguous_range(indices: &UInt64Array) -> Option<Range<usize>> {
+    if indices.is_empty() || indices.null_count() > 0 {
+        return None;
+    }
+    let values = indices.values();
+    let start = values[0];
+    let len = values.len() as u64;
+    // Quick rejection: if last element doesn't match expected, not contiguous
+    if values[values.len() - 1] != start + len - 1 {
+        return None;
+    }
+    // Verify every element is sequential (handles duplicates and gaps)
+    for i in 1..values.len() {
+        if values[i] != start + i as u64 {
+            return None;
+        }
+    }
+    Some(start as usize..(start + len) as usize)
+}
+
 /// Get `buffered_indices` rows for `buffered_data[buffered_batch_idx]` by 
specific column indices
 #[inline(always)]
 fn fetch_right_columns_by_idxs(
@@ -1990,12 +1657,16 @@ fn fetch_right_columns_from_batch_by_idxs(
 ) -> Result<Vec<ArrayRef>> {
     match &buffered_batch.batch {
         // In memory batch
-        BufferedBatchState::InMemory(batch) => Ok(batch
-            .columns()
-            .iter()
-            .map(|column| take(column, &buffered_indices, None))
-            .collect::<Result<Vec<_>, ArrowError>>()
-            .map_err(Into::<DataFusionError>::into)?),
+        // In memory batch
+        BufferedBatchState::InMemory(batch) => {
+            // When indices form a contiguous range (common in SMJ since the
+            // buffered side is scanned sequentially), use zero-copy slice.
+            if let Some(range) = is_contiguous_range(buffered_indices) {
+                Ok(batch.slice(range.start, range.len()).columns().to_vec())
+            } else {
+                Ok(take_arrays(batch.columns(), buffered_indices, None)?)
+            }
+        }
         // If the batch was spilled to disk, less likely
         BufferedBatchState::Spilled(spill_file) => {
             let mut buffered_cols: Vec<ArrayRef> =
diff --git a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs 
b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
index d0bcc79636..4329abdd52 100644
--- a/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
+++ b/datafusion/physical-plan/src/joins/sort_merge_join/tests.rs
@@ -24,42 +24,44 @@
 //!
 //! Add relevant tests under the specified sections.
 
-use std::sync::Arc;
-
+use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
+use crate::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec};
+use crate::test::TestMemoryExec;
+use crate::test::exec::BarrierExec;
+use crate::test::{build_table_i32, build_table_i32_two_cols};
+use crate::{ExecutionPlan, common};
+use crate::{
+    expressions::Column, 
joins::sort_merge_join::filter::get_corrected_filter_mask,
+    joins::sort_merge_join::stream::JoinedRecordBatches,
+};
 use arrow::array::{
     BinaryArray, BooleanArray, Date32Array, Date64Array, FixedSizeBinaryArray,
     Int32Array, RecordBatch, UInt64Array,
-    builder::{BooleanBuilder, UInt64Builder},
 };
 use arrow::compute::{BatchCoalescer, SortOptions, filter_record_batch};
 use arrow::datatypes::{DataType, Field, Schema};
-
+use arrow_ord::sort::SortColumn;
+use arrow_schema::SchemaRef;
 use datafusion_common::JoinType::*;
 use datafusion_common::{
-    JoinSide,
+    JoinSide, internal_err,
     test_util::{batches_to_sort_string, batches_to_string},
 };
 use datafusion_common::{
     JoinType, NullEquality, Result, assert_batches_eq, assert_contains,
 };
-use datafusion_execution::TaskContext;
+use datafusion_common_runtime::JoinSet;
 use datafusion_execution::config::SessionConfig;
 use datafusion_execution::disk_manager::{DiskManagerBuilder, DiskManagerMode};
 use datafusion_execution::runtime_env::RuntimeEnvBuilder;
+use datafusion_execution::{SendableRecordBatchStream, TaskContext};
 use datafusion_expr::Operator;
 use datafusion_physical_expr::expressions::BinaryExpr;
+use futures::StreamExt;
 use insta::{allow_duplicates, assert_snapshot};
-
-use crate::{
-    expressions::Column,
-    joins::sort_merge_join::stream::{JoinedRecordBatches, 
get_corrected_filter_mask},
-};
-
-use crate::joins::SortMergeJoinExec;
-use crate::joins::utils::{ColumnIndex, JoinFilter, JoinOn};
-use crate::test::TestMemoryExec;
-use crate::test::{build_table_i32, build_table_i32_two_cols};
-use crate::{ExecutionPlan, common};
+use itertools::Itertools;
+use std::sync::Arc;
+use std::task::Poll;
 
 fn build_table(
     a: (&str, &Vec<i32>),
@@ -2375,9 +2377,7 @@ fn build_joined_record_batches() -> 
Result<JoinedRecordBatches> {
 
     let mut batches = JoinedRecordBatches {
         joined_batches: BatchCoalescer::new(Arc::clone(&schema), 8192),
-        filter_mask: BooleanBuilder::new(),
-        row_indices: UInt64Builder::new(),
-        batch_ids: vec![],
+        filter_metadata: 
crate::joins::sort_merge_join::filter::FilterMetadata::new(),
     };
 
     // Insert already prejoined non-filtered rows
@@ -2432,44 +2432,73 @@ fn build_joined_record_batches() -> 
Result<JoinedRecordBatches> {
     )?)?;
 
     let streamed_indices = vec![0, 0];
-    batches.batch_ids.extend(vec![0; streamed_indices.len()]);
     batches
+        .filter_metadata
+        .batch_ids
+        .extend(vec![0; streamed_indices.len()]);
+    batches
+        .filter_metadata
         .row_indices
         .extend(&UInt64Array::from(streamed_indices));
 
     let streamed_indices = vec![1];
-    batches.batch_ids.extend(vec![0; streamed_indices.len()]);
     batches
+        .filter_metadata
+        .batch_ids
+        .extend(vec![0; streamed_indices.len()]);
+    batches
+        .filter_metadata
         .row_indices
         .extend(&UInt64Array::from(streamed_indices));
 
     let streamed_indices = vec![0, 0];
-    batches.batch_ids.extend(vec![1; streamed_indices.len()]);
     batches
+        .filter_metadata
+        .batch_ids
+        .extend(vec![1; streamed_indices.len()]);
+    batches
+        .filter_metadata
         .row_indices
         .extend(&UInt64Array::from(streamed_indices));
 
     let streamed_indices = vec![0];
-    batches.batch_ids.extend(vec![2; streamed_indices.len()]);
     batches
+        .filter_metadata
+        .batch_ids
+        .extend(vec![2; streamed_indices.len()]);
+    batches
+        .filter_metadata
         .row_indices
         .extend(&UInt64Array::from(streamed_indices));
 
     let streamed_indices = vec![0, 0];
-    batches.batch_ids.extend(vec![3; streamed_indices.len()]);
     batches
+        .filter_metadata
+        .batch_ids
+        .extend(vec![3; streamed_indices.len()]);
+    batches
+        .filter_metadata
         .row_indices
         .extend(&UInt64Array::from(streamed_indices));
 
     batches
+        .filter_metadata
         .filter_mask
         .extend(&BooleanArray::from(vec![true, false]));
-    batches.filter_mask.extend(&BooleanArray::from(vec![true]));
     batches
+        .filter_metadata
+        .filter_mask
+        .extend(&BooleanArray::from(vec![true]));
+    batches
+        .filter_metadata
         .filter_mask
         .extend(&BooleanArray::from(vec![false, true]));
-    batches.filter_mask.extend(&BooleanArray::from(vec![false]));
     batches
+        .filter_metadata
+        .filter_mask
+        .extend(&BooleanArray::from(vec![false]));
+    batches
+        .filter_metadata
         .filter_mask
         .extend(&BooleanArray::from(vec![false, false]));
 
@@ -2482,8 +2511,8 @@ async fn test_left_outer_join_filtered_mask() -> 
Result<()> {
     let schema = joined_batches.joined_batches.schema();
 
     let output = joined_batches.concat_batches(&schema)?;
-    let out_mask = joined_batches.filter_mask.finish();
-    let out_indices = joined_batches.row_indices.finish();
+    let out_mask = joined_batches.filter_metadata.filter_mask.finish();
+    let out_indices = joined_batches.filter_metadata.row_indices.finish();
 
     assert_eq!(
         get_corrected_filter_mask(
@@ -2620,7 +2649,7 @@ async fn test_left_outer_join_filtered_mask() -> 
Result<()> {
     let corrected_mask = get_corrected_filter_mask(
         Left,
         &out_indices,
-        &joined_batches.batch_ids,
+        &joined_batches.filter_metadata.batch_ids,
         &out_mask,
         output.num_rows(),
     )
@@ -2689,8 +2718,8 @@ async fn test_semi_join_filtered_mask() -> Result<()> {
         let schema = joined_batches.joined_batches.schema();
 
         let output = joined_batches.concat_batches(&schema)?;
-        let out_mask = joined_batches.filter_mask.finish();
-        let out_indices = joined_batches.row_indices.finish();
+        let out_mask = joined_batches.filter_metadata.filter_mask.finish();
+        let out_indices = joined_batches.filter_metadata.row_indices.finish();
 
         assert_eq!(
             get_corrected_filter_mask(
@@ -2791,7 +2820,7 @@ async fn test_semi_join_filtered_mask() -> Result<()> {
         let corrected_mask = get_corrected_filter_mask(
             join_type,
             &out_indices,
-            &joined_batches.batch_ids,
+            &joined_batches.filter_metadata.batch_ids,
             &out_mask,
             output.num_rows(),
         )
@@ -2864,8 +2893,8 @@ async fn test_anti_join_filtered_mask() -> Result<()> {
         let schema = joined_batches.joined_batches.schema();
 
         let output = joined_batches.concat_batches(&schema)?;
-        let out_mask = joined_batches.filter_mask.finish();
-        let out_indices = joined_batches.row_indices.finish();
+        let out_mask = joined_batches.filter_metadata.filter_mask.finish();
+        let out_indices = joined_batches.filter_metadata.row_indices.finish();
 
         assert_eq!(
             get_corrected_filter_mask(
@@ -2966,7 +2995,7 @@ async fn test_anti_join_filtered_mask() -> Result<()> {
         let corrected_mask = get_corrected_filter_mask(
             join_type,
             &out_indices,
-            &joined_batches.batch_ids,
+            &joined_batches.filter_metadata.batch_ids,
             &out_mask,
             output.num_rows(),
         )
@@ -3104,6 +3133,419 @@ fn test_partition_statistics() -> Result<()> {
     Ok(())
 }
 
+fn build_batches(
+    a: (&str, &[Vec<bool>]),
+    b: (&str, &[Vec<i32>]),
+    c: (&str, &[Vec<i32>]),
+) -> (Vec<RecordBatch>, SchemaRef) {
+    assert_eq!(a.1.len(), b.1.len());
+    let mut batches = vec![];
+
+    let schema = Arc::new(Schema::new(vec![
+        Field::new(a.0, DataType::Boolean, false),
+        Field::new(b.0, DataType::Int32, false),
+        Field::new(c.0, DataType::Int32, false),
+    ]));
+
+    for i in 0..a.1.len() {
+        batches.push(
+            RecordBatch::try_new(
+                Arc::clone(&schema),
+                vec![
+                    Arc::new(BooleanArray::from(a.1[i].clone())),
+                    Arc::new(Int32Array::from(b.1[i].clone())),
+                    Arc::new(Int32Array::from(c.1[i].clone())),
+                ],
+            )
+            .unwrap(),
+        );
+    }
+    let schema = batches[0].schema();
+    (batches, schema)
+}
+
+fn build_batched_finish_barrier_table(
+    a: (&str, &[Vec<bool>]),
+    b: (&str, &[Vec<i32>]),
+    c: (&str, &[Vec<i32>]),
+) -> (Arc<BarrierExec>, Arc<TestMemoryExec>) {
+    let (batches, schema) = build_batches(a, b, c);
+
+    let memory_exec = TestMemoryExec::try_new_exec(
+        std::slice::from_ref(&batches),
+        Arc::clone(&schema),
+        None,
+    )
+    .unwrap();
+
+    let barrier_exec = Arc::new(
+        BarrierExec::new(vec![batches], schema)
+            .with_log(false)
+            .without_start_barrier()
+            .with_finish_barrier(),
+    );
+
+    (barrier_exec, memory_exec)
+}
+
+/// Concat and sort batches by all the columns to make sure we can compare 
them with different join
+fn prepare_record_batches_for_cmp(output: Vec<RecordBatch>) -> RecordBatch {
+    let output_batch = arrow::compute::concat_batches(output[0].schema_ref(), 
&output)
+        .expect("failed to concat batches");
+
+    // Sort on all columns to make sure we have a deterministic order for the 
assertion
+    let sort_columns = output_batch
+        .columns()
+        .iter()
+        .map(|c| SortColumn {
+            values: Arc::clone(c),
+            options: None,
+        })
+        .collect::<Vec<_>>();
+
+    let sorted_columns =
+        arrow::compute::lexsort(&sort_columns, None).expect("failed to sort");
+
+    RecordBatch::try_new(output_batch.schema(), sorted_columns)
+        .expect("failed to create batch")
+}
+
+#[expect(clippy::too_many_arguments)]
+async fn join_get_stream_and_get_expected(
+    left: Arc<dyn ExecutionPlan>,
+    right: Arc<dyn ExecutionPlan>,
+    oracle_left: Arc<dyn ExecutionPlan>,
+    oracle_right: Arc<dyn ExecutionPlan>,
+    on: JoinOn,
+    join_type: JoinType,
+    filter: Option<JoinFilter>,
+    batch_size: usize,
+) -> Result<(SendableRecordBatchStream, RecordBatch)> {
+    let sort_options = vec![SortOptions::default(); on.len()];
+    let null_equality = NullEquality::NullEqualsNothing;
+    let task_ctx = Arc::new(
+        TaskContext::default()
+            
.with_session_config(SessionConfig::default().with_batch_size(batch_size)),
+    );
+
+    let expected_output = {
+        let oracle = HashJoinExec::try_new(
+            oracle_left,
+            oracle_right,
+            on.clone(),
+            filter.clone(),
+            &join_type,
+            None,
+            PartitionMode::Partitioned,
+            null_equality,
+        )?;
+
+        let stream = oracle.execute(0, Arc::clone(&task_ctx))?;
+
+        let batches = common::collect(stream).await?;
+
+        prepare_record_batches_for_cmp(batches)
+    };
+
+    let join = SortMergeJoinExec::try_new(
+        left,
+        right,
+        on,
+        filter,
+        join_type,
+        sort_options,
+        null_equality,
+    )?;
+
+    let stream = join.execute(0, task_ctx)?;
+
+    Ok((stream, expected_output))
+}
+
+fn generate_data_for_emit_early_test(
+    batch_size: usize,
+    number_of_batches: usize,
+    join_type: JoinType,
+) -> (
+    Arc<BarrierExec>,
+    Arc<BarrierExec>,
+    Arc<TestMemoryExec>,
+    Arc<TestMemoryExec>,
+) {
+    let number_of_rows_per_batch = number_of_batches * batch_size;
+    // Prepare data
+    let left_a1 = (0..number_of_rows_per_batch as i32)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+    let left_b1 = (0..1000000)
+        .filter(|item| {
+            match join_type {
+                LeftAnti | RightAnti => {
+                    let remainder = item % (batch_size as i32);
+
+                    // Make sure to have one that match and one that don't
+                    remainder == 0 || remainder == 1
+                }
+                // Have at least 1 that is not matching
+                _ => item % batch_size as i32 != 0,
+            }
+        })
+        .take(number_of_rows_per_batch)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+
+    let left_bool_col1 = left_a1
+        .clone()
+        .into_iter()
+        .map(|b| {
+            b.into_iter()
+                // Mostly true but have some false that not overlap with the 
right column
+                .map(|a| a % (batch_size as i32) != (batch_size as i32) - 2)
+                .collect::<Vec<_>>()
+        })
+        .collect::<Vec<_>>();
+
+    let (left, left_memory) = build_batched_finish_barrier_table(
+        ("bool_col1", left_bool_col1.as_slice()),
+        ("b1", left_b1.as_slice()),
+        ("a1", left_a1.as_slice()),
+    );
+
+    let right_a2 = (0..number_of_rows_per_batch as i32)
+        .map(|item| item * 11)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+    let right_b1 = (0..1000000)
+        .filter(|item| {
+            match join_type {
+                LeftAnti | RightAnti => {
+                    let remainder = item % (batch_size as i32);
+
+                    // Make sure to have one that match and one that don't
+                    remainder == 1 || remainder == 2
+                }
+                // Have at least 1 that is not matching
+                _ => item % batch_size as i32 != 1,
+            }
+        })
+        .take(number_of_rows_per_batch)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+    let right_bool_col2 = right_a2
+        .clone()
+        .into_iter()
+        .map(|b| {
+            b.into_iter()
+                // Mostly true but have some false that not overlap with the 
left column
+                .map(|a| a % (batch_size as i32) != (batch_size as i32) - 1)
+                .collect::<Vec<_>>()
+        })
+        .collect::<Vec<_>>();
+
+    let (right, right_memory) = build_batched_finish_barrier_table(
+        ("bool_col2", right_bool_col2.as_slice()),
+        ("b1", right_b1.as_slice()),
+        ("a2", right_a2.as_slice()),
+    );
+
+    (left, right, left_memory, right_memory)
+}
+
+#[tokio::test]
+async fn test_should_emit_early_when_have_enough_data_to_emit() -> Result<()> {
+    for with_filtering in [false, true] {
+        let join_types = vec![
+            Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, 
RightMark,
+        ];
+        const BATCH_SIZE: usize = 10;
+        for join_type in join_types {
+            for output_batch_size in [
+                BATCH_SIZE / 3,
+                BATCH_SIZE / 2,
+                BATCH_SIZE,
+                BATCH_SIZE * 2,
+                BATCH_SIZE * 3,
+            ] {
+                // Make sure the number of batches is enough for all join type 
to emit some output
+                let number_of_batches = if output_batch_size <= BATCH_SIZE {
+                    100
+                } else {
+                    // Have enough batches
+                    (output_batch_size * 100) / BATCH_SIZE
+                };
+
+                let (left, right, left_memory, right_memory) =
+                    generate_data_for_emit_early_test(
+                        BATCH_SIZE,
+                        number_of_batches,
+                        join_type,
+                    );
+
+                let on = vec![(
+                    Arc::new(Column::new_with_schema("b1", &left.schema())?) 
as _,
+                    Arc::new(Column::new_with_schema("b1", &right.schema())?) 
as _,
+                )];
+
+                let join_filter = if with_filtering {
+                    let filter = JoinFilter::new(
+                        Arc::new(BinaryExpr::new(
+                            Arc::new(Column::new("bool_col1", 0)),
+                            Operator::And,
+                            Arc::new(Column::new("bool_col2", 1)),
+                        )),
+                        vec![
+                            ColumnIndex {
+                                index: 0,
+                                side: JoinSide::Left,
+                            },
+                            ColumnIndex {
+                                index: 0,
+                                side: JoinSide::Right,
+                            },
+                        ],
+                        Arc::new(Schema::new(vec![
+                            Field::new("bool_col1", DataType::Boolean, true),
+                            Field::new("bool_col2", DataType::Boolean, true),
+                        ])),
+                    );
+                    Some(filter)
+                } else {
+                    None
+                };
+
+                // select *
+                // from t1
+                // right join t2 on t1.b1 = t2.b1 and t1.bool_col1 AND 
t2.bool_col2
+                let (mut output_stream, expected) = 
join_get_stream_and_get_expected(
+                    Arc::clone(&left) as Arc<dyn ExecutionPlan>,
+                    Arc::clone(&right) as Arc<dyn ExecutionPlan>,
+                    left_memory as Arc<dyn ExecutionPlan>,
+                    right_memory as Arc<dyn ExecutionPlan>,
+                    on,
+                    join_type,
+                    join_filter,
+                    output_batch_size,
+                )
+                .await?;
+
+                let (output_batched, output_batches_after_finish) =
+                  consume_stream_until_finish_barrier_reached(left, right, 
&mut output_stream).await.unwrap_or_else(|e| panic!("Failed to consume stream 
for join type: '{join_type}' and with filtering '{with_filtering}': {e:?}"));
+
+                // It should emit more than that, but we are being generous
+                // and to make sure the test pass for all
+                const MINIMUM_OUTPUT_BATCHES: usize = 5;
+                assert!(
+                    MINIMUM_OUTPUT_BATCHES <= number_of_batches / 5,
+                    "Make sure that the minimum output batches is realistic"
+                );
+                // Test to make sure that we are not waiting for input to be 
fully consumed to emit some output
+                assert!(
+                    output_batched.len() >= MINIMUM_OUTPUT_BATCHES,
+                    "[Sort Merge Join {join_type}] Stream must have at least 
emit {} batches, but only got {} batches",
+                    MINIMUM_OUTPUT_BATCHES,
+                    output_batched.len()
+                );
+
+                // Just sanity test to make sure we are still producing valid 
output
+                {
+                    let output = [output_batched, 
output_batches_after_finish].concat();
+                    let actual_prepared = 
prepare_record_batches_for_cmp(output);
+
+                    assert_eq!(actual_prepared.columns(), expected.columns());
+                }
+            }
+        }
+    }
+    Ok(())
+}
+
+/// Polls the stream until both barriers are reached,
+/// collecting the emitted batches along the way.
+///
+/// If the stream is pending for too long (5s) without emitting any batches,
+/// it panics to avoid hanging the test indefinitely.
+///
+/// Note: The left and right BarrierExec might be the input of the output 
stream
+async fn consume_stream_until_finish_barrier_reached(
+    left: Arc<BarrierExec>,
+    right: Arc<BarrierExec>,
+    output_stream: &mut SendableRecordBatchStream,
+) -> Result<(Vec<RecordBatch>, Vec<RecordBatch>)> {
+    let mut switch_to_finish_barrier = false;
+    let mut output_batched = vec![];
+    let mut after_finish_barrier_reached = vec![];
+    let mut background_task = JoinSet::new();
+
+    let mut start_time_since_last_ready = 
datafusion_common::instant::Instant::now();
+    loop {
+        let next_item = output_stream.next();
+
+        // Manual polling
+        let poll_output = futures::poll!(next_item);
+
+        // Wake up the stream to make sure it makes progress
+        tokio::task::yield_now().await;
+
+        match poll_output {
+            Poll::Ready(Some(Ok(batch))) => {
+                if batch.num_rows() == 0 {
+                    return internal_err!("join stream should not emit empty 
batch");
+                }
+                if switch_to_finish_barrier {
+                    after_finish_barrier_reached.push(batch);
+                } else {
+                    output_batched.push(batch);
+                }
+                start_time_since_last_ready = 
datafusion_common::instant::Instant::now();
+            }
+            Poll::Ready(Some(Err(e))) => return Err(e),
+            Poll::Ready(None) if !switch_to_finish_barrier => {
+                unreachable!("Stream should not end before manually finishing 
it")
+            }
+            Poll::Ready(None) => {
+                break;
+            }
+            Poll::Pending => {
+                if right.is_finish_barrier_reached()
+                    && left.is_finish_barrier_reached()
+                    && !switch_to_finish_barrier
+                {
+                    switch_to_finish_barrier = true;
+
+                    let right = Arc::clone(&right);
+                    background_task.spawn(async move {
+                        right.wait_finish().await;
+                    });
+                    let left = Arc::clone(&left);
+                    background_task.spawn(async move {
+                        left.wait_finish().await;
+                    });
+                }
+
+                // Make sure the test doesn't run forever
+                if start_time_since_last_ready.elapsed()
+                    > std::time::Duration::from_secs(5)
+                {
+                    return internal_err!(
+                        "Stream should have emitted data by now, but it's 
still pending. Output batches so far: {}",
+                        output_batched.len()
+                    );
+                }
+            }
+        }
+    }
+
+    Ok((output_batched, after_finish_barrier_reached))
+}
+
 /// Returns the column names on the schema
 fn columns(schema: &Schema) -> Vec<String> {
     schema.fields().iter().map(|f| f.name().clone()).collect()
diff --git a/datafusion/physical-plan/src/test/exec.rs 
b/datafusion/physical-plan/src/test/exec.rs
index 4507cccba0..c229318415 100644
--- a/datafusion/physical-plan/src/test/exec.rs
+++ b/datafusion/physical-plan/src/test/exec.rs
@@ -17,13 +17,6 @@
 
 //! Simple iterator over batches for use in testing
 
-use std::{
-    any::Any,
-    pin::Pin,
-    sync::{Arc, Weak},
-    task::{Context, Poll},
-};
-
 use crate::{
     DisplayAs, DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties,
     RecordBatchStream, SendableRecordBatchStream, Statistics, common,
@@ -33,6 +26,13 @@ use crate::{
     execution_plan::EmissionType,
     stream::{RecordBatchReceiverStream, RecordBatchStreamAdapter},
 };
+use std::sync::atomic::{AtomicUsize, Ordering};
+use std::{
+    any::Any,
+    pin::Pin,
+    sync::{Arc, Weak},
+    task::{Context, Poll},
+};
 
 use arrow::datatypes::{DataType, Field, Schema, SchemaRef};
 use arrow::record_batch::RecordBatch;
@@ -298,29 +298,91 @@ pub struct BarrierExec {
     schema: SchemaRef,
 
     /// all streams wait on this barrier to produce
-    barrier: Arc<Barrier>,
+    start_data_barrier: Option<Arc<Barrier>>,
+
+    /// the stream wait for this to return Poll::Ready(None)
+    finish_barrier: Option<Arc<(Barrier, AtomicUsize)>>,
+
     cache: PlanProperties,
+
+    log: bool,
 }
 
 impl BarrierExec {
     /// Create a new exec with some number of partitions.
     pub fn new(data: Vec<Vec<RecordBatch>>, schema: SchemaRef) -> Self {
         // wait for all streams and the input
-        let barrier = Arc::new(Barrier::new(data.len() + 1));
+        let barrier = Some(Arc::new(Barrier::new(data.len() + 1)));
         let cache = Self::compute_properties(Arc::clone(&schema), &data);
         Self {
             data,
             schema,
-            barrier,
+            start_data_barrier: barrier,
             cache,
+            finish_barrier: None,
+            log: true,
         }
     }
 
+    pub fn with_log(mut self, log: bool) -> Self {
+        self.log = log;
+        self
+    }
+
+    pub fn without_start_barrier(mut self) -> Self {
+        self.start_data_barrier = None;
+        self
+    }
+
+    pub fn with_finish_barrier(mut self) -> Self {
+        let barrier = Arc::new((
+            // wait for all streams and the input
+            Barrier::new(self.data.len() + 1),
+            AtomicUsize::new(0),
+        ));
+
+        self.finish_barrier = Some(barrier);
+        self
+    }
+
     /// wait until all the input streams and this function is ready
     pub async fn wait(&self) {
-        println!("BarrierExec::wait waiting on barrier");
-        self.barrier.wait().await;
-        println!("BarrierExec::wait done waiting");
+        let barrier = &self
+            .start_data_barrier
+            .as_ref()
+            .expect("Must only be called when having a start barrier");
+        if self.log {
+            println!("BarrierExec::wait waiting on barrier");
+        }
+        barrier.wait().await;
+        if self.log {
+            println!("BarrierExec::wait done waiting");
+        }
+    }
+
+    pub async fn wait_finish(&self) {
+        let (barrier, _) = &self
+            .finish_barrier
+            .as_deref()
+            .expect("Must only be called when having a finish barrier");
+
+        if self.log {
+            println!("BarrierExec::wait_finish waiting on barrier");
+        }
+        barrier.wait().await;
+        if self.log {
+            println!("BarrierExec::wait_finish done waiting");
+        }
+    }
+
+    /// Return true if the finish barrier has been reached in all partitions
+    pub fn is_finish_barrier_reached(&self) -> bool {
+        let (_, reached_finish) = self
+            .finish_barrier
+            .as_deref()
+            .expect("Must only be called when having finish barrier");
+
+        reached_finish.load(Ordering::Relaxed) == self.data.len()
     }
 
     /// This function creates the cache object that stores the plan properties 
such as schema, equivalence properties, ordering, partitioning, etc.
@@ -391,17 +453,32 @@ impl ExecutionPlan for BarrierExec {
 
         // task simply sends data in order after barrier is reached
         let data = self.data[partition].clone();
-        let b = Arc::clone(&self.barrier);
+        let start_barrier = self.start_data_barrier.as_ref().map(Arc::clone);
+        let finish_barrier = self.finish_barrier.as_ref().map(Arc::clone);
+        let log = self.log;
         let tx = builder.tx();
         builder.spawn(async move {
-            println!("Partition {partition} waiting on barrier");
-            b.wait().await;
+            if let Some(barrier) = start_barrier {
+                if log {
+                    println!("Partition {partition} waiting on barrier");
+                }
+                barrier.wait().await;
+            }
             for batch in data {
-                println!("Partition {partition} sending batch");
+                if log {
+                    println!("Partition {partition} sending batch");
+                }
                 if let Err(e) = tx.send(Ok(batch)).await {
                     println!("ERROR batch via barrier stream stream: {e}");
                 }
             }
+            if let Some((barrier, reached_finish)) = finish_barrier.as_deref() 
{
+                if log {
+                    println!("Partition {partition} waiting on finish 
barrier");
+                }
+                reached_finish.fetch_add(1, Ordering::Relaxed);
+                barrier.wait().await;
+            }
 
             Ok(())
         });


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]


Reply via email to