Weijun-H commented on code in PR #20482:
URL: https://github.com/apache/datafusion/pull/20482#discussion_r2844830181


##########
datafusion/physical-plan/src/joins/sort_merge_join/tests.rs:
##########
@@ -3130,6 +3133,420 @@ fn test_partition_statistics() -> Result<()> {
     Ok(())
 }
 
+fn build_batches(
+    a: (&str, &[Vec<bool>]),
+    b: (&str, &[Vec<i32>]),
+    c: (&str, &[Vec<i32>]),
+) -> (Vec<RecordBatch>, SchemaRef) {
+    assert_eq!(a.1.len(), b.1.len());
+    let mut batches = vec![];
+
+    for i in 0..a.1.len() {
+        let schema = Schema::new(vec![
+            Field::new(a.0, DataType::Boolean, false),
+            Field::new(b.0, DataType::Int32, false),
+            Field::new(c.0, DataType::Int32, false),
+        ]);
+
+        batches.push(
+            RecordBatch::try_new(
+                Arc::new(schema),
+                vec![
+                    Arc::new(BooleanArray::from(a.1[i].clone())),
+                    Arc::new(Int32Array::from(b.1[i].clone())),
+                    Arc::new(Int32Array::from(c.1[i].clone())),
+                ],
+            )
+            .unwrap(),
+        );
+    }
+    let schema = batches[0].schema();
+    (batches, schema)
+}
+
+fn build_batched_finish_barrier_table(
+    a: (&str, &[Vec<bool>]),
+    b: (&str, &[Vec<i32>]),
+    c: (&str, &[Vec<i32>]),
+) -> (Arc<BarrierExec>, Arc<TestMemoryExec>) {
+    let (batches, schema) = build_batches(a, b, c);
+
+    let memory_exec = TestMemoryExec::try_new_exec(
+        std::slice::from_ref(&batches),
+        Arc::clone(&schema),
+        None,
+    )
+    .unwrap();
+
+    let barrier_exec = Arc::new(
+        BarrierExec::new(vec![batches], schema)
+            .with_log(false)
+            .without_start_barrier()
+            .with_finish_barrier(),
+    );
+
+    (barrier_exec, memory_exec)
+}
+
+/// Concat and sort batches by all the columns to make sure we can compare 
them with different join
+fn prepare_record_batches_for_cmp(output: Vec<RecordBatch>) -> RecordBatch {
+    let output_batch = arrow::compute::concat_batches(output[0].schema_ref(), 
&output)
+        .expect("failed to concat batches");
+
+    // Sort on all columns to make sure we have a deterministic order for the 
assertion
+    let sort_columns = output_batch
+        .columns()
+        .iter()
+        .map(|c| SortColumn {
+            values: Arc::clone(c),
+            options: None,
+        })
+        .collect::<Vec<_>>();
+
+    let sorted_columns =
+        arrow::compute::lexsort(&sort_columns, None).expect("failed to sort");
+
+    RecordBatch::try_new(output_batch.schema(), sorted_columns)
+        .expect("failed to create batch")
+}
+
+#[expect(clippy::too_many_arguments)]
+async fn join_get_stream_and_get_expected(
+    left: Arc<dyn ExecutionPlan>,
+    right: Arc<dyn ExecutionPlan>,
+    oracle_left: Arc<dyn ExecutionPlan>,
+    oracle_right: Arc<dyn ExecutionPlan>,
+    on: JoinOn,
+    join_type: JoinType,
+    filter: Option<JoinFilter>,
+    batch_size: usize,
+) -> Result<(SendableRecordBatchStream, RecordBatch)> {
+    let sort_options = vec![SortOptions::default(); on.len()];
+    let null_equality = NullEquality::NullEqualsNothing;
+    let task_ctx = Arc::new(
+        TaskContext::default()
+            
.with_session_config(SessionConfig::default().with_batch_size(batch_size)),
+    );
+
+    let expected_output = {
+        let oracle = HashJoinExec::try_new(
+            oracle_left,
+            oracle_right,
+            on.clone(),
+            filter.clone(),
+            &join_type,
+            None,
+            PartitionMode::Partitioned,
+            null_equality,
+            false,
+        )?;
+
+        let stream = oracle.execute(0, Arc::clone(&task_ctx))?;
+
+        let batches = common::collect(stream).await?;
+
+        prepare_record_batches_for_cmp(batches)
+    };
+
+    let join = SortMergeJoinExec::try_new(
+        left,
+        right,
+        on,
+        filter,
+        join_type,
+        sort_options,
+        null_equality,
+    )?;
+
+    let stream = join.execute(0, task_ctx)?;
+
+    Ok((stream, expected_output))
+}
+
+fn generate_data_for_emit_early_test(
+    batch_size: usize,
+    number_of_batches: usize,
+    join_type: JoinType,
+) -> (
+    Arc<BarrierExec>,
+    Arc<BarrierExec>,
+    Arc<TestMemoryExec>,
+    Arc<TestMemoryExec>,
+) {
+    let number_of_rows_per_batch = number_of_batches * batch_size;
+    // Prepare data
+    let left_a1 = (0..number_of_rows_per_batch as i32)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+    let left_b1 = (0..1000000)
+        .filter(|item| {
+            match join_type {
+                LeftAnti | RightAnti => {
+                    let remainder = item % (batch_size as i32);
+
+                    // Make sure to have one that match and one that don't
+                    remainder == 0 || remainder == 1
+                }
+                // Have at least 1 that is not matching
+                _ => item % batch_size as i32 != 0,
+            }
+        })
+        .take(number_of_rows_per_batch)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+
+    let left_bool_col1 = left_a1
+        .clone()
+        .into_iter()
+        .map(|b| {
+            b.into_iter()
+                // Mostly true but have some false that not overlap with the 
right column
+                .map(|a| a % (batch_size as i32) != (batch_size as i32) - 2)
+                .collect::<Vec<_>>()
+        })
+        .collect::<Vec<_>>();
+
+    let (left, left_memory) = build_batched_finish_barrier_table(
+        ("bool_col1", left_bool_col1.as_slice()),
+        ("b1", left_b1.as_slice()),
+        ("a1", left_a1.as_slice()),
+    );
+
+    let right_a2 = (0..number_of_rows_per_batch as i32)
+        .map(|item| item * 11)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+    let right_b1 = (0..1000000)
+        .filter(|item| {
+            match join_type {
+                LeftAnti | RightAnti => {
+                    let remainder = item % (batch_size as i32);
+
+                    // Make sure to have one that match and one that don't
+                    remainder == 1 || remainder == 2
+                }
+                // Have at least 1 that is not matching
+                _ => item % batch_size as i32 != 1,
+            }
+        })
+        .take(number_of_rows_per_batch)
+        .chunks(batch_size)
+        .into_iter()
+        .map(|chunk| chunk.collect::<Vec<_>>())
+        .collect::<Vec<_>>();
+    let right_bool_col2 = right_a2
+        .clone()
+        .into_iter()
+        .map(|b| {
+            b.into_iter()
+                // Mostly true but have some false that not overlap with the 
left column
+                .map(|a| a % (batch_size as i32) != (batch_size as i32) - 1)
+                .collect::<Vec<_>>()
+        })
+        .collect::<Vec<_>>();
+
+    let (right, right_memory) = build_batched_finish_barrier_table(
+        ("bool_col2", right_bool_col2.as_slice()),
+        ("b1", right_b1.as_slice()),
+        ("a2", right_a2.as_slice()),
+    );
+
+    (left, right, left_memory, right_memory)
+}
+
+#[tokio::test]
+async fn test_should_emit_early_when_have_enough_data_to_emit() -> Result<()> {
+    for with_filtering in [false, true] {
+        let join_types = vec![
+            Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark, 
RightMark,
+        ];
+        const BATCH_SIZE: usize = 10;
+        for join_type in join_types {
+            for output_batch_size in [
+                BATCH_SIZE / 3,
+                BATCH_SIZE / 2,
+                BATCH_SIZE,
+                BATCH_SIZE * 2,
+                BATCH_SIZE * 3,
+            ] {
+                // Make sure the number of batches is enough for all join type 
to emit some output
+                let number_of_batches = if output_batch_size <= BATCH_SIZE {
+                    100
+                } else {
+                    // Have enough batches
+                    (output_batch_size * 100) / BATCH_SIZE
+                };
+
+                let (left, right, left_memory, right_memory) =
+                    generate_data_for_emit_early_test(
+                        BATCH_SIZE,
+                        number_of_batches,
+                        join_type,
+                    );
+
+                let on = vec![(
+                    Arc::new(Column::new_with_schema("b1", &left.schema())?) 
as _,
+                    Arc::new(Column::new_with_schema("b1", &right.schema())?) 
as _,
+                )];
+
+                let join_filter = if with_filtering {
+                    let filter = JoinFilter::new(
+                        Arc::new(BinaryExpr::new(
+                            Arc::new(Column::new("bool_col1", 0)),
+                            Operator::And,
+                            Arc::new(Column::new("bool_col2", 1)),
+                        )),
+                        vec![
+                            ColumnIndex {
+                                index: 0,
+                                side: JoinSide::Left,
+                            },
+                            ColumnIndex {
+                                index: 0,
+                                side: JoinSide::Right,
+                            },
+                        ],
+                        Arc::new(Schema::new(vec![
+                            Field::new("bool_col1", DataType::Boolean, true),
+                            Field::new("bool_col2", DataType::Boolean, true),
+                        ])),
+                    );
+                    Some(filter)
+                } else {
+                    None
+                };
+
+                // select *
+                // from t1
+                // right join t2 on t1.b1 = t2.b1 and t1.bool_col1 AND 
t2.bool_col2
+                let (mut output_stream, expected) = 
join_get_stream_and_get_expected(
+                    Arc::clone(&left) as Arc<dyn ExecutionPlan>,
+                    Arc::clone(&right) as Arc<dyn ExecutionPlan>,
+                    left_memory as Arc<dyn ExecutionPlan>,
+                    right_memory as Arc<dyn ExecutionPlan>,
+                    on,
+                    join_type,
+                    join_filter,
+                    BATCH_SIZE,

Review Comment:
   I think the test intent and the actual coverage are slightly mismatched here.
   
   `output_batch_size` is not unused, but it currently only affects 
`number_of_batches` (i.e. total input size / test workload). The join itself is 
still executed with the fixed `BATCH_SIZE` constant.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to