vbarua commented on code in PR #13931:
URL: https://github.com/apache/datafusion/pull/13931#discussion_r1899713365


##########
datafusion/substrait/src/logical_plan/producer.rs:
##########
@@ -998,450 +1304,418 @@ pub fn make_binary_op_scalar_func(
 /// Convert DataFusion Expr to Substrait Rex
 ///
 /// # Arguments
-///
-/// * `expr` - DataFusion expression to be parse into a Substrait expression
-/// * `schema` - DataFusion input schema for looking up field qualifiers
-/// * `col_ref_offset` - Offset for calculating Substrait field reference 
indices.
-///                     This should only be set by caller with more than one 
input relations i.e. Join.
-///                     Substrait expects one set of indices when joining two 
relations.
-///                     Let's say `left` and `right` have `m` and `n` columns, 
respectively. The `right`
-///                     relation will have column indices from `0` to `n-1`, 
however, Substrait will expect
-///                     the `right` indices to be offset by the `left`. This 
means Substrait will expect to
-///                     evaluate the join condition expression on indices [0 
.. n-1, n .. n+m-1]. For example:
-///                     ```SELECT *
-///                        FROM t1
-///                        JOIN t2
-///                        ON t1.c1 = t2.c0;```
-///                     where t1 consists of columns [c0, c1, c2], and t2 = 
columns [c0, c1]
-///                     the join condition should become
-///                     `col_ref(1) = col_ref(3 + 0)`
-///                     , where `3` is the number of `left` columns 
(`col_ref_offset`) and `0` is the index
-///                     of the join key column from `right`
-/// * `extensions` - Substrait extension info. Contains registered function 
information
-#[allow(deprecated)]
+/// * `producer` - SubstraitProducer implementation which the handles the 
actual conversion
+/// * `expr` - DataFusion expression to convert into a Substrait expression
+/// * `schema` - DataFusion input schema for looking up columns
 pub fn to_substrait_rex(
-    state: &dyn SubstraitPlanningState,
+    producer: &mut impl SubstraitProducer,
     expr: &Expr,
     schema: &DFSchemaRef,
-    col_ref_offset: usize,
-    extensions: &mut Extensions,
 ) -> Result<Expression> {
     match expr {
-        Expr::InList(InList {
-            expr,
-            list,
-            negated,
-        }) => {
-            let substrait_list = list
-                .iter()
-                .map(|x| to_substrait_rex(state, x, schema, col_ref_offset, 
extensions))
-                .collect::<Result<Vec<Expression>>>()?;
-            let substrait_expr =
-                to_substrait_rex(state, expr, schema, col_ref_offset, 
extensions)?;
-
-            let substrait_or_list = Expression {
-                rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList 
{
-                    value: Some(Box::new(substrait_expr)),
-                    options: substrait_list,
-                }))),
-            };
-
-            if *negated {
-                let function_anchor = 
extensions.register_function("not".to_string());
-
-                Ok(Expression {
-                    rex_type: Some(RexType::ScalarFunction(ScalarFunction {
-                        function_reference: function_anchor,
-                        arguments: vec![FunctionArgument {
-                            arg_type: Some(ArgType::Value(substrait_or_list)),
-                        }],
-                        output_type: None,
-                        args: vec![],
-                        options: vec![],
-                    })),
-                })
-            } else {
-                Ok(substrait_or_list)
-            }
+        Expr::Alias(expr) => producer.consume_alias(expr, schema),
+        Expr::Column(expr) => producer.consume_column(expr, schema),
+        Expr::Literal(expr) => producer.consume_literal(expr),
+        Expr::BinaryExpr(expr) => producer.consume_binary_expr(expr, schema),
+        Expr::Like(expr) => producer.consume_like(expr, schema),
+        Expr::SimilarTo(_) => not_impl_err!("SimilarTo is not supported"),
+        Expr::Not(_) => producer.consume_unary_expr(expr, schema),
+        Expr::IsNotNull(_) => producer.consume_unary_expr(expr, schema),
+        Expr::IsNull(_) => producer.consume_unary_expr(expr, schema),
+        Expr::IsTrue(_) => producer.consume_unary_expr(expr, schema),
+        Expr::IsFalse(_) => producer.consume_unary_expr(expr, schema),
+        Expr::IsUnknown(_) => producer.consume_unary_expr(expr, schema),
+        Expr::IsNotTrue(_) => producer.consume_unary_expr(expr, schema),
+        Expr::IsNotFalse(_) => producer.consume_unary_expr(expr, schema),
+        Expr::IsNotUnknown(_) => producer.consume_unary_expr(expr, schema),
+        Expr::Negative(_) => producer.consume_unary_expr(expr, schema),
+        Expr::Between(expr) => producer.consume_between(expr, schema),
+        Expr::Case(expr) => producer.consume_case(expr, schema),
+        Expr::Cast(expr) => producer.consume_cast(expr, schema),
+        Expr::TryCast(expr) => producer.consume_try_cast(expr, schema),
+        Expr::ScalarFunction(expr) => producer.consume_scalar_function(expr, 
schema),
+        Expr::AggregateFunction(_) => {
+            internal_err!(
+                "AggregateFunction should only be encountered as part of a 
LogicalPlan::Aggregate"
+            )
         }
-        Expr::ScalarFunction(fun) => {
-            let mut arguments: Vec<FunctionArgument> = vec![];
-            for arg in &fun.args {
-                arguments.push(FunctionArgument {
-                    arg_type: Some(ArgType::Value(to_substrait_rex(
-                        state,
-                        arg,
-                        schema,
-                        col_ref_offset,
-                        extensions,
-                    )?)),
-                });
-            }
+        Expr::WindowFunction(expr) => producer.consume_window_function(expr, 
schema),
+        Expr::InList(expr) => producer.consume_in_list(expr, schema),
+        Expr::InSubquery(expr) => producer.consume_in_subquery(expr, schema),
+        _ => not_impl_err!("Cannot convert {expr:?} to Substrait"),
+    }
+}
 
-            let function_anchor = 
extensions.register_function(fun.name().to_string());
-            Ok(Expression {
-                rex_type: Some(RexType::ScalarFunction(ScalarFunction {
-                    function_reference: function_anchor,
-                    arguments,
-                    output_type: None,
-                    args: vec![],
-                    options: vec![],
-                })),
-            })
-        }
-        Expr::Between(Between {
-            expr,
-            negated,
-            low,
-            high,
-        }) => {
-            if *negated {
-                // `expr NOT BETWEEN low AND high` can be translated into 
(expr < low OR high < expr)
-                let substrait_expr =
-                    to_substrait_rex(state, expr, schema, col_ref_offset, 
extensions)?;
-                let substrait_low =
-                    to_substrait_rex(state, low, schema, col_ref_offset, 
extensions)?;
-                let substrait_high =
-                    to_substrait_rex(state, high, schema, col_ref_offset, 
extensions)?;
-
-                let l_expr = make_binary_op_scalar_func(
-                    &substrait_expr,
-                    &substrait_low,
-                    Operator::Lt,
-                    extensions,
-                );
-                let r_expr = make_binary_op_scalar_func(
-                    &substrait_high,
-                    &substrait_expr,
-                    Operator::Lt,
-                    extensions,
-                );
+pub fn from_in_list(
+    producer: &mut impl SubstraitProducer,
+    in_list: &InList,
+    schema: &DFSchemaRef,
+) -> Result<Expression> {
+    let InList {
+        expr,
+        list,
+        negated,
+    } = in_list;
+    let substrait_list = list
+        .iter()
+        .map(|x| producer.consume_expr(x, schema))
+        .collect::<Result<Vec<Expression>>>()?;
+    let substrait_expr = producer.consume_expr(expr, schema)?;
+
+    let substrait_or_list = Expression {
+        rex_type: Some(RexType::SingularOrList(Box::new(SingularOrList {
+            value: Some(Box::new(substrait_expr)),
+            options: substrait_list,
+        }))),
+    };
 
-                Ok(make_binary_op_scalar_func(
-                    &l_expr,
-                    &r_expr,
-                    Operator::Or,
-                    extensions,
-                ))
-            } else {
-                // `expr BETWEEN low AND high` can be translated into (low <= 
expr AND expr <= high)
-                let substrait_expr =
-                    to_substrait_rex(state, expr, schema, col_ref_offset, 
extensions)?;
-                let substrait_low =
-                    to_substrait_rex(state, low, schema, col_ref_offset, 
extensions)?;
-                let substrait_high =
-                    to_substrait_rex(state, high, schema, col_ref_offset, 
extensions)?;
-
-                let l_expr = make_binary_op_scalar_func(
-                    &substrait_low,
-                    &substrait_expr,
-                    Operator::LtEq,
-                    extensions,
-                );
-                let r_expr = make_binary_op_scalar_func(
-                    &substrait_expr,
-                    &substrait_high,
-                    Operator::LtEq,
-                    extensions,
-                );
+    if *negated {
+        let function_anchor = producer.register_function("not".to_string());
 
-                Ok(make_binary_op_scalar_func(
-                    &l_expr,
-                    &r_expr,
-                    Operator::And,
-                    extensions,
-                ))
-            }
-        }
-        Expr::Column(col) => {
-            let index = schema.index_of_column(col)?;
-            substrait_field_ref(index + col_ref_offset)
-        }
-        Expr::BinaryExpr(BinaryExpr { left, op, right }) => {
-            let l = to_substrait_rex(state, left, schema, col_ref_offset, 
extensions)?;
-            let r = to_substrait_rex(state, right, schema, col_ref_offset, 
extensions)?;
+        #[allow(deprecated)]
+        Ok(Expression {
+            rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+                function_reference: function_anchor,
+                arguments: vec![FunctionArgument {
+                    arg_type: Some(ArgType::Value(substrait_or_list)),
+                }],
+                output_type: None,
+                args: vec![],
+                options: vec![],
+            })),
+        })
+    } else {
+        Ok(substrait_or_list)
+    }
+}
 
-            Ok(make_binary_op_scalar_func(&l, &r, *op, extensions))
-        }
-        Expr::Case(Case {
-            expr,
-            when_then_expr,
-            else_expr,
-        }) => {
-            let mut ifs: Vec<IfClause> = vec![];
-            // Parse base
-            if let Some(e) = expr {
-                // Base expression exists
-                ifs.push(IfClause {
-                    r#if: Some(to_substrait_rex(
-                        state,
-                        e,
-                        schema,
-                        col_ref_offset,
-                        extensions,
-                    )?),
-                    then: None,
-                });
-            }
-            // Parse `when`s
-            for (r#if, then) in when_then_expr {
-                ifs.push(IfClause {
-                    r#if: Some(to_substrait_rex(
-                        state,
-                        r#if,
-                        schema,
-                        col_ref_offset,
-                        extensions,
-                    )?),
-                    then: Some(to_substrait_rex(
-                        state,
-                        then,
-                        schema,
-                        col_ref_offset,
-                        extensions,
-                    )?),
-                });
-            }
+pub fn from_scalar_function(
+    producer: &mut impl SubstraitProducer,
+    fun: &expr::ScalarFunction,
+    schema: &DFSchemaRef,
+) -> Result<Expression> {
+    let mut arguments: Vec<FunctionArgument> = vec![];
+    for arg in &fun.args {
+        arguments.push(FunctionArgument {
+            arg_type: Some(ArgType::Value(to_substrait_rex(producer, arg, 
schema)?)),
+        });
+    }
 
-            // Parse outer `else`
-            let r#else: Option<Box<Expression>> = match else_expr {
-                Some(e) => Some(Box::new(to_substrait_rex(
-                    state,
-                    e,
-                    schema,
-                    col_ref_offset,
-                    extensions,
-                )?)),
-                None => None,
-            };
+    let function_anchor = producer.register_function(fun.name().to_string());
+    #[allow(deprecated)]
+    Ok(Expression {
+        rex_type: Some(RexType::ScalarFunction(ScalarFunction {
+            function_reference: function_anchor,
+            arguments,
+            output_type: None,
+            options: vec![],
+            args: vec![],
+        })),
+    })
+}
 
-            Ok(Expression {
-                rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else 
}))),
-            })
-        }
-        Expr::Cast(Cast { expr, data_type }) => Ok(Expression {
-            rex_type: Some(RexType::Cast(Box::new(
-                substrait::proto::expression::Cast {
-                    r#type: Some(to_substrait_type(data_type, true)?),
-                    input: Some(Box::new(to_substrait_rex(
-                        state,
-                        expr,
-                        schema,
-                        col_ref_offset,
-                        extensions,
-                    )?)),
-                    failure_behavior: FailureBehavior::ThrowException.into(),
-                },
-            ))),
-        }),
-        Expr::TryCast(TryCast { expr, data_type }) => Ok(Expression {
-            rex_type: Some(RexType::Cast(Box::new(
-                substrait::proto::expression::Cast {
-                    r#type: Some(to_substrait_type(data_type, true)?),
-                    input: Some(Box::new(to_substrait_rex(
-                        state,
-                        expr,
-                        schema,
-                        col_ref_offset,
-                        extensions,
-                    )?)),
-                    failure_behavior: FailureBehavior::ReturnNull.into(),
-                },
-            ))),
-        }),
-        Expr::Literal(value) => to_substrait_literal_expr(value, extensions),
-        Expr::Alias(Alias { expr, .. }) => {
-            to_substrait_rex(state, expr, schema, col_ref_offset, extensions)
-        }
-        Expr::WindowFunction(WindowFunction {
-            fun,
-            args,
-            partition_by,
-            order_by,
-            window_frame,
-            null_treatment: _,
-        }) => {
-            // function reference
-            let function_anchor = 
extensions.register_function(fun.to_string());
-            // arguments
-            let mut arguments: Vec<FunctionArgument> = vec![];
-            for arg in args {
-                arguments.push(FunctionArgument {
-                    arg_type: Some(ArgType::Value(to_substrait_rex(
-                        state,
-                        arg,
-                        schema,
-                        col_ref_offset,
-                        extensions,
-                    )?)),
-                });
-            }
-            // partition by expressions
-            let partition_by = partition_by
-                .iter()
-                .map(|e| to_substrait_rex(state, e, schema, col_ref_offset, 
extensions))
-                .collect::<Result<Vec<_>>>()?;
-            // order by expressions
-            let order_by = order_by
-                .iter()
-                .map(|e| substrait_sort_field(state, e, schema, extensions))
-                .collect::<Result<Vec<_>>>()?;
-            // window frame
-            let bounds = to_substrait_bounds(window_frame)?;
-            let bound_type = to_substrait_bound_type(window_frame)?;
-            Ok(make_substrait_window_function(
-                function_anchor,
-                arguments,
-                partition_by,
-                order_by,
-                bounds,
-                bound_type,
-            ))
-        }
-        Expr::Like(Like {
-            negated,
-            expr,
-            pattern,
-            escape_char,
-            case_insensitive,
-        }) => make_substrait_like_expr(
-            state,
-            *case_insensitive,
-            *negated,
-            expr,
-            pattern,
-            *escape_char,
-            schema,
-            col_ref_offset,
-            extensions,
-        ),
-        Expr::InSubquery(InSubquery {
-            expr,
-            subquery,
-            negated,
-        }) => {
-            let substrait_expr =
-                to_substrait_rex(state, expr, schema, col_ref_offset, 
extensions)?;
-
-            let subquery_plan =
-                to_substrait_rel(subquery.subquery.as_ref(), state, 
extensions)?;
-
-            let substrait_subquery = Expression {
-                rex_type: Some(RexType::Subquery(Box::new(Subquery {
-                    subquery_type: Some(
-                        
substrait::proto::expression::subquery::SubqueryType::InPredicate(
-                            Box::new(InPredicate {
-                                needles: (vec![substrait_expr]),
-                                haystack: Some(subquery_plan),
-                            }),
-                        ),
+pub fn from_between(
+    producer: &mut impl SubstraitProducer,
+    between: &Between,
+    schema: &DFSchemaRef,
+) -> Result<Expression> {
+    let Between {
+        expr,
+        negated,
+        low,
+        high,
+    } = between;
+    if *negated {
+        // `expr NOT BETWEEN low AND high` can be translated into (expr < low 
OR high < expr)
+        let substrait_expr = producer.consume_expr(expr.as_ref(), schema)?;
+        let substrait_low = producer.consume_expr(low.as_ref(), schema)?;
+        let substrait_high = producer.consume_expr(high.as_ref(), schema)?;
+
+        let l_expr = make_binary_op_scalar_func(
+            producer,
+            &substrait_expr,
+            &substrait_low,
+            Operator::Lt,
+        );
+        let r_expr = make_binary_op_scalar_func(
+            producer,
+            &substrait_high,
+            &substrait_expr,
+            Operator::Lt,
+        );
+
+        Ok(make_binary_op_scalar_func(
+            producer,
+            &l_expr,
+            &r_expr,
+            Operator::Or,
+        ))
+    } else {
+        // `expr BETWEEN low AND high` can be translated into (low <= expr AND 
expr <= high)
+        let substrait_expr = producer.consume_expr(expr.as_ref(), schema)?;
+        let substrait_low = producer.consume_expr(low.as_ref(), schema)?;
+        let substrait_high = producer.consume_expr(high.as_ref(), schema)?;
+
+        let l_expr = make_binary_op_scalar_func(
+            producer,
+            &substrait_low,
+            &substrait_expr,
+            Operator::LtEq,
+        );
+        let r_expr = make_binary_op_scalar_func(
+            producer,
+            &substrait_expr,
+            &substrait_high,
+            Operator::LtEq,
+        );
+
+        Ok(make_binary_op_scalar_func(
+            producer,
+            &l_expr,
+            &r_expr,
+            Operator::And,
+        ))
+    }
+}
+pub fn from_column(col: &Column, schema: &DFSchemaRef) -> Result<Expression> {
+    let index = schema.index_of_column(col)?;
+    substrait_field_ref(index)
+}
+
+pub fn from_binary_expr(
+    producer: &mut impl SubstraitProducer,
+    expr: &BinaryExpr,
+    schema: &DFSchemaRef,
+) -> Result<Expression> {
+    let BinaryExpr { left, op, right } = expr;
+    let l = producer.consume_expr(left, schema)?;
+    let r = producer.consume_expr(right, schema)?;
+    Ok(make_binary_op_scalar_func(producer, &l, &r, *op))
+}
+pub fn from_case(
+    producer: &mut impl SubstraitProducer,
+    case: &Case,
+    schema: &DFSchemaRef,
+) -> Result<Expression> {
+    let Case {
+        expr,
+        when_then_expr,
+        else_expr,
+    } = case;
+    let mut ifs: Vec<IfClause> = vec![];
+    // Parse base
+    if let Some(e) = expr {
+        // Base expression exists
+        ifs.push(IfClause {
+            r#if: Some(producer.consume_expr(e, schema)?),
+            then: None,
+        });
+    }
+    // Parse `when`s
+    for (r#if, then) in when_then_expr {
+        ifs.push(IfClause {
+            r#if: Some(producer.consume_expr(r#if, schema)?),
+            then: Some(producer.consume_expr(then, schema)?),
+        });
+    }
+
+    // Parse outer `else`
+    let r#else: Option<Box<Expression>> = match else_expr {
+        Some(e) => Some(Box::new(to_substrait_rex(producer, e, schema)?)),
+        None => None,
+    };
+
+    Ok(Expression {
+        rex_type: Some(RexType::IfThen(Box::new(IfThen { ifs, r#else }))),
+    })
+}
+
+pub fn from_cast(
+    producer: &mut impl SubstraitProducer,
+    cast: &Cast,
+    schema: &DFSchemaRef,
+) -> Result<Expression> {
+    let Cast { expr, data_type } = cast;
+    Ok(Expression {
+        rex_type: Some(RexType::Cast(Box::new(
+            substrait::proto::expression::Cast {
+                r#type: Some(to_substrait_type(data_type, true)?),
+                input: Some(Box::new(to_substrait_rex(producer, expr, 
schema)?)),
+                failure_behavior: FailureBehavior::ThrowException.into(),
+            },
+        ))),
+    })
+}
+
+pub fn from_try_cast(
+    producer: &mut impl SubstraitProducer,
+    cast: &TryCast,
+    schema: &DFSchemaRef,
+) -> Result<Expression> {
+    let TryCast { expr, data_type } = cast;
+    Ok(Expression {
+        rex_type: Some(RexType::Cast(Box::new(
+            substrait::proto::expression::Cast {
+                r#type: Some(to_substrait_type(data_type, true)?),
+                input: Some(Box::new(to_substrait_rex(producer, expr, 
schema)?)),
+                failure_behavior: FailureBehavior::ReturnNull.into(),
+            },
+        ))),
+    })
+}
+
+pub fn from_literal(
+    producer: &mut impl SubstraitProducer,
+    value: &ScalarValue,
+) -> Result<Expression> {
+    to_substrait_literal_expr(producer, value)
+}
+
+pub fn from_alias(
+    producer: &mut impl SubstraitProducer,
+    alias: &Alias,
+    schema: &DFSchemaRef,
+) -> Result<Expression> {
+    producer.consume_expr(alias.expr.as_ref(), schema)
+}
+
+pub fn from_window_function(
+    producer: &mut impl SubstraitProducer,
+    window_fn: &WindowFunction,
+    schema: &DFSchemaRef,
+) -> Result<Expression> {
+    let WindowFunction {
+        fun,
+        args,
+        partition_by,
+        order_by,
+        window_frame,
+        null_treatment: _,
+    } = window_fn;
+    // function reference
+    let function_anchor = producer.register_function(fun.to_string());
+    // arguments
+    let mut arguments: Vec<FunctionArgument> = vec![];
+    for arg in args {
+        arguments.push(FunctionArgument {
+            arg_type: Some(ArgType::Value(to_substrait_rex(producer, arg, 
schema)?)),

Review Comment:
   Good catch, everything should go through `producer.handle_expr`. I missed 
this one, and a couple of other ones. I've updated all of them.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org
For additional commands, e-mail: github-h...@datafusion.apache.org

Reply via email to