alamb commented on code in PR #16398: URL: https://github.com/apache/datafusion/pull/16398#discussion_r2152534664
########## datafusion/physical-plan/src/execution_plan.rs: ########## @@ -912,25 +948,7 @@ impl PlanProperties { /// 2. CoalescePartitionsExec for collapsing all of the partitions into one without ordering guarantee /// 3. SortPreservingMergeExec for collapsing all of the sorted partitions into one with ordering guarantee pub fn need_data_exchange(plan: Arc<dyn ExecutionPlan>) -> bool { - if let Some(repartition) = plan.as_any().downcast_ref::<RepartitionExec>() { - !matches!( - repartition.properties().output_partitioning(), - Partitioning::RoundRobinBatch(_) - ) - } else if let Some(coalesce) = plan.as_any().downcast_ref::<CoalescePartitionsExec>() - { - coalesce.input().output_partitioning().partition_count() > 1 - } else if let Some(sort_preserving_merge) = - plan.as_any().downcast_ref::<SortPreservingMergeExec>() - { - sort_preserving_merge - .input() - .output_partitioning() - .partition_count() - > 1 - } else { - false - } + plan.properties().evaluation_type == EvaluationType::Lazy Review Comment: that is certainly a lot nicer and easier to understand ########## datafusion/physical-plan/src/execution_plan.rs: ########## @@ -743,6 +733,38 @@ pub enum EmissionType { Both, } +/// Represents whether an operator's `Stream` has been implemented to actively cooperate with the +/// Tokio scheduler or not. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SchedulingType { + /// The stream generated by [`execute`](ExecutionPlan::execute) does not actively participate in + /// cooperative scheduling. This means the implementation of the `Stream` returned by + /// [`ExecutionPlan::execute`] does not contain explicit cooperative yield points. + Blocking, + /// The stream generated by [`execute`](ExecutionPlan::execute) actively participates in + /// cooperative scheduling by consuming task budget when it was able to produce a + /// [`RecordBatch`]. Please refer to the [`coop`](crate::coop) module for more details. + Cooperative, +} + +/// Represents how an operator's `Stream` implementation generates `RecordBatch`es. +/// +/// Most operators in DataFusion generate `RecordBatch`es when asked to do so by a call to +/// `Stream::poll_next`. This is known as demand-driven or lazy evaluation. +/// +/// Some operators like `Repartition` need to drive `RecordBatch` generation themselves though. This +/// is known as data-driven or eager evaluation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EvaluationType { + /// The stream generated by [`execute`](ExecutionPlan::execute) only generates `RecordBatch` + /// instances when it is demanded by invoking `Stream::poll_next`. + Lazy, + /// The stream generated by [`execute`](ExecutionPlan::execute) eagerly generates `RecordBatch` + /// in one or more spawned Tokio tasks. Eager evaluation is only started the first time + /// `Stream::poll_next` is called. Review Comment: ```suggestion /// `Stream::poll_next` is called. Hash aggregation and HashJoin are examples of such operators ``` ########## datafusion/physical-plan/src/execution_plan.rs: ########## @@ -743,6 +733,38 @@ pub enum EmissionType { Both, } +/// Represents whether an operator's `Stream` has been implemented to actively cooperate with the +/// Tokio scheduler or not. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SchedulingType { + /// The stream generated by [`execute`](ExecutionPlan::execute) does not actively participate in + /// cooperative scheduling. This means the implementation of the `Stream` returned by + /// [`ExecutionPlan::execute`] does not contain explicit cooperative yield points. Review Comment: I think it would also be helpful here to give an explicitl example of what a cooperative yield point is. For example ```suggestion /// [`ExecutionPlan::execute`] does not contain explicit cooperative yield points such as /// `await` or [`tokio::task::yield_now`]. ``` ########## datafusion/core/tests/execution/coop.rs: ########## @@ -0,0 +1,747 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Int64Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::SortOptions; +use datafusion::functions_aggregate::sum; +use datafusion::physical_expr::aggregate::AggregateExprBuilder; +use datafusion::physical_plan; +use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion::physical_plan::execution_plan::Boundedness; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::{DataFusionError, JoinType, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_expr_common::operator::Operator; +use datafusion_expr_common::operator::Operator::{Divide, Eq, Gt, Modulo}; +use datafusion_functions_aggregate::min_max; +use datafusion_physical_expr::expressions::{ + binary, col, lit, BinaryExpr, Column, Literal, +}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::ensure_coop::EnsureCooperative; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::union::InterleaveExec; +use futures::StreamExt; +use parking_lot::RwLock; +use rstest::rstest; +use std::error::Error; +use std::fmt::Formatter; +use std::ops::Range; +use std::sync::Arc; +use std::task::Poll; +use std::time::Duration; +use tokio::runtime::{Handle, Runtime}; +use tokio::select; + +#[derive(Debug)] +struct RangeBatchGenerator { + schema: SchemaRef, + value_range: Range<i64>, + boundedness: Boundedness, + batch_size: usize, + poll_count: usize, +} + +impl std::fmt::Display for RangeBatchGenerator { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + // Display current counter + write!(f, "InfiniteGenerator(counter={})", self.poll_count) + } +} + +impl LazyBatchGenerator for RangeBatchGenerator { + fn boundedness(&self) -> Boundedness { + self.boundedness + } + + /// Generate the next RecordBatch. + fn generate_next_batch(&mut self) -> datafusion_common::Result<Option<RecordBatch>> { + self.poll_count += 1; + + let mut builder = Int64Array::builder(self.batch_size); + for _ in 0..self.batch_size { + match self.value_range.next() { + None => break, + Some(v) => builder.append_value(v), + } + } + let array = builder.finish(); + + if array.is_empty() { + return Ok(None); + } + + let batch = + RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?; + Ok(Some(batch)) + } +} + +fn make_lazy_exec(column_name: &str, pretend_infinite: bool) -> LazyMemoryExec { + make_lazy_exec_with_range(column_name, i64::MIN..i64::MAX, pretend_infinite) +} + +fn make_lazy_exec_with_range( + column_name: &str, + range: Range<i64>, + pretend_infinite: bool, +) -> LazyMemoryExec { + let schema = Arc::new(Schema::new(vec![Field::new( + column_name, + DataType::Int64, + false, + )])); + + let boundedness = if pretend_infinite { + Boundedness::Unbounded { + requires_infinite_memory: false, + } + } else { + Boundedness::Bounded + }; + + // Instantiate the generator with the batch and limit + let gen = RangeBatchGenerator { + schema: Arc::clone(&schema), + boundedness, + value_range: range, + batch_size: 8192, + poll_count: 0, + }; + + // Wrap the generator in a trait object behind Arc<RwLock<_>> + let generator: Arc<RwLock<dyn LazyBatchGenerator>> = Arc::new(RwLock::new(gen)); + + // Create a LazyMemoryExec with one partition using our generator + let mut exec = LazyMemoryExec::try_new(schema, vec![generator]).unwrap(); + + exec.add_ordering(vec![PhysicalSortExpr::new( + Arc::new(Column::new(column_name, 0)), + SortOptions::new(false, true), + )]); + + exec +} + +#[rstest] +#[tokio::test] +async fn agg_no_grouping_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up an aggregation without grouping + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new( + AggregateExprBuilder::new( + sum::sum_udaf(), + vec![col("value", &inf.schema())?], + ) + .schema(inf.schema()) + .alias("sum") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn agg_grouping_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up an aggregation with grouping + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + let value_col = col("value", &inf.schema())?; + let group = binary(value_col.clone(), Divide, lit(1000000i64), &inf.schema())?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![(group, "group".to_string())], vec![], vec![]), + vec![Arc::new( + AggregateExprBuilder::new(sum::sum_udaf(), vec![value_col.clone()]) + .schema(inf.schema()) + .alias("sum") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn agg_grouped_topk_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up a top-k aggregation + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + let value_col = col("value", &inf.schema())?; + let group = binary(value_col.clone(), Divide, lit(1000000i64), &inf.schema())?; + + let aggr = Arc::new( + AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![(group, "group".to_string())], + vec![], + vec![vec![false]], + ), + vec![Arc::new( + AggregateExprBuilder::new(min_max::max_udaf(), vec![value_col.clone()]) + .schema(inf.schema()) + .alias("max") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )? + .with_limit(Some(100)), + ); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn sort_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up the infinite source + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + // set up a SortExec that will not be able to finish in time because input is very large + let sort_expr = PhysicalSortExpr::new( + col("value", &inf.schema())?, + SortOptions { + descending: true, + nulls_first: true, + }, + ); + + let lex_ordering = LexOrdering::new(vec![sort_expr]).unwrap(); + let sort_exec = Arc::new(SortExec::new(lex_ordering, inf.clone())); + + query_yields(sort_exec, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn sort_merge_join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up the join sources + let inf1 = Arc::new(make_lazy_exec_with_range( + "value1", + i64::MIN..0, + pretend_infinite, + )); + let inf2 = Arc::new(make_lazy_exec_with_range( + "value2", + 0..i64::MAX, + pretend_infinite, + )); + + // set up a SortMergeJoinExec that will take a long time skipping left side content to find + // the first right side match + let join = Arc::new(SortMergeJoinExec::try_new( + inf1.clone(), + inf2.clone(), + vec![( + col("value1", &inf1.schema())?, + col("value2", &inf2.schema())?, + )], + None, + JoinType::Inner, + vec![inf1.properties().eq_properties.output_ordering().unwrap()[0].options], + true, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn filter_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up the infinite source + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + // set up a FilterExec that will filter out entire batches + let filter_expr = binary( + col("value", &inf.schema())?, + Operator::Lt, + lit(i64::MIN), + &inf.schema(), + )?; + let filter = Arc::new(FilterExec::try_new(filter_expr, inf.clone())?); + + query_yields(filter, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn filter_reject_all_batches_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Create a Session, Schema, and an 8K-row RecordBatch + let session_ctx = SessionContext::new(); + + // Wrap this batch in an InfiniteExec + let infinite = make_lazy_exec_with_range("value", i64::MIN..0, pretend_infinite); + + // 2b) Construct a FilterExec that is always false: “value > 10000” (no rows pass) + let false_predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("value", 0)), + Gt, + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); + let filtered = Arc::new(FilterExec::try_new(false_predicate, Arc::new(infinite))?); + + // Use CoalesceBatchesExec to guarantee each Filter pull always yields an 8192-row batch + let coalesced = Arc::new(CoalesceBatchesExec::new(filtered, 8_192)); + + query_yields(coalesced, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +#[ignore = "Fails unless using Tokio based coop implementation"] +async fn interleave_then_filter_all_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Build a session and a schema with one i64 column. + let session_ctx = SessionContext::new(); + + // Create multiple infinite sources, each filtered by a different threshold. + // This ensures InterleaveExec has many children. + let mut infinite_children = vec![]; + + // Use 32 distinct thresholds (each >0 and <8 192) to force 32 infinite inputs + for thr in 1..32 { + // One infinite exec: + let mut inf = make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Now repartition so that all children share identical Hash partitioning + // on “value” into 1 bucket. This is required for InterleaveExec::try_new. + let exprs = vec![Arc::new(Column::new("value", 0)) as _]; + let partitioning = Partitioning::Hash(exprs, 1); + inf.try_set_partitioning(partitioning)?; + + // Apply a FilterExec: “(value / 8192) % thr == 0”. + let filter_expr = binary( + binary( + binary( + col("value", &inf.schema())?, + Divide, + lit(8192i64), + &inf.schema(), + )?, + Modulo, + lit(thr as i64), + &inf.schema(), + )?, + Eq, + lit(0i64), + &inf.schema(), + )?; + let filtered = Arc::new(FilterExec::try_new(filter_expr, Arc::new(inf))?); + + infinite_children.push(filtered as _); + } + + // Build an InterleaveExec over all infinite children. + let interleave = Arc::new(InterleaveExec::try_new(infinite_children)?); + + // Wrap the InterleaveExec in a FilterExec that always returns false, + // ensuring that no rows are ever emitted. + let always_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))); + let filtered_interleave = Arc::new(FilterExec::try_new(always_false, interleave)?); + + query_yields(filtered_interleave, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +#[ignore = "Fails unless using Tokio based coop implementation"] +async fn interleave_then_aggregate_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Build session, schema, and a sample batch. + let session_ctx = SessionContext::new(); + + // Create N infinite sources, each filtered by a different predicate. + // That way, the InterleaveExec will have multiple children. + let mut infinite_children = vec![]; + + // Use 32 distinct thresholds (each >0 and <8 192) to force 32 infinite inputs + for thr in 1..32 { + // One infinite exec: + let mut inf = make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Now repartition so that all children share identical Hash partitioning + // on “value” into 1 bucket. This is required for InterleaveExec::try_new. + let exprs = vec![Arc::new(Column::new("value", 0)) as _]; + let partitioning = Partitioning::Hash(exprs, 1); + inf.try_set_partitioning(partitioning)?; + + // Apply a FilterExec: “(value / 8192) % thr == 0”. + let filter_expr = binary( + binary( + binary( + col("value", &inf.schema())?, + Divide, + lit(8192i64), + &inf.schema(), + )?, + Modulo, + lit(thr as i64), + &inf.schema(), + )?, + Eq, + lit(0i64), + &inf.schema(), + )?; + let filtered = Arc::new(FilterExec::try_new(filter_expr, Arc::new(inf))?); + + infinite_children.push(filtered as _); + } + + // Build an InterleaveExec over all N children. + // Since each child now has Partitioning::Hash([col "value"], 1), InterleaveExec::try_new succeeds. + let interleave = Arc::new(InterleaveExec::try_new(infinite_children)?); + let interleave_schema = interleave.schema(); + + // Build a global AggregateExec that sums “value” over all rows. + // Because we use `AggregateMode::Single` with no GROUP BY columns, this plan will + // only produce one “final” row once all inputs finish. But our inputs never finish, + // so we should never get any output. + let aggregate_expr = AggregateExprBuilder::new( + sum::sum_udaf(), + vec![Arc::new(Column::new("value", 0))], + ) + .schema(interleave_schema.clone()) + .alias("total") + .build()?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![], // no GROUP BY columns + vec![], // no GROUP BY expressions + vec![], // no GROUP BY physical expressions + ), + vec![Arc::new(aggregate_expr)], + vec![None], // no “distinct” flags + interleave, + interleave_schema, + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Session, schema, and a single 8 K‐row batch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Create Join keys → join on “value” = “value” + let left_keys: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(Column::new("value", 0))]; + let right_keys: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(Column::new("value", 0))]; + + // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition + let coalesced_left = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); + let coalesced_right = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); + + let part_left = Partitioning::Hash(left_keys, 1); + let part_right = Partitioning::Hash(right_keys, 1); + + let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); + let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + + // Build an Inner HashJoinExec → left.value = right.value + let join = Arc::new(HashJoinExec::try_new( + hashed_left, + hashed_right, + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + true, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn join_agg_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Session, schema, and a single 8 K‐row batch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // 2b) Create Join keys → join on “value” = “value” + let left_keys: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(Column::new("value", 0))]; + let right_keys: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(Column::new("value", 0))]; + + // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition + let coalesced_left = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); + let coalesced_right = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); + + let part_left = Partitioning::Hash(left_keys, 1); + let part_right = Partitioning::Hash(right_keys, 1); + + let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); + let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + + // Build an Inner HashJoinExec → left.value = right.value + let join = Arc::new(HashJoinExec::try_new( + hashed_left, + hashed_right, + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + true, + )?); + + // Project only one column (“value” from the left side) because we just want to sum that + let input_schema = join.schema(); + + let proj_expr = vec![( + Arc::new(Column::new_with_schema("value", &input_schema)?) as _, + "value".to_string(), + )]; + + let projection = Arc::new(ProjectionExec::try_new(proj_expr, join)?); + let projection_schema = projection.schema(); + + let output_fields = vec![Field::new("total", DataType::Int64, true)]; + let output_schema = Arc::new(Schema::new(output_fields)); + + // 4) Global aggregate (Single) over “value” + let aggregate_expr = AggregateExprBuilder::new( + sum::sum_udaf(), + vec![Arc::new(Column::new_with_schema( + "value", + &projection.schema(), + )?)], + ) + .schema(output_schema) + .alias("total") + .build()?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new(aggregate_expr)], + vec![None], + projection, + projection_schema, + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn hash_join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up the join sources + let inf1 = Arc::new(make_lazy_exec("value1", pretend_infinite)); + let inf2 = Arc::new(make_lazy_exec("value2", pretend_infinite)); + + // set up a HashJoinExec that will take a long time in the build phase + let join = Arc::new(HashJoinExec::try_new( + inf1.clone(), + inf2.clone(), + vec![( + col("value1", &inf1.schema())?, + col("value2", &inf2.schema())?, + )], + None, + &JoinType::Left, + None, + PartitionMode::CollectLeft, + true, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn hash_join_without_repartition_and_no_agg( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Create Session, schema, and an 8K-row RecordBatch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Directly feed `infinite_left` and `infinite_right` into HashJoinExec. + // Do not use aggregation or repartition. + let join = Arc::new(HashJoinExec::try_new( + Arc::new(infinite_left), + Arc::new(infinite_right), + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + /* filter */ None, + &JoinType::Inner, + /* output64 */ None, + // Using CollectLeft is fine—just avoid RepartitionExec’s partitioned channels. + PartitionMode::CollectLeft, + /* build_left */ true, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +async fn query_yields( + plan: Arc<dyn ExecutionPlan>, + task_ctx: Arc<TaskContext>, +) -> Result<(), Box<dyn Error>> { + // Run plan through EnsureCooperative + let optimized = + EnsureCooperative::new().optimize(plan, task_ctx.session_config().options())?; + + // Get the stream + let mut stream = physical_plan::execute_stream(optimized, task_ctx)?; + + // Create an independent executor pool + let child_runtime = Runtime::new()?; + + // Spawn a task that tries to poll the stream + // The task returns Ready when the stream yielded with either Ready or Pending + let join_handle = child_runtime.spawn(std::future::poll_fn(move |cx| { + match stream.poll_next_unpin(cx) { + Poll::Ready(Some(Ok(_))) => Poll::Ready(Poll::Ready(Ok(()))), + Poll::Ready(Some(Err(e))) => Poll::Ready(Poll::Ready(Err(e))), + Poll::Ready(None) => Poll::Ready(Poll::Ready(Ok(()))), + Poll::Pending => Poll::Ready(Poll::Pending), + } + })); + + let abort_handle = join_handle.abort_handle(); + + // Now select on the join handle of the task running in the child executor with a timeout + let yielded = select! { Review Comment: Very minor I think I would find this test easier to follow if we used an explicit enum rather than reusing the `Poll` Maybe something like ```diff --- a/datafusion/core/tests/execution/coop.rs +++ b/datafusion/core/tests/execution/coop.rs @@ -708,12 +708,17 @@ async fn query_yields( // Spawn a task that tries to poll the stream // The task returns Ready when the stream yielded with either Ready or Pending + enum Yielded { + // the task yielded with Ready or beinding + ReadyOrPending, + Err(DataFusionError), + } let join_handle = child_runtime.spawn(std::future::poll_fn(move |cx| { match stream.poll_next_unpin(cx) { - Poll::Ready(Some(Ok(_))) => Poll::Ready(Poll::Ready(Ok(()))), - Poll::Ready(Some(Err(e))) => Poll::Ready(Poll::Ready(Err(e))), - Poll::Ready(None) => Poll::Ready(Poll::Ready(Ok(()))), - Poll::Pending => Poll::Ready(Poll::Pending), + Poll::Ready(Some(Ok(_))) => Poll::Ready(Yielded::ReadyOrPending), + Poll::Ready(Some(Err(e))) => Poll::Ready(Yielded::Err(e)), + Poll::Ready(None) => Poll::Ready(Yielded::ReadyOrPending), + Poll::Pending => Poll::Ready(Yielded::ReadyOrPending), } })); @@ -723,10 +728,8 @@ async fn query_yields( let yielded = select! { result = join_handle => { match result { - Ok(Pending) => Ok(()), - // The task yielded which is ok - Ok(Ready(Ok(_))) => Ok(()), - Ok(Ready(Err(e))) => Err(e), + Ok(Yielded::ReadyOrPending) => Ok(()), + Ok(Yielded::Err(err)) => Err(err), Err(_) => Err(DataFusionError::Execution("join error".into())), } }, ``` ########## datafusion/physical-plan/src/execution_plan.rs: ########## @@ -881,6 +907,16 @@ impl PlanProperties { self } + pub fn with_scheduling_type(mut self, scheduling_type: SchedulingType) -> Self { + self.scheduling_type = scheduling_type; + self + } + + pub fn with_evaluation_type(mut self, drive_type: EvaluationType) -> Self { Review Comment: ```suggestion /// Set the [`EvaluationType`]. /// /// Defaults to [`EvaluationType::Lazy`] pub fn with_evaluation_type(mut self, drive_type: EvaluationType) -> Self { ``` ########## datafusion/physical-plan/src/execution_plan.rs: ########## @@ -743,6 +733,38 @@ pub enum EmissionType { Both, } +/// Represents whether an operator's `Stream` has been implemented to actively cooperate with the +/// Tokio scheduler or not. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SchedulingType { + /// The stream generated by [`execute`](ExecutionPlan::execute) does not actively participate in + /// cooperative scheduling. This means the implementation of the `Stream` returned by + /// [`ExecutionPlan::execute`] does not contain explicit cooperative yield points. + Blocking, + /// The stream generated by [`execute`](ExecutionPlan::execute) actively participates in + /// cooperative scheduling by consuming task budget when it was able to produce a + /// [`RecordBatch`]. Please refer to the [`coop`](crate::coop) module for more details. + Cooperative, +} + +/// Represents how an operator's `Stream` implementation generates `RecordBatch`es. +/// +/// Most operators in DataFusion generate `RecordBatch`es when asked to do so by a call to +/// `Stream::poll_next`. This is known as demand-driven or lazy evaluation. +/// +/// Some operators like `Repartition` need to drive `RecordBatch` generation themselves though. This +/// is known as data-driven or eager evaluation. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum EvaluationType { + /// The stream generated by [`execute`](ExecutionPlan::execute) only generates `RecordBatch` + /// instances when it is demanded by invoking `Stream::poll_next`. Review Comment: Examples I think would help here too. For example ```suggestion /// instances when it is demanded by invoking `Stream::poll_next`. /// Filter and Projection are examples of such Lazy operators. ``` ########## datafusion/physical-plan/src/execution_plan.rs: ########## @@ -881,6 +907,16 @@ impl PlanProperties { self } + pub fn with_scheduling_type(mut self, scheduling_type: SchedulingType) -> Self { Review Comment: I think it would help to document the defaults here ```suggestion /// Set the [`SchedulingType`]. /// /// Defaults to [`SchedulingType::Blocking`] pub fn with_scheduling_type(mut self, scheduling_type: SchedulingType) -> Self { ``` ########## datafusion/physical-plan/src/execution_plan.rs: ########## @@ -743,6 +733,38 @@ pub enum EmissionType { Both, } +/// Represents whether an operator's `Stream` has been implemented to actively cooperate with the +/// Tokio scheduler or not. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SchedulingType { + /// The stream generated by [`execute`](ExecutionPlan::execute) does not actively participate in + /// cooperative scheduling. This means the implementation of the `Stream` returned by + /// [`ExecutionPlan::execute`] does not contain explicit cooperative yield points. + Blocking, + /// The stream generated by [`execute`](ExecutionPlan::execute) actively participates in + /// cooperative scheduling by consuming task budget when it was able to produce a + /// [`RecordBatch`]. Please refer to the [`coop`](crate::coop) module for more details. + Cooperative, +} + +/// Represents how an operator's `Stream` implementation generates `RecordBatch`es. +/// +/// Most operators in DataFusion generate `RecordBatch`es when asked to do so by a call to +/// `Stream::poll_next`. This is known as demand-driven or lazy evaluation. +/// +/// Some operators like `Repartition` need to drive `RecordBatch` generation themselves though. This +/// is known as data-driven or eager evaluation. Review Comment: It would also be super helpful here to document what the implications of returning each `EvaluationType` are ########## datafusion/physical-plan/src/execution_plan.rs: ########## @@ -270,11 +268,13 @@ pub trait ExecutionPlan: Debug + DisplayAs + Send + Sync { /// batch is superlinear. See this [general guideline][async-guideline] for more context /// on this point, which explains why one should avoid spending a long time without /// reaching an `await`/yield point in asynchronous runtimes. - /// This can be achieved by manually returning [`Poll::Pending`] and setting up wakers - /// appropriately, or the use of [`tokio::task::yield_now()`] when appropriate. + /// This can be achieved by using the utilities from the [`coop`](crate::coop) module, by Review Comment: - Once we have https://github.com/apache/datafusion-site/pull/75 done it will be great to link to it from these docs as well :) ########## datafusion/physical-plan/src/execution_plan.rs: ########## @@ -743,6 +733,38 @@ pub enum EmissionType { Both, } +/// Represents whether an operator's `Stream` has been implemented to actively cooperate with the +/// Tokio scheduler or not. +#[derive(Debug, Clone, Copy, PartialEq, Eq)] +pub enum SchedulingType { + /// The stream generated by [`execute`](ExecutionPlan::execute) does not actively participate in + /// cooperative scheduling. This means the implementation of the `Stream` returned by + /// [`ExecutionPlan::execute`] does not contain explicit cooperative yield points. Review Comment: I think it would also help to document here what the implications of returning one vs the other. I think the implication is that if an `ExecutionPlan` returns `Blocking` that DataFusion's planner will automatically wrap the stream in `make_cooperative` for example ########## datafusion/physical-plan/src/coop.rs: ########## @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for improved cooperative scheduling. +//! +//! # Cooperative scheduling +//! +//! A single call to `poll_next` on a top-level `Stream` may potentially do a lot of work before it +//! returns a `Poll::Pending`. Think for instance of calculating an aggregation over a large dataset. +//! If an operator tree runs for a long period of time without yielding back to the Tokio executor, +//! it can starve other tasks waiting on that executor to execute them. +//! Additionally, this prevents the query execution from being cancelled. +//! +//! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield +//! points using the utilities in this module. For most operators this is **not** necessary. The +//! built-in DataFusion operators that generate (rather than manipulate; for instance `DataSourceExec`) +//! or repartition `RecordBatch`es (for instance, `RepartitionExec`) contain yield points that will +//! make most operator trees yield as appropriate. +//! +//! There are a couple of types of operators that should insert yield points: +//! - New source operators that do not make use of Tokio resources +//! - Exchange like operators that do not use Tokio's `Channel` implementation to pass data between +//! tasks + +#[cfg(feature = "tokio_coop_fallback")] +use futures::Future; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::execution_plan::CardinalityEffect::{self, Equal}; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, +}; +use arrow::record_batch::RecordBatch; +use arrow_schema::Schema; +use datafusion_common::{internal_err, Result, Statistics}; +use datafusion_execution::TaskContext; + +use crate::execution_plan::SchedulingType; +use crate::stream::RecordBatchStreamAdapter; +use futures::{Stream, StreamExt}; + +/// A stream that passes record batches through unchanged while cooperating with the Tokio runtime. Review Comment: I reommend adding a link to the upstream tokio ticket - https://github.com/tokio-rs/tokio/pull/7405 as a comment. That way in the future if someone runs across this code they can check the upstream ticket and if it is merged and released they can potentially update this code to use the upstream API ########## datafusion/physical-plan/src/coop.rs: ########## @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for improved cooperative scheduling. +//! +//! # Cooperative scheduling +//! +//! A single call to `poll_next` on a top-level `Stream` may potentially do a lot of work before it +//! returns a `Poll::Pending`. Think for instance of calculating an aggregation over a large dataset. +//! If an operator tree runs for a long period of time without yielding back to the Tokio executor, +//! it can starve other tasks waiting on that executor to execute them. +//! Additionally, this prevents the query execution from being cancelled. +//! +//! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield +//! points using the utilities in this module. For most operators this is **not** necessary. The +//! built-in DataFusion operators that generate (rather than manipulate; for instance `DataSourceExec`) +//! or repartition `RecordBatch`es (for instance, `RepartitionExec`) contain yield points that will +//! make most operator trees yield as appropriate. +//! +//! There are a couple of types of operators that should insert yield points: +//! - New source operators that do not make use of Tokio resources +//! - Exchange like operators that do not use Tokio's `Channel` implementation to pass data between +//! tasks + +#[cfg(feature = "tokio_coop_fallback")] +use futures::Future; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::execution_plan::CardinalityEffect::{self, Equal}; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, +}; +use arrow::record_batch::RecordBatch; +use arrow_schema::Schema; +use datafusion_common::{internal_err, Result, Statistics}; +use datafusion_execution::TaskContext; + +use crate::execution_plan::SchedulingType; +use crate::stream::RecordBatchStreamAdapter; +use futures::{Stream, StreamExt}; + +/// A stream that passes record batches through unchanged while cooperating with the Tokio runtime. +/// It consumes cooperative scheduling budget for each returned [`RecordBatch`], +/// allowing other tasks to execute when the budget is exhausted. Review Comment: I think it would be useful to note the cooperating time is yielding back to the scheduler ```suggestion /// It consumes cooperative scheduling budget for each returned [`RecordBatch`], /// and returns control to the tokio scheduler periodically, allowing other tasks to /// execute or the plan to be cancelled when the budget is exhausted. ``` ########## datafusion/physical-plan/src/coop.rs: ########## @@ -0,0 +1,325 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! Utilities for improved cooperative scheduling. +//! +//! # Cooperative scheduling +//! +//! A single call to `poll_next` on a top-level `Stream` may potentially do a lot of work before it +//! returns a `Poll::Pending`. Think for instance of calculating an aggregation over a large dataset. +//! If an operator tree runs for a long period of time without yielding back to the Tokio executor, +//! it can starve other tasks waiting on that executor to execute them. +//! Additionally, this prevents the query execution from being cancelled. +//! +//! To ensure that `Stream` implementations yield regularly, operators can insert explicit yield +//! points using the utilities in this module. For most operators this is **not** necessary. The +//! built-in DataFusion operators that generate (rather than manipulate; for instance `DataSourceExec`) +//! or repartition `RecordBatch`es (for instance, `RepartitionExec`) contain yield points that will +//! make most operator trees yield as appropriate. +//! +//! There are a couple of types of operators that should insert yield points: +//! - New source operators that do not make use of Tokio resources +//! - Exchange like operators that do not use Tokio's `Channel` implementation to pass data between +//! tasks + +#[cfg(feature = "tokio_coop_fallback")] +use futures::Future; +use std::any::Any; +use std::pin::Pin; +use std::sync::Arc; +use std::task::{Context, Poll}; + +use crate::execution_plan::CardinalityEffect::{self, Equal}; +use crate::{ + DisplayAs, DisplayFormatType, ExecutionPlan, PlanProperties, RecordBatchStream, + SendableRecordBatchStream, +}; +use arrow::record_batch::RecordBatch; +use arrow_schema::Schema; +use datafusion_common::{internal_err, Result, Statistics}; +use datafusion_execution::TaskContext; + +use crate::execution_plan::SchedulingType; +use crate::stream::RecordBatchStreamAdapter; +use futures::{Stream, StreamExt}; + +/// A stream that passes record batches through unchanged while cooperating with the Tokio runtime. +/// It consumes cooperative scheduling budget for each returned [`RecordBatch`], +/// allowing other tasks to execute when the budget is exhausted. +pub struct CooperativeStream<T> +where + T: RecordBatchStream + Unpin, +{ + inner: T, + #[cfg(not(any(feature = "tokio_coop", feature = "tokio_coop_fallback")))] + budget: u8, +} + +#[cfg(not(any(feature = "tokio_coop", feature = "tokio_coop_fallback")))] +// Magic value that matches Tokio's task budget value +const YIELD_FREQUENCY: u8 = 128; + +impl<T> CooperativeStream<T> +where + T: RecordBatchStream + Unpin, +{ + /// Creates a new `CooperativeStream` that wraps the provided stream. + /// The resulting stream will cooperate with the Tokio scheduler by consuming a unit of + /// scheduling budget when the wrapped `Stream` returns a record batch. + pub fn new(inner: T) -> Self { + Self { + inner, + #[cfg(not(any(feature = "tokio_coop", feature = "tokio_coop_fallback")))] + budget: YIELD_FREQUENCY, + } + } +} + +impl<T> Stream for CooperativeStream<T> +where + T: RecordBatchStream + Unpin, +{ + type Item = Result<RecordBatch>; + + fn poll_next( + mut self: Pin<&mut Self>, + cx: &mut Context<'_>, + ) -> Poll<Option<Self::Item>> { + #[cfg(all(feature = "tokio_coop", not(feature = "tokio_coop_fallback")))] + { + let coop = std::task::ready!(tokio::task::coop::poll_proceed(cx)); + let value = self.inner.poll_next_unpin(cx); + if value.is_ready() { + coop.made_progress(); + } + value + } + + #[cfg(feature = "tokio_coop_fallback")] + { + if !tokio::task::coop::has_budget_remaining() { + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + let value = self.inner.poll_next_unpin(cx); + if value.is_ready() { + // This is a temporary placeholder implementation + let consume = tokio::task::consume_budget(); + let consume_ref = std::pin::pin!(consume); + let _ = consume_ref.poll(cx); + } + value + } + + #[cfg(not(any(feature = "tokio_coop", feature = "tokio_coop_fallback")))] + { + if self.budget == 0 { + self.budget = YIELD_FREQUENCY; + cx.waker().wake_by_ref(); + return Poll::Pending; + } + + let value = { self.inner.poll_next_unpin(cx) }; + + if value.is_ready() { + self.budget -= 1; + } else { + self.budget = YIELD_FREQUENCY; + } + value + } + } +} + +impl<T> RecordBatchStream for CooperativeStream<T> +where + T: RecordBatchStream + Unpin, +{ + fn schema(&self) -> Arc<Schema> { + self.inner.schema() + } +} + +/// An execution plan decorator that enables cooperative multitasking. +/// It wraps the streams produced by its input execution plan using the [`make_cooperative`] function, +/// which makes the stream participate in Tokio cooperative scheduling. +#[derive(Debug)] +pub struct CooperativeExec { + input: Arc<dyn ExecutionPlan>, + properties: PlanProperties, +} + +impl CooperativeExec { + /// Creates a new `CooperativeExec` operator that wraps the given input execution plan. + pub fn new(input: Arc<dyn ExecutionPlan>) -> Self { + let properties = input + .properties() + .clone() + .with_scheduling_type(SchedulingType::Cooperative); + + Self { input, properties } + } + + /// Returns a reference to the wrapped input execution plan. + pub fn input(&self) -> &Arc<dyn ExecutionPlan> { + &self.input + } +} + +impl DisplayAs for CooperativeExec { + fn fmt_as( + &self, + _t: DisplayFormatType, + f: &mut std::fmt::Formatter<'_>, + ) -> std::fmt::Result { + write!(f, "CooperativeExec") + } +} + +impl ExecutionPlan for CooperativeExec { + fn name(&self) -> &str { + "CooperativeExec" + } + + fn as_any(&self) -> &dyn Any { + self + } + + fn schema(&self) -> Arc<Schema> { + self.input.schema() + } + + fn properties(&self) -> &PlanProperties { + &self.properties + } + + fn maintains_input_order(&self) -> Vec<bool> { + self.input.maintains_input_order() + } + + fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> { + vec![&self.input] + } + + fn with_new_children( + self: Arc<Self>, + mut children: Vec<Arc<dyn ExecutionPlan>>, + ) -> Result<Arc<dyn ExecutionPlan>> { + if children.len() != 1 { + return internal_err!("CooperativeExec requires exactly one child"); + } + Ok(Arc::new(CooperativeExec::new(children.swap_remove(0)))) + } + + fn execute( + &self, + partition: usize, + task_ctx: Arc<TaskContext>, + ) -> Result<SendableRecordBatchStream> { + let child_stream = self.input.execute(partition, task_ctx)?; + Ok(make_cooperative(child_stream)) + } + + fn partition_statistics(&self, partition: Option<usize>) -> Result<Statistics> { + self.input.partition_statistics(partition) + } + + fn supports_limit_pushdown(&self) -> bool { + true + } + + fn cardinality_effect(&self) -> CardinalityEffect { + Equal + } +} + +/// Creates a [`CooperativeStream`] wrapper around the given [`RecordBatchStream`]. +/// This wrapper collaborates with the Tokio cooperative scheduler by consuming a unit of +/// scheduling budget for each returned record batch. Review Comment: ```suggestion /// scheduling budget for each returned record batch, and yielding when the budget is exhausted. ``` ########## datafusion/core/tests/execution/coop.rs: ########## @@ -0,0 +1,747 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Int64Array, RecordBatch}; Review Comment: I double checked that this test covers the code changes in this PR by breaking LazyMemoryExec deliberately and running the tests ```diff (venv) andrewlamb@Andrews-MacBook-Pro-2:~/Software/datafusion2$ git diff diff --git a/datafusion/physical-plan/src/memory.rs b/datafusion/physical-plan/src/memory.rs index 3e5ea32a4..e790b062a 100644 --- a/datafusion/physical-plan/src/memory.rs +++ b/datafusion/physical-plan/src/memory.rs @@ -314,7 +314,8 @@ impl ExecutionPlan for LazyMemoryExec { generator: Arc::clone(&self.batch_generators[partition]), baseline_metrics, }; - Ok(Box::pin(cooperative(stream))) + //Ok(Box::pin(cooperative(stream))) + Ok(Box::pin(stream)) } fn metrics(&self) -> Option<MetricsSet> { ``` Many of the tests in this file failed with ``` Error: Execution("time out") ``` As expected 👍 ########## datafusion/core/tests/execution/coop.rs: ########## @@ -0,0 +1,722 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +use arrow::array::{Int64Array, RecordBatch}; +use arrow::datatypes::{DataType, Field, Schema, SchemaRef}; +use arrow_schema::SortOptions; +use datafusion::functions_aggregate::sum; +use datafusion::physical_expr::aggregate::AggregateExprBuilder; +use datafusion::physical_plan; +use datafusion::physical_plan::aggregates::{ + AggregateExec, AggregateMode, PhysicalGroupBy, +}; +use datafusion::physical_plan::execution_plan::Boundedness; +use datafusion::physical_plan::ExecutionPlan; +use datafusion::prelude::SessionContext; +use datafusion_common::{JoinType, ScalarValue}; +use datafusion_execution::TaskContext; +use datafusion_expr_common::operator::Operator; +use datafusion_expr_common::operator::Operator::Gt; +use datafusion_functions_aggregate::min_max; +use datafusion_physical_expr::expressions::{ + binary, col, lit, BinaryExpr, Column, Literal, +}; +use datafusion_physical_expr::Partitioning; +use datafusion_physical_expr_common::physical_expr::PhysicalExpr; +use datafusion_physical_expr_common::sort_expr::{LexOrdering, PhysicalSortExpr}; +use datafusion_physical_optimizer::ensure_coop::EnsureCooperative; +use datafusion_physical_optimizer::PhysicalOptimizerRule; +use datafusion_physical_plan::coalesce_batches::CoalesceBatchesExec; +use datafusion_physical_plan::filter::FilterExec; +use datafusion_physical_plan::joins::{HashJoinExec, PartitionMode, SortMergeJoinExec}; +use datafusion_physical_plan::memory::{LazyBatchGenerator, LazyMemoryExec}; +use datafusion_physical_plan::projection::ProjectionExec; +use datafusion_physical_plan::repartition::RepartitionExec; +use datafusion_physical_plan::sorts::sort::SortExec; +use datafusion_physical_plan::union::InterleaveExec; +use futures::StreamExt; +use parking_lot::RwLock; +use rstest::rstest; +use std::error::Error; +use std::fmt::Formatter; +use std::ops::Range; +use std::sync::Arc; +use std::task::Poll; +use tokio::runtime::{Handle, Runtime}; +use tokio::select; + +#[derive(Debug)] +struct RangeBatchGenerator { + schema: SchemaRef, + value_range: Range<i64>, + boundedness: Boundedness, + batch_size: usize, + poll_count: usize, +} + +impl std::fmt::Display for RangeBatchGenerator { + fn fmt(&self, f: &mut Formatter) -> std::fmt::Result { + // Display current counter + write!(f, "InfiniteGenerator(counter={})", self.poll_count) + } +} + +impl LazyBatchGenerator for RangeBatchGenerator { + fn boundedness(&self) -> Boundedness { + self.boundedness + } + + /// Generate the next RecordBatch. + fn generate_next_batch(&mut self) -> datafusion_common::Result<Option<RecordBatch>> { + self.poll_count += 1; + + let mut builder = Int64Array::builder(self.batch_size); + for _ in 0..self.batch_size { + match self.value_range.next() { + None => break, + Some(v) => builder.append_value(v), + } + } + let array = builder.finish(); + + if array.is_empty() { + return Ok(None); + } + + let batch = + RecordBatch::try_new(Arc::clone(&self.schema), vec![Arc::new(array)])?; + Ok(Some(batch)) + } +} + +fn make_lazy_exec(column_name: &str, pretend_infinite: bool) -> LazyMemoryExec { + make_lazy_exec_with_range(column_name, i64::MIN..i64::MAX, pretend_infinite) +} + +fn make_lazy_exec_with_range( + column_name: &str, + range: Range<i64>, + pretend_infinite: bool, +) -> LazyMemoryExec { + let schema = Arc::new(Schema::new(vec![Field::new( + column_name, + DataType::Int64, + false, + )])); + + let boundedness = if pretend_infinite { + Boundedness::Unbounded { + requires_infinite_memory: false, + } + } else { + Boundedness::Bounded + }; + + // Instantiate the generator with the batch and limit + let gen = RangeBatchGenerator { + schema: Arc::clone(&schema), + boundedness, + value_range: range, + batch_size: 8192, + poll_count: 0, + }; + + // Wrap the generator in a trait object behind Arc<RwLock<_>> + let generator: Arc<RwLock<dyn LazyBatchGenerator>> = Arc::new(RwLock::new(gen)); + + // Create a LazyMemoryExec with one partition using our generator + let mut exec = LazyMemoryExec::try_new(schema, vec![generator]).unwrap(); + + exec.add_ordering(vec![PhysicalSortExpr::new( + Arc::new(Column::new(column_name, 0)), + SortOptions::new(false, true), + )]); + + exec +} + +#[rstest] +#[tokio::test] +async fn agg_no_grouping_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up an aggregation without grouping + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new( + AggregateExprBuilder::new( + sum::sum_udaf(), + vec![col("value", &inf.schema())?], + ) + .schema(inf.schema()) + .alias("sum") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn agg_grouping_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up an aggregation with grouping + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + let value_col = col("value", &inf.schema())?; + let group = binary( + value_col.clone(), + Operator::Divide, + lit(1000000i64), + &inf.schema(), + )?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![(group, "group".to_string())], vec![], vec![]), + vec![Arc::new( + AggregateExprBuilder::new(sum::sum_udaf(), vec![value_col.clone()]) + .schema(inf.schema()) + .alias("sum") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn agg_grouped_topk_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up a top-k aggregation + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + let value_col = col("value", &inf.schema())?; + let group = binary( + value_col.clone(), + Operator::Divide, + lit(1000000i64), + &inf.schema(), + )?; + + let aggr = Arc::new( + AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![(group, "group".to_string())], + vec![], + vec![vec![false]], + ), + vec![Arc::new( + AggregateExprBuilder::new(min_max::max_udaf(), vec![value_col.clone()]) + .schema(inf.schema()) + .alias("max") + .build()?, + )], + vec![None], + inf.clone(), + inf.schema(), + )? + .with_limit(Some(100)), + ); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn sort_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up the infinite source + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + // set up a SortExec that will not be able to finish in time because input is very large + let sort_expr = PhysicalSortExpr::new( + col("value", &inf.schema())?, + SortOptions { + descending: true, + nulls_first: true, + }, + ); + + let lex_ordering = LexOrdering::new(vec![sort_expr]).unwrap(); + let sort_exec = Arc::new(SortExec::new(lex_ordering, inf.clone())); + + query_yields(sort_exec, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn sort_merge_join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up the join sources + let inf1 = Arc::new(make_lazy_exec_with_range( + "value1", + i64::MIN..0, + pretend_infinite, + )); + let inf2 = Arc::new(make_lazy_exec_with_range( + "value2", + 0..i64::MAX, + pretend_infinite, + )); + + // set up a SortMergeJoinExec that will take a long time skipping left side content to find + // the first right side match + let join = Arc::new(SortMergeJoinExec::try_new( + inf1.clone(), + inf2.clone(), + vec![( + col("value1", &inf1.schema())?, + col("value2", &inf2.schema())?, + )], + None, + JoinType::Inner, + vec![inf1.properties().eq_properties.output_ordering().unwrap()[0].options], + true, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn filter_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up the infinite source + let inf = Arc::new(make_lazy_exec("value", pretend_infinite)); + + // set up a FilterExec that will filter out entire batches + let filter_expr = binary( + col("value", &inf.schema())?, + Operator::Lt, + lit(i64::MIN), + &inf.schema(), + )?; + let filter = Arc::new(FilterExec::try_new(filter_expr, inf.clone())?); + + query_yields(filter, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn filter_reject_all_batches_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Create a Session, Schema, and an 8K-row RecordBatch + let session_ctx = SessionContext::new(); + + // Wrap this batch in an InfiniteExec + let infinite = make_lazy_exec_with_range("value", i64::MIN..0, pretend_infinite); + + // 2b) Construct a FilterExec that is always false: “value > 10000” (no rows pass) + let false_predicate = Arc::new(BinaryExpr::new( + Arc::new(Column::new("value", 0)), + Gt, + Arc::new(Literal::new(ScalarValue::Int64(Some(0)))), + )); + let filtered = Arc::new(FilterExec::try_new(false_predicate, Arc::new(infinite))?); + + // Use CoalesceBatchesExec to guarantee each Filter pull always yields an 8192-row batch + let coalesced = Arc::new(CoalesceBatchesExec::new(filtered, 8_192)); + + query_yields(coalesced, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn interleave_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Build a session and a schema with one i64 column. + let session_ctx = SessionContext::new(); + + // Create multiple infinite sources, each filtered by a different threshold. + // This ensures InterleaveExec has many children. + let mut infinite_children = vec![]; + // Use 32 distinct thresholds (each > 0 and < 8192) for 32 infinite inputs. + let thresholds = (0..32).map(|i| 8191 - (i * 256) as i64); + + for thr in thresholds { + // One infinite exec: + let mut inf = make_lazy_exec("value", pretend_infinite); + + // Now repartition so that all children share identical Hash partitioning + // on “value” into 1 bucket. This is required for InterleaveExec::try_new. + let exprs = vec![Arc::new(Column::new("value", 0)) as _]; + let partitioning = Partitioning::Hash(exprs, 1); + inf.try_set_partitioning(partitioning)?; + + // Apply a FilterExec: “value > thr”. + let filter_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("value", 0)), + Gt, + Arc::new(Literal::new(ScalarValue::Int64(Some(thr)))), + )); + let filtered = Arc::new(FilterExec::try_new(filter_expr, Arc::new(inf))?); + + infinite_children.push(filtered as _); + } + + // Build an InterleaveExec over all infinite children. + let interleave = Arc::new(InterleaveExec::try_new(infinite_children)?); + + // Wrap the InterleaveExec in a FilterExec that always returns false, + // ensuring that no rows are ever emitted. + let always_false = Arc::new(Literal::new(ScalarValue::Boolean(Some(false)))); + let filtered_interleave = Arc::new(FilterExec::try_new(always_false, interleave)?); + + query_yields(filtered_interleave, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn interleave_agg_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Build session, schema, and a sample batch. + let session_ctx = SessionContext::new(); + + // Create N infinite sources, each filtered by a different predicate. + // That way, the InterleaveExec will have multiple children. + let mut infinite_children = vec![]; + // Use 32 distinct thresholds (each >0 and <8 192) to force 32 infinite inputs + let thresholds = (0..32).map(|i| 8_192 - 1 - (i * 256) as i64); + + for thr in thresholds { + // One infinite exec: + let mut inf = make_lazy_exec("value", pretend_infinite); + + // Now repartition so that all children share identical Hash partitioning + // on “value” into 1 bucket. This is required for InterleaveExec::try_new. + let exprs = vec![Arc::new(Column::new("value", 0)) as _]; + let partitioning = Partitioning::Hash(exprs, 1); + inf.try_set_partitioning(partitioning)?; + + // Apply a FilterExec: “value > thr”. + let filter_expr = Arc::new(BinaryExpr::new( + Arc::new(Column::new("value", 0)), + Gt, + Arc::new(Literal::new(ScalarValue::Int64(Some(thr)))), + )); + let filtered = Arc::new(FilterExec::try_new(filter_expr, Arc::new(inf))?); + + infinite_children.push(filtered as _); + } + + // Build an InterleaveExec over all N children. + // Since each child now has Partitioning::Hash([col "value"], 1), InterleaveExec::try_new succeeds. + let interleave = Arc::new(InterleaveExec::try_new(infinite_children)?); + let interleave_schema = interleave.schema(); + + // Build a global AggregateExec that sums “value” over all rows. + // Because we use `AggregateMode::Single` with no GROUP BY columns, this plan will + // only produce one “final” row once all inputs finish. But our inputs never finish, + // so we should never get any output. + let aggregate_expr = AggregateExprBuilder::new( + sum::sum_udaf(), + vec![Arc::new(Column::new("value", 0))], + ) + .schema(interleave_schema.clone()) + .alias("total") + .build()?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new( + vec![], // no GROUP BY columns + vec![], // no GROUP BY expressions + vec![], // no GROUP BY physical expressions + ), + vec![Arc::new(aggregate_expr)], + vec![None], // no “distinct” flags + interleave, + interleave_schema, + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Session, schema, and a single 8 K‐row batch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Create Join keys → join on “value” = “value” + let left_keys: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(Column::new("value", 0))]; + let right_keys: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(Column::new("value", 0))]; + + // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition + let coalesced_left = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); + let coalesced_right = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); + + let part_left = Partitioning::Hash(left_keys, 1); + let part_right = Partitioning::Hash(right_keys, 1); + + let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); + let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + + // Build an Inner HashJoinExec → left.value = right.value + let join = Arc::new(HashJoinExec::try_new( + hashed_left, + hashed_right, + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + true, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn join_agg_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Session, schema, and a single 8 K‐row batch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // 2b) Create Join keys → join on “value” = “value” + let left_keys: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(Column::new("value", 0))]; + let right_keys: Vec<Arc<dyn PhysicalExpr>> = vec![Arc::new(Column::new("value", 0))]; + + // Wrap each side in CoalesceBatches + Repartition so they are both hashed into 1 partition + let coalesced_left = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_left), 8_192)); + let coalesced_right = + Arc::new(CoalesceBatchesExec::new(Arc::new(infinite_right), 8_192)); + + let part_left = Partitioning::Hash(left_keys, 1); + let part_right = Partitioning::Hash(right_keys, 1); + + let hashed_left = Arc::new(RepartitionExec::try_new(coalesced_left, part_left)?); + let hashed_right = Arc::new(RepartitionExec::try_new(coalesced_right, part_right)?); + + // Build an Inner HashJoinExec → left.value = right.value + let join = Arc::new(HashJoinExec::try_new( + hashed_left, + hashed_right, + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + None, + &JoinType::Inner, + None, + PartitionMode::CollectLeft, + true, + )?); + + // Project only one column (“value” from the left side) because we just want to sum that + let input_schema = join.schema(); + + let proj_expr = vec![( + Arc::new(Column::new_with_schema("value", &input_schema)?) as _, + "value".to_string(), + )]; + + let projection = Arc::new(ProjectionExec::try_new(proj_expr, join)?); + let projection_schema = projection.schema(); + + let output_fields = vec![Field::new("total", DataType::Int64, true)]; + let output_schema = Arc::new(Schema::new(output_fields)); + + // 4) Global aggregate (Single) over “value” + let aggregate_expr = AggregateExprBuilder::new( + sum::sum_udaf(), + vec![Arc::new(Column::new_with_schema( + "value", + &projection.schema(), + )?)], + ) + .schema(output_schema) + .alias("total") + .build()?; + + let aggr = Arc::new(AggregateExec::try_new( + AggregateMode::Single, + PhysicalGroupBy::new(vec![], vec![], vec![]), + vec![Arc::new(aggregate_expr)], + vec![None], + projection, + projection_schema, + )?); + + query_yields(aggr, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn hash_join_yields( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // build session + let session_ctx = SessionContext::new(); + + // set up the join sources + let inf1 = Arc::new(make_lazy_exec("value1", pretend_infinite)); + let inf2 = Arc::new(make_lazy_exec("value2", pretend_infinite)); + + // set up a HashJoinExec that will take a long time in the build phase + let join = Arc::new(HashJoinExec::try_new( + inf1.clone(), + inf2.clone(), + vec![( + col("value1", &inf1.schema())?, + col("value2", &inf2.schema())?, + )], + None, + &JoinType::Left, + None, + PartitionMode::CollectLeft, + true, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +#[rstest] +#[tokio::test] +async fn hash_join_without_repartition_and_no_agg( + #[values(false, true)] pretend_infinite: bool, +) -> Result<(), Box<dyn Error>> { + // Create Session, schema, and an 8K-row RecordBatch for each side + let session_ctx = SessionContext::new(); + + // on the right side, we’ll shift each value by +1 so that not everything joins, + // but plenty of matching keys exist (e.g. 0 on left matches 1 on right, etc.) + let infinite_left = make_lazy_exec_with_range("value", -10..10, false); + let infinite_right = + make_lazy_exec_with_range("value", 0..i64::MAX, pretend_infinite); + + // Directly feed `infinite_left` and `infinite_right` into HashJoinExec. + // Do not use aggregation or repartition. + let join = Arc::new(HashJoinExec::try_new( + Arc::new(infinite_left), + Arc::new(infinite_right), + vec![( + Arc::new(Column::new("value", 0)), + Arc::new(Column::new("value", 0)), + )], + /* filter */ None, + &JoinType::Inner, + /* output64 */ None, + // Using CollectLeft is fine—just avoid RepartitionExec’s partitioned channels. + PartitionMode::CollectLeft, + /* build_left */ true, + )?); + + query_yields(join, session_ctx.task_ctx()).await +} + +async fn query_yields( + plan: Arc<dyn ExecutionPlan>, + task_ctx: Arc<TaskContext>, +) -> Result<(), Box<dyn Error>> { + // Run plan through EnsureCooperative + let optimized = + EnsureCooperative::new().optimize(plan, task_ctx.session_config().options())?; + + // Get the stream + let mut stream = physical_plan::execute_stream(optimized, task_ctx)?; + + // Create an independent executor pool + let child_runtime = Runtime::new()?; + + // Spawn a task that tries to poll the stream + // The task returns Ready when the stream yielded with either Ready or Pending + let join_handle = child_runtime.spawn(std::future::poll_fn(move |cx| { + match stream.poll_next_unpin(cx) { + Poll::Ready(_) => Poll::Ready(Poll::Ready(())), + Poll::Pending => Poll::Ready(Poll::Pending), + } + })); + + let abort_handle = join_handle.abort_handle(); + + // Now select on the join handle of the task running in the child executor with a timeout + let yielded = select! { + _ = join_handle => true, + _ = tokio::time::sleep(tokio::time::Duration::from_secs(10)) => false + }; + + // Try to abort the poll task and shutdown the child runtime + abort_handle.abort(); + Handle::current().spawn_blocking(move || { + drop(child_runtime); + }); + + // Finally, check if poll_next yielded + assert!(yielded, "Task did not yield in a timely fashion"); + Ok(()) +} Review Comment: maybe @2010YOUY01 or @ding-young have some ideas -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org