vbarua commented on code in PR #13931: URL: https://github.com/apache/datafusion/pull/13931#discussion_r1899710875
########## datafusion/substrait/src/logical_plan/producer.rs: ########## @@ -101,14 +105,330 @@ use substrait::{ version, }; -use super::state::SubstraitPlanningState; +/// This trait is used to produce Substrait plans, converting them from DataFusion Logical Plans. +/// It can be implemented by users to allow for custom handling of relations, expressions, etc. +/// +/// Combined with the [crate::logical_plan::consumer::SubstraitConsumer] this allows for fully +/// customizable Substrait serde. +/// +/// # Example Usage +/// +/// ``` +/// # use std::sync::Arc; +/// # use substrait::proto::{Expression, Rel}; +/// # use substrait::proto::rel::RelType; +/// # use datafusion::common::DFSchemaRef; +/// # use datafusion::error::Result; +/// # use datafusion::execution::SessionState; +/// # use datafusion::logical_expr::{Between, Extension, Projection}; +/// # use datafusion_substrait::extensions::Extensions; +/// # use datafusion_substrait::logical_plan::producer::{from_projection, SubstraitProducer}; +/// +/// struct CustomSubstraitProducer { +/// extensions: Extensions, +/// state: Arc<SessionState>, +/// } +/// +/// impl SubstraitProducer for CustomSubstraitProducer { +/// +/// fn register_function(&mut self, signature: String) -> u32 { +/// self.extensions.register_function(signature) +/// } +/// +/// fn get_extensions(self) -> Extensions { +/// self.extensions +/// } +/// +/// // You can set additional metadata on the Rels you produce +/// fn consume_projection(&mut self, plan: &Projection) -> Result<Box<Rel>> { +/// let mut rel = from_projection(self, plan)?; +/// match rel.rel_type { +/// Some(RelType::Project(mut project)) => { +/// let mut project = project.clone(); +/// // set common metadata or advanced extension +/// project.common = None; +/// project.advanced_extension = None; +/// Ok(Box::new(Rel { +/// rel_type: Some(RelType::Project(project)), +/// })) +/// } +/// rel_type => Ok(Box::new(Rel { rel_type })), +/// } +/// } +/// +/// // You can tweak how you convert expressions for your target system +/// fn consume_between(&mut self, between: &Between, schema: &DFSchemaRef) -> Result<Expression> { +/// // add your own encoding for Between +/// todo!() +/// } +/// +/// // You can fully control how you convert UserDefinedLogicalNodes into Substrait +/// fn consume_extension(&mut self, _plan: &Extension) -> Result<Box<Rel>> { +/// // implement your own serializer into Substrait +/// todo!() +/// } +/// } +/// ``` +pub trait SubstraitProducer: Send + Sync + Sized { + /// Within a Substrait plan, functions are referenced using function anchors that are stored at + /// the top level of the [Plan] within + /// [ExtensionFunction](substrait::proto::extensions::simple_extension_declaration::ExtensionFunction) + /// messages. + /// + /// When given a function signature, this method should return the existing anchor for it if + /// there is one. Otherwise, it should generate a new anchor. + fn register_function(&mut self, signature: String) -> u32; + + /// Consume the producer to generate the [Extensions] for the Substrait plan based on the + /// functions that have been registered + fn get_extensions(self) -> Extensions; + + // Logical Plan Methods + // There is one method per LogicalPlan to allow for easy overriding of producer behaviour. + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn consume_plan(&mut self, plan: &LogicalPlan) -> Result<Box<Rel>> { + to_substrait_rel(self, plan) + } + + fn consume_projection(&mut self, plan: &Projection) -> Result<Box<Rel>> { + from_projection(self, plan) + } + + fn consume_filter(&mut self, plan: &Filter) -> Result<Box<Rel>> { + from_filter(self, plan) + } + + fn consume_window(&mut self, plan: &Window) -> Result<Box<Rel>> { + from_window(self, plan) + } + + fn consume_aggregate(&mut self, plan: &Aggregate) -> Result<Box<Rel>> { + from_aggregate(self, plan) + } + + fn consume_sort(&mut self, plan: &Sort) -> Result<Box<Rel>> { + from_sort(self, plan) + } + + fn consume_join(&mut self, plan: &Join) -> Result<Box<Rel>> { + from_join(self, plan) + } + + fn consume_repartition(&mut self, plan: &Repartition) -> Result<Box<Rel>> { + from_repartition(self, plan) + } + + fn consume_union(&mut self, plan: &Union) -> Result<Box<Rel>> { + from_union(self, plan) + } + + fn consume_table_scan(&mut self, plan: &TableScan) -> Result<Box<Rel>> { + from_table_scan(self, plan) + } + + fn consume_empty_relation(&mut self, plan: &EmptyRelation) -> Result<Box<Rel>> { + from_empty_relation(plan) + } + + fn consume_subquery_alias(&mut self, plan: &SubqueryAlias) -> Result<Box<Rel>> { + from_subquery_alias(self, plan) + } + + fn consume_limit(&mut self, plan: &Limit) -> Result<Box<Rel>> { + from_limit(self, plan) + } + + fn consume_values(&mut self, plan: &Values) -> Result<Box<Rel>> { + from_values(self, plan) + } + + fn consume_distinct(&mut self, plan: &Distinct) -> Result<Box<Rel>> { + from_distinct(self, plan) + } + + fn consume_extension(&mut self, _plan: &Extension) -> Result<Box<Rel>> { + substrait_err!("Specify handling for LogicalPlan::Extension by implementing the SubstraitProducer trait") + } + + // Expression Methods + // There is one method per DataFusion Expr to allow for easy overriding of producer behaviour + // These methods have default implementations calling the common handler code, to allow for users + // to re-use common handling logic. + + fn consume_expr(&mut self, expr: &Expr, schema: &DFSchemaRef) -> Result<Expression> { + to_substrait_rex(self, expr, schema) + } + + fn consume_alias( + &mut self, + alias: &Alias, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_alias(self, alias, schema) + } + + fn consume_column( + &mut self, + column: &Column, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_column(column, schema) + } + + fn consume_literal(&mut self, value: &ScalarValue) -> Result<Expression> { + from_literal(self, value) + } + + fn consume_binary_expr( + &mut self, + expr: &BinaryExpr, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_binary_expr(self, expr, schema) + } + + fn consume_like(&mut self, like: &Like, schema: &DFSchemaRef) -> Result<Expression> { + from_like(self, like, schema) + } + + /// For handling Not, IsNotNull, IsNull, IsTrue, IsFalse, IsUnknown, IsNotTrue, IsNotFalse, IsNotUnknown, Negative + fn consume_unary_expr( + &mut self, + expr: &Expr, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_unary_expr(self, expr, schema) + } + + fn consume_between( + &mut self, + between: &Between, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_between(self, between, schema) + } + + fn consume_case(&mut self, case: &Case, schema: &DFSchemaRef) -> Result<Expression> { + from_case(self, case, schema) + } + + fn consume_cast(&mut self, cast: &Cast, schema: &DFSchemaRef) -> Result<Expression> { + from_cast(self, cast, schema) + } + + fn consume_try_cast( + &mut self, + cast: &TryCast, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_try_cast(self, cast, schema) + } + + fn consume_scalar_function( + &mut self, + scalar_fn: &expr::ScalarFunction, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_scalar_function(self, scalar_fn, schema) + } + + fn consume_aggregate_function( + &mut self, + agg_fn: &expr::AggregateFunction, + schema: &DFSchemaRef, + ) -> Result<Measure> { + from_aggregate_function(self, agg_fn, schema) + } + + fn consume_window_function( + &mut self, + window_fn: &WindowFunction, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_window_function(self, window_fn, schema) + } + + fn consume_in_list( + &mut self, + in_list: &InList, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_in_list(self, in_list, schema) + } + + fn consume_in_subquery( + &mut self, + in_subquery: &InSubquery, + schema: &DFSchemaRef, + ) -> Result<Expression> { + from_in_subquery(self, in_subquery, schema) + } +} + +struct DefaultSubstraitProducer<'a> { + extensions: Extensions, + state: &'a SessionState, +} + +impl<'a> DefaultSubstraitProducer<'a> { + pub fn new(state: &'a SessionState) -> Self { + DefaultSubstraitProducer { + extensions: Extensions::default(), + state, + } + } +} + +impl SubstraitProducer for DefaultSubstraitProducer<'_> { + fn register_function(&mut self, fn_name: String) -> u32 { + self.extensions.register_function(fn_name) + } + + fn get_extensions(self) -> Extensions { + self.extensions + } + + fn consume_extension(&mut self, plan: &Extension) -> Result<Box<Rel>> { + let extension_bytes = self + .state Review Comment: Actually, this makes me want to only store the SerializerRegistry. If we need the state (or other data) we can always add it in later. Part of the reason to switch to the producer trait is that we can modify the internal details of the DefaultSubstraitConsumer without it impacting users. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org