EmilyMatt commented on code in PR #20482:
URL: https://github.com/apache/datafusion/pull/20482#discussion_r2843182828
##########
datafusion/physical-plan/src/joins/sort_merge_join/tests.rs:
##########
@@ -3130,6 +3133,420 @@ fn test_partition_statistics() -> Result<()> {
Ok(())
}
+fn build_batches(
+ a: (&str, &[Vec<bool>]),
+ b: (&str, &[Vec<i32>]),
+ c: (&str, &[Vec<i32>]),
+) -> (Vec<RecordBatch>, SchemaRef) {
+ assert_eq!(a.1.len(), b.1.len());
+ let mut batches = vec![];
+
+ for i in 0..a.1.len() {
+ let schema = Schema::new(vec![
+ Field::new(a.0, DataType::Boolean, false),
+ Field::new(b.0, DataType::Int32, false),
+ Field::new(c.0, DataType::Int32, false),
+ ]);
+
+ batches.push(
+ RecordBatch::try_new(
+ Arc::new(schema),
+ vec![
+ Arc::new(BooleanArray::from(a.1[i].clone())),
+ Arc::new(Int32Array::from(b.1[i].clone())),
+ Arc::new(Int32Array::from(c.1[i].clone())),
+ ],
+ )
+ .unwrap(),
+ );
+ }
+ let schema = batches[0].schema();
+ (batches, schema)
+}
+
+fn build_batched_finish_barrier_table(
+ a: (&str, &[Vec<bool>]),
+ b: (&str, &[Vec<i32>]),
+ c: (&str, &[Vec<i32>]),
+) -> (Arc<BarrierExec>, Arc<TestMemoryExec>) {
+ let (batches, schema) = build_batches(a, b, c);
+
+ let memory_exec = TestMemoryExec::try_new_exec(
+ std::slice::from_ref(&batches),
+ Arc::clone(&schema),
+ None,
+ )
+ .unwrap();
+
+ let barrier_exec = Arc::new(
+ BarrierExec::new(vec![batches], schema)
+ .with_log(false)
+ .without_start_barrier()
+ .with_finish_barrier(),
+ );
+
+ (barrier_exec, memory_exec)
+}
+
+/// Concat and sort batches by all the columns to make sure we can compare
them with different join
+fn prepare_record_batches_for_cmp(output: Vec<RecordBatch>) -> RecordBatch {
+ let output_batch = arrow::compute::concat_batches(output[0].schema_ref(),
&output)
+ .expect("failed to concat batches");
+
+ // Sort on all columns to make sure we have a deterministic order for the
assertion
+ let sort_columns = output_batch
+ .columns()
+ .iter()
+ .map(|c| SortColumn {
+ values: Arc::clone(c),
+ options: None,
+ })
+ .collect::<Vec<_>>();
+
+ let sorted_columns =
+ arrow::compute::lexsort(&sort_columns, None).expect("failed to sort");
+
+ RecordBatch::try_new(output_batch.schema(), sorted_columns)
+ .expect("failed to create batch")
+}
+
+#[expect(clippy::too_many_arguments)]
+async fn join_get_stream_and_get_expected(
+ left: Arc<dyn ExecutionPlan>,
+ right: Arc<dyn ExecutionPlan>,
+ oracle_left: Arc<dyn ExecutionPlan>,
+ oracle_right: Arc<dyn ExecutionPlan>,
+ on: JoinOn,
+ join_type: JoinType,
+ filter: Option<JoinFilter>,
+ batch_size: usize,
+) -> Result<(SendableRecordBatchStream, RecordBatch)> {
+ let sort_options = vec![SortOptions::default(); on.len()];
+ let null_equality = NullEquality::NullEqualsNothing;
+ let task_ctx = Arc::new(
+ TaskContext::default()
+
.with_session_config(SessionConfig::default().with_batch_size(batch_size)),
+ );
+
+ let expected_output = {
+ let oracle = HashJoinExec::try_new(
+ oracle_left,
+ oracle_right,
+ on.clone(),
+ filter.clone(),
+ &join_type,
+ None,
+ PartitionMode::Partitioned,
+ null_equality,
+ false,
+ )?;
+
+ let stream = oracle.execute(0, Arc::clone(&task_ctx))?;
+
+ let batches = common::collect(stream).await?;
+
+ prepare_record_batches_for_cmp(batches)
+ };
+
+ let join = SortMergeJoinExec::try_new(
+ left,
+ right,
+ on,
+ filter,
+ join_type,
+ sort_options,
+ null_equality,
+ )?;
+
+ let stream = join.execute(0, task_ctx)?;
+
+ Ok((stream, expected_output))
+}
+
+fn generate_data_for_emit_early_test(
+ batch_size: usize,
+ number_of_batches: usize,
+ join_type: JoinType,
+) -> (
+ Arc<BarrierExec>,
+ Arc<BarrierExec>,
+ Arc<TestMemoryExec>,
+ Arc<TestMemoryExec>,
+) {
+ let number_of_rows_per_batch = number_of_batches * batch_size;
+ // Prepare data
+ let left_a1 = (0..number_of_rows_per_batch as i32)
+ .chunks(batch_size)
+ .into_iter()
+ .map(|chunk| chunk.collect::<Vec<_>>())
+ .collect::<Vec<_>>();
+ let left_b1 = (0..1000000)
+ .filter(|item| {
+ match join_type {
+ LeftAnti | RightAnti => {
+ let remainder = item % (batch_size as i32);
+
+ // Make sure to have one that match and one that don't
+ remainder == 0 || remainder == 1
+ }
+ // Have at least 1 that is not matching
+ _ => item % batch_size as i32 != 0,
+ }
+ })
+ .take(number_of_rows_per_batch)
+ .chunks(batch_size)
+ .into_iter()
+ .map(|chunk| chunk.collect::<Vec<_>>())
+ .collect::<Vec<_>>();
+
+ let left_bool_col1 = left_a1
+ .clone()
+ .into_iter()
+ .map(|b| {
+ b.into_iter()
+ // Mostly true but have some false that not overlap with the
right column
+ .map(|a| a % (batch_size as i32) != (batch_size as i32) - 2)
+ .collect::<Vec<_>>()
+ })
+ .collect::<Vec<_>>();
+
+ let (left, left_memory) = build_batched_finish_barrier_table(
+ ("bool_col1", left_bool_col1.as_slice()),
+ ("b1", left_b1.as_slice()),
+ ("a1", left_a1.as_slice()),
+ );
+
+ let right_a2 = (0..number_of_rows_per_batch as i32)
+ .map(|item| item * 11)
+ .chunks(batch_size)
+ .into_iter()
+ .map(|chunk| chunk.collect::<Vec<_>>())
+ .collect::<Vec<_>>();
+ let right_b1 = (0..1000000)
+ .filter(|item| {
+ match join_type {
+ LeftAnti | RightAnti => {
+ let remainder = item % (batch_size as i32);
+
+ // Make sure to have one that match and one that don't
+ remainder == 1 || remainder == 2
+ }
+ // Have at least 1 that is not matching
+ _ => item % batch_size as i32 != 1,
+ }
+ })
+ .take(number_of_rows_per_batch)
+ .chunks(batch_size)
+ .into_iter()
+ .map(|chunk| chunk.collect::<Vec<_>>())
+ .collect::<Vec<_>>();
+ let right_bool_col2 = right_a2
+ .clone()
+ .into_iter()
+ .map(|b| {
+ b.into_iter()
+ // Mostly true but have some false that not overlap with the
left column
+ .map(|a| a % (batch_size as i32) != (batch_size as i32) - 1)
+ .collect::<Vec<_>>()
+ })
+ .collect::<Vec<_>>();
+
+ let (right, right_memory) = build_batched_finish_barrier_table(
+ ("bool_col2", right_bool_col2.as_slice()),
+ ("b1", right_b1.as_slice()),
+ ("a2", right_a2.as_slice()),
+ );
+
+ (left, right, left_memory, right_memory)
+}
+
+#[tokio::test]
+async fn test_should_emit_early_when_have_enough_data_to_emit() -> Result<()> {
+ for with_filtering in [false, true] {
+ let join_types = vec![
+ Inner, Left, Right, RightSemi, Full, LeftSemi, LeftAnti, LeftMark,
RightMark,
+ ];
+ const BATCH_SIZE: usize = 10;
+ for join_type in join_types {
+ for output_batch_size in [
+ BATCH_SIZE / 3,
+ BATCH_SIZE / 2,
+ BATCH_SIZE,
+ BATCH_SIZE * 2,
+ BATCH_SIZE * 3,
+ ] {
+ // Make sure the number of batches is enough for all join type
to emit some output
+ let number_of_batches = if output_batch_size <= BATCH_SIZE {
+ 100
+ } else {
+ // Have enough batches
+ (output_batch_size * 100) / BATCH_SIZE
+ };
+
+ let (left, right, left_memory, right_memory) =
+ generate_data_for_emit_early_test(
+ BATCH_SIZE,
+ number_of_batches,
+ join_type,
+ );
+
+ let on = vec![(
+ Arc::new(Column::new_with_schema("b1", &left.schema())?)
as _,
+ Arc::new(Column::new_with_schema("b1", &right.schema())?)
as _,
+ )];
+
+ let join_filter = if with_filtering {
+ let filter = JoinFilter::new(
+ Arc::new(BinaryExpr::new(
+ Arc::new(Column::new("bool_col1", 0)),
+ Operator::And,
+ Arc::new(Column::new("bool_col2", 1)),
+ )),
+ vec![
+ ColumnIndex {
+ index: 0,
+ side: JoinSide::Left,
+ },
+ ColumnIndex {
+ index: 0,
+ side: JoinSide::Right,
+ },
+ ],
+ Arc::new(Schema::new(vec![
+ Field::new("bool_col1", DataType::Boolean, true),
+ Field::new("bool_col2", DataType::Boolean, true),
+ ])),
+ );
+ Some(filter)
+ } else {
+ None
+ };
+
+ // select *
+ // from t1
+ // right join t2 on t1.b1 = t2.b1 and t1.bool_col1 AND
t2.bool_col2
+ let (mut output_stream, expected) =
join_get_stream_and_get_expected(
+ Arc::clone(&left) as Arc<dyn ExecutionPlan>,
+ Arc::clone(&right) as Arc<dyn ExecutionPlan>,
+ left_memory as Arc<dyn ExecutionPlan>,
+ right_memory as Arc<dyn ExecutionPlan>,
+ on,
+ join_type,
+ join_filter,
+ BATCH_SIZE,
+ )
+ .await?;
+
+ let (output_batched, output_batches_after_finish) =
+ consume_stream_until_finish_barrier_reached(left, right,
&mut output_stream).await.unwrap_or_else(|e| panic!("Failed to consume stream
for join type: '{join_type}' and with filtering '{with_filtering}': {e:?}"));
+
+ // It should emit more than that, but we are being generous
+ // and to make sure the test pass for all
+ const MINIMUM_OUTPUT_BATCHES: usize = 5;
+ assert!(
+ MINIMUM_OUTPUT_BATCHES <= number_of_batches / 5,
+ "Make sure that the minimum output batches is realistic"
+ );
+ // Test to make sure that we are not waiting for input to be
fully consumed to emit some output
+ assert!(
+ output_batched.len() >= MINIMUM_OUTPUT_BATCHES,
+ "[Sort Merge Join {join_type}] Stream must have at least
emit {} batches, but only got {} batches",
+ MINIMUM_OUTPUT_BATCHES,
+ output_batched.len()
+ );
+
+ // Just sanity test to make sure we are still producing valid
output
+ {
+ let output = [output_batched,
output_batches_after_finish].concat();
+ let actual_prepared =
prepare_record_batches_for_cmp(output);
+
+ assert_eq!(actual_prepared.columns(), expected.columns());
+ }
+ }
+ }
+ }
+ Ok(())
+}
+
+/// Polls the stream until both barriers are reached,
+/// collecting the emitted batches along the way.
+///
+/// If the stream is pending for too long (5s) without emitting any batches,
+/// it panics to avoid hanging the test indefinitely.
+///
+/// Note: The left and right BarrierExec might be the input of the output
stream
+async fn consume_stream_until_finish_barrier_reached(
Review Comment:
This whole function should be implemented as a manual future which is
awaited.
It will allow you to use polling directly and avoid the concurrency dangers
of poll! + yield, which are a bit skewed
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]