mertak-synnada commented on code in PR #14411:
URL: https://github.com/apache/datafusion/pull/14411#discussion_r1944579206


##########
datafusion/physical-plan/src/repartition/on_demand_repartition.rs:
##########
@@ -0,0 +1,1320 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+//! This file implements the [`OnDemandRepartitionExec`]  operator, which maps 
N input
+//! partitions to M output partitions based on a partitioning scheme, 
optionally
+//! maintaining the order of the input rows in the output.
+
+use std::pin::Pin;
+use std::sync::Arc;
+use std::task::{Context, Poll};
+use std::{any::Any, vec};
+
+use super::metrics::{ExecutionPlanMetricsSet, MetricsSet};
+use super::{
+    BatchPartitioner, DisplayAs, ExecutionPlanProperties, MaybeBatch, 
RecordBatchStream,
+    RepartitionExecBase, RepartitionMetrics, SendableRecordBatchStream,
+};
+use crate::common::SharedMemoryReservation;
+use crate::execution_plan::CardinalityEffect;
+use crate::metrics::BaselineMetrics;
+use crate::repartition::distributor_channels::{
+    DistributionReceiver, DistributionSender,
+};
+use crate::repartition::RepartitionExecStateBuilder;
+use crate::sorts::streaming_merge::StreamingMergeBuilder;
+use crate::stream::RecordBatchStreamAdapter;
+use crate::{DisplayFormatType, ExecutionPlan, Partitioning, PlanProperties, 
Statistics};
+
+use arrow::datatypes::SchemaRef;
+use arrow::record_batch::RecordBatch;
+use async_channel::{Receiver, Sender};
+
+use datafusion_common::{internal_datafusion_err, Result};
+use datafusion_common_runtime::SpawnedTask;
+use datafusion_execution::memory_pool::MemoryConsumer;
+use datafusion_execution::TaskContext;
+
+use datafusion_common::HashMap;
+use futures::stream::Stream;
+use futures::{ready, FutureExt, StreamExt, TryStreamExt};
+use log::{debug, trace};
+use parking_lot::Mutex;
+
+type PartitionChannels = (Vec<Sender<usize>>, Vec<Receiver<usize>>);
+
+#[derive(Debug, Clone)]
+pub struct OnDemandRepartitionExec {
+    base: RepartitionExecBase,
+    /// Channel to send partition number to the downstream task
+    partition_channels: Arc<tokio::sync::OnceCell<Mutex<PartitionChannels>>>,
+}
+
+impl OnDemandRepartitionExec {
+    /// Input execution plan
+    pub fn input(&self) -> &Arc<dyn ExecutionPlan> {
+        &self.base.input
+    }
+
+    /// Partitioning scheme to use
+    pub fn partitioning(&self) -> &Partitioning {
+        &self.base.cache.partitioning
+    }
+
+    /// Get preserve_order flag of the RepartitionExecutor
+    /// `true` means `SortPreservingRepartitionExec`, `false` means 
`RepartitionExec`
+    pub fn preserve_order(&self) -> bool {
+        self.base.preserve_order
+    }
+
+    /// Specify if this reparititoning operation should preserve the order of
+    /// rows from its input when producing output. Preserving order is more
+    /// expensive at runtime, so should only be set if the output of this
+    /// operator can take advantage of it.
+    ///
+    /// If the input is not ordered, or has only one partition, this is a no 
op,
+    /// and the node remains a `RepartitionExec`.
+    pub fn with_preserve_order(mut self) -> Self {
+        self.base = self.base.with_preserve_order();
+        self
+    }
+
+    /// Get name used to display this Exec
+    pub fn name(&self) -> &str {
+        "OnDemandRepartitionExec"
+    }
+}
+
+impl DisplayAs for OnDemandRepartitionExec {
+    fn fmt_as(
+        &self,
+        t: DisplayFormatType,
+        f: &mut std::fmt::Formatter,
+    ) -> std::fmt::Result {
+        match t {
+            DisplayFormatType::Default | DisplayFormatType::Verbose => {
+                write!(
+                    f,
+                    "{}: partitioning={}, input_partitions={}",
+                    self.name(),
+                    self.partitioning(),
+                    self.base.input.output_partitioning().partition_count()
+                )?;
+
+                if self.base.preserve_order {
+                    write!(f, ", preserve_order=true")?;
+                }
+
+                if let Some(sort_exprs) = self.base.sort_exprs() {
+                    write!(f, ", sort_exprs={}", sort_exprs.clone())?;
+                }
+                Ok(())
+            }
+        }
+    }
+}
+
+impl ExecutionPlan for OnDemandRepartitionExec {
+    fn name(&self) -> &'static str {
+        "OnDemandRepartitionExec"
+    }
+
+    /// Return a reference to Any that can be used for downcasting
+    fn as_any(&self) -> &dyn Any {
+        self
+    }
+
+    fn properties(&self) -> &PlanProperties {
+        &self.base.cache
+    }
+
+    fn children(&self) -> Vec<&Arc<dyn ExecutionPlan>> {
+        vec![&self.base.input]
+    }
+
+    fn with_new_children(
+        self: Arc<Self>,
+        mut children: Vec<Arc<dyn ExecutionPlan>>,
+    ) -> Result<Arc<dyn ExecutionPlan>> {
+        let mut repartition = OnDemandRepartitionExec::try_new(
+            children.swap_remove(0),
+            self.partitioning().clone(),
+        )?;
+        if self.base.preserve_order {
+            repartition = repartition.with_preserve_order();
+        }
+        Ok(Arc::new(repartition))
+    }
+
+    fn benefits_from_input_partitioning(&self) -> Vec<bool> {
+        vec![matches!(self.partitioning(), Partitioning::Hash(_, _))]
+    }
+
+    fn maintains_input_order(&self) -> Vec<bool> {
+        RepartitionExecBase::maintains_input_order_helper(
+            self.input(),
+            self.base.preserve_order,
+        )
+    }
+
+    fn execute(
+        &self,
+        partition: usize,
+        context: Arc<TaskContext>,
+    ) -> Result<SendableRecordBatchStream> {
+        trace!(
+            "Start {}::execute for partition: {}",
+            self.name(),
+            partition
+        );
+
+        let lazy_state = Arc::clone(&self.base.state);
+        let partition_channels = Arc::clone(&self.partition_channels);
+        let input = Arc::clone(&self.base.input);
+        let partitioning = self.partitioning().clone();
+        let metrics = self.base.metrics.clone();
+        let preserve_order = self.base.preserve_order;
+        let name = self.name().to_owned();
+        let schema = self.schema();
+        let schema_captured = Arc::clone(&schema);
+
+        // Get existing ordering to use for merging
+        let sort_exprs = self.base.sort_exprs().cloned().unwrap_or_default();
+
+        let stream = futures::stream::once(async move {
+            let num_input_partitions = 
input.output_partitioning().partition_count();
+
+            let input_captured = Arc::clone(&input);
+            let metrics_captured = metrics.clone();
+            let name_captured = name.clone();
+            let context_captured = Arc::clone(&context);
+            let partition_channels = partition_channels
+                .get_or_init(|| async move {
+                    let (txs, rxs) = if preserve_order {
+                        (0..num_input_partitions)
+                            .map(|_| async_channel::bounded(2))
+                            .unzip::<_, _, Vec<_>, Vec<_>>()
+                    } else {
+                        let (tx, rx) = async_channel::bounded(2);
+                        (vec![tx], vec![rx])
+                    };
+                    Mutex::new((txs, rxs))
+                })
+                .await;
+            let (partition_txs, partition_rxs) = {
+                let channel = partition_channels.lock();
+                (channel.0.clone(), channel.1.clone())
+            };
+
+            let state = lazy_state
+                .get_or_init(|| async move {
+                    Mutex::new(
+                        RepartitionExecStateBuilder::new()
+                            .enable_pull_based(true)
+                            .partition_receivers(partition_rxs.clone())
+                            .build(
+                                input_captured,
+                                partitioning.clone(),
+                                metrics_captured,
+                                preserve_order,
+                                name_captured,
+                                context_captured,
+                            ),
+                    )
+                })
+                .await;
+
+            // lock scope
+            let (mut rx, reservation, abort_helper) = {
+                // lock mutexes
+                let mut state = state.lock();
+
+                // now return stream for the specified *output* partition 
which will
+                // read from the channel
+                let (_tx, rx, reservation) = state
+                    .channels
+                    .remove(&partition)
+                    .expect("partition not used yet");
+
+                (rx, reservation, Arc::clone(&state.abort_helper))
+            };
+
+            trace!(
+                "Before returning stream in {}::execute for partition: {}",
+                name,
+                partition
+            );
+
+            if preserve_order {
+                // Store streams from all the input partitions:
+                let input_streams = rx
+                    .into_iter()
+                    .enumerate()
+                    .map(|(i, receiver)| {
+                        // sender should be partition-wise
+                        Box::pin(OnDemandPerPartitionStream {
+                            schema: Arc::clone(&schema_captured),
+                            receiver,
+                            _drop_helper: Arc::clone(&abort_helper),
+                            reservation: Arc::clone(&reservation),
+                            sender: partition_txs[i].clone(),
+                            partition,
+                            is_requested: false,
+                        }) as SendableRecordBatchStream
+                    })
+                    .collect::<Vec<_>>();
+                // Note that receiver size (`rx.len()`) and 
`num_input_partitions` are same.
+
+                // Merge streams (while preserving ordering) coming from
+                // input partitions to this partition:
+                let fetch = None;
+                let merge_reservation =
+                    MemoryConsumer::new(format!("{}[Merge {partition}]", name))
+                        .register(context.memory_pool());
+                StreamingMergeBuilder::new()
+                    .with_streams(input_streams)
+                    .with_schema(schema_captured)
+                    .with_expressions(&sort_exprs)
+                    .with_metrics(BaselineMetrics::new(&metrics, partition))
+                    .with_batch_size(context.session_config().batch_size())
+                    .with_fetch(fetch)
+                    .with_reservation(merge_reservation)
+                    .build()
+            } else {
+                Ok(Box::pin(OnDemandRepartitionStream {
+                    num_input_partitions,
+                    num_input_partitions_processed: 0,
+                    schema: input.schema(),
+                    input: rx.swap_remove(0),
+                    _drop_helper: abort_helper,
+                    reservation,
+                    sender: partition_txs[0].clone(),
+                    partition,
+                    is_requested: false,
+                }) as SendableRecordBatchStream)
+            }
+        })
+        .try_flatten();
+        let stream = RecordBatchStreamAdapter::new(schema, stream);
+        Ok(Box::pin(stream))
+    }
+
+    fn metrics(&self) -> Option<MetricsSet> {
+        Some(self.base.metrics.clone_inner())
+    }
+
+    fn statistics(&self) -> Result<Statistics> {
+        self.base.input.statistics()
+    }
+
+    fn cardinality_effect(&self) -> CardinalityEffect {
+        CardinalityEffect::Equal
+    }
+}
+
+impl OnDemandRepartitionExec {
+    /// Create a new RepartitionExec, that produces output `partitioning`, and
+    /// does not preserve the order of the input (see 
[`Self::with_preserve_order`]
+    /// for more details)
+    pub fn try_new(
+        input: Arc<dyn ExecutionPlan>,
+        partitioning: Partitioning,
+    ) -> Result<Self> {
+        let preserve_order = false;
+        let cache = RepartitionExecBase::compute_properties(
+            &input,
+            partitioning.clone(),
+            preserve_order,
+        );
+        Ok(OnDemandRepartitionExec {
+            base: RepartitionExecBase {
+                input,
+                state: Default::default(),
+                metrics: ExecutionPlanMetricsSet::new(),
+                preserve_order,
+                cache,
+            },
+            partition_channels: Default::default(),
+        })
+    }
+
+    async fn process_input(
+        input: Arc<dyn ExecutionPlan>,
+        partition: usize,
+        buffer_tx: tokio::sync::mpsc::Sender<RecordBatch>,
+        context: Arc<TaskContext>,
+    ) -> Result<()> {
+        let mut stream = input.execute(partition, context).map_err(|e| {
+            internal_datafusion_err!(
+                "Error executing input partition {} for on demand 
repartitioning: {}",
+                partition,
+                e
+            )
+        })?;
+        while let Some(batch) = stream.next().await {
+            buffer_tx.send(batch?).await.map_err(|e| {
+                internal_datafusion_err!(
+                    "Error sending batch to buffer channel for partition {}: 
{}",
+                    partition,
+                    e
+                )
+            })?;
+        }
+        debug!(
+            "On demand input partition {} processing finished",
+            partition
+        );
+
+        Ok(())
+    }
+
+    /// Pulls data from the specified input plan, feeding it to the
+    /// output partitions based on the desired partitioning
+    ///
+    /// txs hold the output sending channels for each output partition
+    pub(crate) async fn pull_from_input(
+        input: Arc<dyn ExecutionPlan>,
+        partition: usize,
+        mut output_channels: HashMap<
+            usize,
+            (DistributionSender<MaybeBatch>, SharedMemoryReservation),
+        >,
+        partitioning: Partitioning,
+        output_partition_rx: Receiver<usize>,
+        metrics: RepartitionMetrics,
+        context: Arc<TaskContext>,
+    ) -> Result<()> {
+        let _ = BatchPartitioner::try_new(
+            partitioning.clone(),
+            metrics.repartition_time.clone(),
+        )?;
+
+        // execute the child operator in a separate task
+        let (buffer_tx, mut buffer_rx) = 
tokio::sync::mpsc::channel::<RecordBatch>(2);
+        let processing_task = SpawnedTask::spawn(Self::process_input(
+            Arc::clone(&input),
+            partition,
+            buffer_tx,
+            Arc::clone(&context),
+        ));
+
+        // While there are still outputs to send to, keep pulling inputs
+        let mut batches_until_yield = partitioning.partition_count();
+        while !output_channels.is_empty() {
+            // Input is done
+            let batch = match buffer_rx.recv().await {
+                Some(result) => result,
+                None => break,
+            };
+
+            // Get the partition number from the output partition
+            let partition = output_partition_rx.recv().await.map_err(|e| {
+                internal_datafusion_err!(
+                    "Error receiving partition number from output partition: 
{}",
+                    e
+                )
+            })?;
+
+            let size = batch.get_array_memory_size();
+
+            let timer = metrics.send_time[partition].timer();
+            // if there is still a receiver, send to it
+            if let Some((tx, reservation)) = 
output_channels.get_mut(&partition) {
+                reservation.lock().try_grow(size)?;
+
+                if tx.send(Some(Ok(batch))).await.is_err() {
+                    // If the other end has hung up, it was an early shutdown 
(e.g. LIMIT)
+                    reservation.lock().shrink(size);
+                    output_channels.remove(&partition);
+                }
+            }
+            timer.done();
+
+            // If the input stream is endless, we may spin forever and
+            // never yield back to tokio.  See
+            // https://github.com/apache/datafusion/issues/5278.
+            //
+            // However, yielding on every batch causes a bottleneck
+            // when running with multiple cores. See
+            // https://github.com/apache/datafusion/issues/6290
+            //
+            // Thus, heuristically yield after producing num_partition
+            // batches
+            //
+            // In round robin this is ideal as each input will get a
+            // new batch. In hash partitioning it may yield too often
+            // on uneven distributions even if some partition can not
+            // make progress, but parallelism is going to be limited
+            // in that case anyways

Review Comment:
   I meant the comments since its about round robin and hash partitioning



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to