benbellick commented on code in PR #21193:
URL: https://github.com/apache/datafusion/pull/21193#discussion_r3182715538


##########
datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs:
##########
@@ -35,7 +38,68 @@ pub fn from_higher_order_function(
     fun: &expr::HigherOrderFunction,
     schema: &DFSchemaRef,
 ) -> datafusion::common::Result<Expression> {
-    from_function(producer, fun.name(), &fun.args, schema)
+    let mut lambda_parameters = fun.lambda_parameters(schema)?.into_iter();
+
+    let num_lambdas = fun
+        .args
+        .iter()
+        .filter(|arg| matches!(arg, Expr::Lambda(_)))
+        .count();
+
+    if lambda_parameters.len() != num_lambdas {
+        return substrait_err!(
+            "{} returned {} lambdas but {num_lambdas} expected",
+            fun.name(),
+            lambda_parameters.len()
+        );
+    }
+
+    let arguments = fun
+        .args
+        .iter()
+        .map(|arg| {
+            let arg = match arg {
+                Expr::Lambda(l) => {
+                    let lambda_parameters =
+                        lambda_parameters.next().ok_or_else(|| {
+                            internal_datafusion_err!(
+                                "lambda_parameters len should have been 
checked above"
+                            )
+                        })?;
+
+                    let named_lambda_parameters =
+                        std::iter::zip(&l.params, lambda_parameters)
+                            .map(|(name, parameter)| parameter.renamed(name))
+                            .collect();
+
+                    producer.push_lambda_parameters(named_lambda_parameters)?;
+
+                    let arg = producer.handle_expr(arg, schema)?;

Review Comment:
   Minor thing, but if `handle_expr` fails for some reason, we will leave the 
lambda parameters on the producer stack, from:
   ```rust
   producer.push_lambda_parameters(named_lambda_parameters)?;
   ```
   
   I'm not sure it matters because I think we dispose of the whole producer 
struct on failure, but just wanted to call attention to it.



##########
datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs:
##########
@@ -594,6 +662,124 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
         let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
         Ok(LogicalPlan::Extension(Extension { node: plan }))
     }
+
+    fn with_lambda_parameters(
+        &self,
+        lambda_parameters: &[Type],
+        input_schema: &DFSchema,
+    ) -> datafusion::common::Result<(Vec<String>, Self)> {
+        let (names, lambda_consumer) = 
self.lambda_consumer.with_lambda_parameters(
+            self,
+            lambda_parameters,
+            input_schema,
+        )?;
+
+        Ok((
+            names,
+            Self {
+                extensions: self.extensions,
+                state: self.state,
+                outer_schemas: 
RwLock::new(self.outer_schemas.read().unwrap().clone()),
+                lambda_consumer,
+            },
+        ))
+    }
+
+    fn lambda_variable(
+        &self,
+        steps_out: usize,
+        field_idx: usize,
+    ) -> datafusion::common::Result<Expr> {
+        self.lambda_consumer.lambda_variable(steps_out, field_idx)
+    }
+}
+
+/// Default implementation of lambda related methods of the 
[SubstraitConsumer] trait
+///
+/// Can be embedded into a custom [SubstraitConsumer] to implement them
+pub struct DefaultSubstraitLambdaConsumer {

Review Comment:
   Is there a reason this is public? This feels like an implementation detail 
of the default lambda-handling logic. What about:
   ```suggestion
   struct LambdaConsumerState {
   ```



##########
datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs:
##########
@@ -594,6 +662,124 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
         let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
         Ok(LogicalPlan::Extension(Extension { node: plan }))
     }
+
+    fn with_lambda_parameters(
+        &self,
+        lambda_parameters: &[Type],
+        input_schema: &DFSchema,
+    ) -> datafusion::common::Result<(Vec<String>, Self)> {
+        let (names, lambda_consumer) = 
self.lambda_consumer.with_lambda_parameters(
+            self,
+            lambda_parameters,
+            input_schema,
+        )?;
+
+        Ok((
+            names,
+            Self {
+                extensions: self.extensions,
+                state: self.state,
+                outer_schemas: 
RwLock::new(self.outer_schemas.read().unwrap().clone()),
+                lambda_consumer,
+            },
+        ))
+    }
+
+    fn lambda_variable(
+        &self,
+        steps_out: usize,
+        field_idx: usize,
+    ) -> datafusion::common::Result<Expr> {
+        self.lambda_consumer.lambda_variable(steps_out, field_idx)
+    }
+}
+
+/// Default implementation of lambda related methods of the 
[SubstraitConsumer] trait
+///
+/// Can be embedded into a custom [SubstraitConsumer] to implement them
+pub struct DefaultSubstraitLambdaConsumer {
+    lambdas_parameters: VecDeque<Vec<FieldRef>>,

Review Comment:
   I _think_ these are the parameters from innermost to outermost. Is that 
correct? If so, it would be great to add a comment here spelling that out. 



##########
datafusion/substrait/src/logical_plan/consumer/substrait_consumer.rs:
##########
@@ -594,6 +662,124 @@ impl SubstraitConsumer for DefaultSubstraitConsumer<'_> {
         let plan = plan.with_exprs_and_inputs(plan.expressions(), inputs)?;
         Ok(LogicalPlan::Extension(Extension { node: plan }))
     }
+
+    fn with_lambda_parameters(
+        &self,
+        lambda_parameters: &[Type],
+        input_schema: &DFSchema,
+    ) -> datafusion::common::Result<(Vec<String>, Self)> {
+        let (names, lambda_consumer) = 
self.lambda_consumer.with_lambda_parameters(
+            self,
+            lambda_parameters,
+            input_schema,
+        )?;
+
+        Ok((
+            names,
+            Self {
+                extensions: self.extensions,
+                state: self.state,
+                outer_schemas: 
RwLock::new(self.outer_schemas.read().unwrap().clone()),
+                lambda_consumer,
+            },
+        ))
+    }
+
+    fn lambda_variable(
+        &self,
+        steps_out: usize,
+        field_idx: usize,
+    ) -> datafusion::common::Result<Expr> {
+        self.lambda_consumer.lambda_variable(steps_out, field_idx)
+    }
+}
+
+/// Default implementation of lambda related methods of the 
[SubstraitConsumer] trait
+///
+/// Can be embedded into a custom [SubstraitConsumer] to implement them
+pub struct DefaultSubstraitLambdaConsumer {
+    lambdas_parameters: VecDeque<Vec<FieldRef>>,
+    next_lambda_parameter: usize,
+}
+
+impl Default for DefaultSubstraitLambdaConsumer {
+    fn default() -> Self {
+        Self::new()
+    }
+}
+
+impl DefaultSubstraitLambdaConsumer {
+    pub fn new() -> Self {
+        Self {
+            lambdas_parameters: VecDeque::new(),
+            next_lambda_parameter: 0,
+        }
+    }
+
+    pub fn with_lambda_parameters(
+        &self,
+        consumer: &impl SubstraitConsumer,
+        lambda_parameters: &[Type],
+        input_schema: &DFSchema,
+    ) -> datafusion::common::Result<(Vec<String>, Self)> {
+        let mut next_lambda_parameter = self.next_lambda_parameter;
+
+        let lambda_parameters = lambda_parameters
+            .iter()
+            .map(|ty| {
+                loop {
+                    let default_name = format!("p{next_lambda_parameter}");

Review Comment:
   What do you think about extracting out the name generation into a helper 
called something like `next_lambda_parameter_name`? This way, the name 
generation can be separated and it can have a docstring which clarifies the 
strategy (which I believe is to always just take the name `pN` where `N` is 
`next_lambda_parameter`, but skip `N` if there is a collision).



##########
datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs:
##########
@@ -45,6 +45,20 @@ pub async fn from_scalar_function(
     let fn_name = substrait_fun_name(fn_signature);
     let args = from_substrait_func_args(consumer, &f.arguments, 
input_schema).await?;
 
+    let udlf_func = consumer

Review Comment:
   Perhaps a clearer name for this variable would be `udf_lambda_func` or 
something that makes the lambda more explicit? 



##########
datafusion/substrait/src/logical_plan/producer/expr/scalar_function.rs:
##########
@@ -35,7 +38,68 @@ pub fn from_higher_order_function(
     fun: &expr::HigherOrderFunction,
     schema: &DFSchemaRef,
 ) -> datafusion::common::Result<Expression> {
-    from_function(producer, fun.name(), &fun.args, schema)
+    let mut lambda_parameters = fun.lambda_parameters(schema)?.into_iter();
+
+    let num_lambdas = fun
+        .args
+        .iter()
+        .filter(|arg| matches!(arg, Expr::Lambda(_)))
+        .count();
+
+    if lambda_parameters.len() != num_lambdas {
+        return substrait_err!(
+            "{} returned {} lambdas but {num_lambdas} expected",
+            fun.name(),
+            lambda_parameters.len()
+        );
+    }
+
+    let arguments = fun
+        .args
+        .iter()
+        .map(|arg| {
+            let arg = match arg {
+                Expr::Lambda(l) => {
+                    let lambda_parameters =
+                        lambda_parameters.next().ok_or_else(|| {
+                            internal_datafusion_err!(
+                                "lambda_parameters len should have been 
checked above"
+                            )
+                        })?;
+
+                    let named_lambda_parameters =
+                        std::iter::zip(&l.params, lambda_parameters)
+                            .map(|(name, parameter)| parameter.renamed(name))
+                            .collect();
+
+                    producer.push_lambda_parameters(named_lambda_parameters)?;
+
+                    let arg = producer.handle_expr(arg, schema)?;

Review Comment:
   I think it is slightly more readable if we just pass in the already 
extracted lambda `l`. 
   ```suggestion
                       let arg = producer.handle_lambda(l, schema)?;
   ```



##########
datafusion/substrait/src/logical_plan/consumer/expr/scalar_function.rs:
##########


Review Comment:
   There is a TODO on 33 that says
   ```
   //TODO: handle higher order functions, as they are also encoded as scalar 
functions
   ```
   which can be removed here



##########
datafusion/substrait/tests/cases/roundtrip_logical_plan.rs:
##########


Review Comment:
   These tests are great!
   
   One additional thing that might be useful is a small number of tests that 
exercise the producer and consumer independently. The roundtrip tests verify 
that the producer and consumer are internally consistent with each other, but 
they don’t make it as obvious what Substrait representation we expect to 
support.
   
   There is some precedent for both styles:
   - Consumer-side tests that load Substrait JSON and convert it to a 
DataFusion plan: 
https://github.com/apache/datafusion/blob/fa9ada36871074b622f3ca67fcdaf34d7a1efdbc/datafusion/substrait/tests/cases/emit_kind_tests.rs#L34
   - Producer-side tests that call `to_substrait_plan` and inspect the 
generated proto: 
https://github.com/apache/datafusion/blob/fa9ada36871074b622f3ca67fcdaf34d7a1efdbc/datafusion/substrait/tests/cases/serialize.rs#L114-L126
   
   It might be nice to add one or two similar tests for lambdas, so the 
expected Substrait shape for `Lambda` / `LambdaParameterReference` is 
documented by the tests.
   



##########
datafusion/substrait/src/logical_plan/producer/substrait_producer.rs:
##########
@@ -471,4 +551,109 @@ impl SubstraitProducer for DefaultSubstraitProducer<'_> {
             rel_type: Some(rel_type),
         }))
     }
+
+    fn push_lambda_parameters(
+        &mut self,
+        lambda_parameters: Vec<FieldRef>,
+    ) -> datafusion::common::Result<()> {
+        let lambda_parameters_map = lambda_parameters_map(self, 
lambda_parameters)?;
+
+        self.lambda_producer
+            .push_lambda_parameters(lambda_parameters_map);
+
+        Ok(())
+    }
+
+    fn pop_lambda_parameters(&mut self) -> datafusion::common::Result<()> {
+        self.lambda_producer.pop_lambda_parameters()
+    }
+
+    fn lambda_variable(&self, name: &str) -> datafusion::common::Result<(u32, 
i32)> {
+        self.lambda_producer.lambda_variable(name)
+    }
+
+    fn lambda_parameter_type(
+        &self,
+        name: &str,
+    ) -> datafusion::common::Result<substrait::proto::Type> {
+        self.lambda_producer.lambda_parameter_type(name)
+    }
+}
+
+/// Default implementation of lambda related methods of the 
[SubstraitProducer] trait
+///
+/// Can be embedded into a custom [SubstraitProducer] to implement them
+pub struct DefaultSubstraitLambdaProducer {

Review Comment:
   Same comment as on the consumer side. I wonder if we can just keep this 
private, since its usage in implementing this producer is an implementation 
detail. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to