vbarua commented on code in PR #12462:
URL: https://github.com/apache/datafusion/pull/12462#discussion_r1759608677
##########
datafusion/substrait/tests/cases/consumer_integration.rs:
##########
@@ -24,569 +24,435 @@
#[cfg(test)]
mod tests {
+ use crate::utils::test::add_plan_schemas_to_ctx;
use datafusion::common::Result;
- use datafusion::execution::options::CsvReadOptions;
use datafusion::prelude::SessionContext;
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use substrait::proto::Plan;
- async fn create_context(files: Vec<(&str, &str)>) ->
Result<SessionContext> {
- let ctx = SessionContext::new();
- for (table_name, file_path) in files {
- ctx.register_csv(table_name, file_path, CsvReadOptions::default())
- .await?;
- }
- Ok(ctx)
- }
- #[tokio::test]
- async fn tpch_test_1() -> Result<()> {
- let ctx = create_context(vec![(
- "FILENAME_PLACEHOLDER_0",
- "tests/testdata/tpch/lineitem.csv",
- )])
- .await?;
- let path = "tests/testdata/tpch_substrait_plans/query_1.json";
Review Comment:
I noticed a lot of duplication so I encapsulated the test setup code in
`tpch_plan_to_string`.
##########
datafusion/substrait/tests/utils.rs:
##########
@@ -37,13 +46,29 @@ pub mod test {
.expect("failed to parse json")
}
- pub fn add_plan_schemas_to_ctx(ctx: SessionContext, plan: &Plan) ->
SessionContext {
- let schemas = TestSchemaCollector::collect_schemas(plan);
- for (table_reference, table) in schemas {
- ctx.register_table(table_reference, table)
- .expect("Failed to register table");
+ pub fn add_plan_schemas_to_ctx(
+ ctx: SessionContext,
+ plan: &Plan,
+ ) -> Result<SessionContext> {
+ let schemas = TestSchemaCollector::collect_schemas(plan)?;
+ let mut schema_map: HashMap<TableReference, Arc<dyn TableProvider>> =
+ HashMap::new();
+ for (table_reference, table) in schemas.into_iter() {
+ let schema = table.schema();
+ if let Some(existing_table) =
+ schema_map.insert(table_reference.clone(), table)
+ {
+ if existing_table.schema() != schema {
+ return substrait_err!(
+ "Substrait plan contained the same table {} with
different schemas.\nSchema 1: {}\nSchema 2: {}",
+ table_reference, existing_table.schema(), schema);
Review Comment:
I'm checking for the case of the same table having different schemas because
I've been bitten by it before due to schema pruning shenanigans and it took a
while to figure out what was hapenning.
##########
datafusion/substrait/tests/utils.rs:
##########
@@ -97,90 +127,362 @@ pub mod test {
},
};
- let substrait_schema = read
- .base_schema
- .as_ref()
- .expect("No base schema found for NamedTable");
+ let substrait_schema =
+ read.base_schema.as_ref().ok_or(substrait_datafusion_err!(
+ "No base schema found for NamedTable: {}",
+ table_reference
+ ))?;
let empty_extensions = Extensions {
functions: Default::default(),
types: Default::default(),
type_variations: Default::default(),
};
let df_schema =
- from_substrait_named_struct(substrait_schema,
&empty_extensions)
- .expect(
- "Unable to generate DataFusion schema from Substrait
NamedStruct",
- )
+ from_substrait_named_struct(substrait_schema,
&empty_extensions)?
.replace_qualifier(table_reference.clone());
let table = EmptyTable::new(df_schema.inner().clone());
self.schemas.push((table_reference, Arc::new(table)));
+ Ok(())
}
- fn collect_schemas_from_rel(&mut self, rel: &Rel) {
- match rel.rel_type.as_ref().unwrap() {
- RelType::Read(r) => match r.read_type.as_ref().unwrap() {
- // Virtual Tables do not contribute to the schema
- ReadType::VirtualTable(_) => (),
- ReadType::LocalFiles(_) => todo!(),
- ReadType::NamedTable(nt) => self.collect_named_table(r,
nt),
- ReadType::ExtensionTable(_) => todo!(),
- },
- RelType::Filter(f) => self.apply(f.input.as_ref().map(|b|
b.as_ref())),
- RelType::Fetch(f) => self.apply(f.input.as_ref().map(|b|
b.as_ref())),
- RelType::Aggregate(a) => self.apply(a.input.as_ref().map(|b|
b.as_ref())),
- RelType::Sort(s) => self.apply(s.input.as_ref().map(|b|
b.as_ref())),
+ fn collect_schemas_from_rel(&mut self, rel: &Rel) -> Result<()> {
Review Comment:
meta: writing traversals like this is quite painful. It would be good go get
some visitation machinery added to substrait-rust.
##########
datafusion/substrait/tests/cases/consumer_integration.rs:
##########
@@ -24,569 +24,435 @@
#[cfg(test)]
mod tests {
+ use crate::utils::test::add_plan_schemas_to_ctx;
use datafusion::common::Result;
- use datafusion::execution::options::CsvReadOptions;
use datafusion::prelude::SessionContext;
use datafusion_substrait::logical_plan::consumer::from_substrait_plan;
use std::fs::File;
use std::io::BufReader;
use substrait::proto::Plan;
- async fn create_context(files: Vec<(&str, &str)>) ->
Result<SessionContext> {
- let ctx = SessionContext::new();
- for (table_name, file_path) in files {
- ctx.register_csv(table_name, file_path, CsvReadOptions::default())
- .await?;
- }
- Ok(ctx)
- }
- #[tokio::test]
- async fn tpch_test_1() -> Result<()> {
- let ctx = create_context(vec![(
- "FILENAME_PLACEHOLDER_0",
- "tests/testdata/tpch/lineitem.csv",
- )])
- .await?;
- let path = "tests/testdata/tpch_substrait_plans/query_1.json";
+ async fn tpch_plan_to_string(query_id: i32) -> Result<String> {
+ let path =
+
format!("tests/testdata/tpch_substrait_plans/query_{query_id:02}_plan.json");
let proto = serde_json::from_reader::<_, Plan>(BufReader::new(
File::open(path).expect("file not found"),
))
.expect("failed to parse json");
+ let ctx = add_plan_schemas_to_ctx(SessionContext::new(), &proto)?;
Review Comment:
Instead of reading the CSV input files to generate schemas, we can generate
them directly from the Substrait plans.
##########
datafusion/substrait/tests/utils.rs:
##########
@@ -97,90 +127,362 @@ pub mod test {
},
};
- let substrait_schema = read
- .base_schema
- .as_ref()
- .expect("No base schema found for NamedTable");
+ let substrait_schema =
+ read.base_schema.as_ref().ok_or(substrait_datafusion_err!(
+ "No base schema found for NamedTable: {}",
+ table_reference
+ ))?;
let empty_extensions = Extensions {
functions: Default::default(),
types: Default::default(),
type_variations: Default::default(),
};
let df_schema =
- from_substrait_named_struct(substrait_schema,
&empty_extensions)
- .expect(
- "Unable to generate DataFusion schema from Substrait
NamedStruct",
- )
+ from_substrait_named_struct(substrait_schema,
&empty_extensions)?
.replace_qualifier(table_reference.clone());
let table = EmptyTable::new(df_schema.inner().clone());
self.schemas.push((table_reference, Arc::new(table)));
+ Ok(())
}
- fn collect_schemas_from_rel(&mut self, rel: &Rel) {
- match rel.rel_type.as_ref().unwrap() {
- RelType::Read(r) => match r.read_type.as_ref().unwrap() {
- // Virtual Tables do not contribute to the schema
- ReadType::VirtualTable(_) => (),
- ReadType::LocalFiles(_) => todo!(),
- ReadType::NamedTable(nt) => self.collect_named_table(r,
nt),
- ReadType::ExtensionTable(_) => todo!(),
- },
- RelType::Filter(f) => self.apply(f.input.as_ref().map(|b|
b.as_ref())),
- RelType::Fetch(f) => self.apply(f.input.as_ref().map(|b|
b.as_ref())),
- RelType::Aggregate(a) => self.apply(a.input.as_ref().map(|b|
b.as_ref())),
- RelType::Sort(s) => self.apply(s.input.as_ref().map(|b|
b.as_ref())),
+ fn collect_schemas_from_rel(&mut self, rel: &Rel) -> Result<()> {
+ let rel_type = rel
+ .rel_type
+ .as_ref()
+ .ok_or(substrait_datafusion_err!("RelRoot must set input"))?;
+ match rel_type {
+ RelType::Read(r) => {
+ let read_type = r
+ .read_type
+ .as_ref()
+ .ok_or(substrait_datafusion_err!("read_type not set on
Read"))?;
+ match read_type {
+ // Virtual Tables do not contribute to the schema
+ ReadType::VirtualTable(_) => (),
+ ReadType::LocalFiles(_) => todo!(),
+ ReadType::NamedTable(nt) =>
self.collect_named_table(r, nt)?,
+ ReadType::ExtensionTable(_) => todo!(),
+ }
+ if let Some(expr) = r.filter.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ };
+ if let Some(expr) = r.best_effort_filter.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ };
+ }
+ RelType::Filter(f) => {
+ self.apply(f.input.as_ref().map(|b| b.as_ref()))?;
+ for expr in f.condition.iter() {
+ self.collect_schemas_from_expr(expr)?;
+ }
+ }
+ RelType::Fetch(f) => {
+ self.apply(f.input.as_ref().map(|b| b.as_ref()))?;
+ }
+ RelType::Aggregate(a) => {
+ self.apply(a.input.as_ref().map(|b| b.as_ref()))?;
+ for grouping in a.groupings.iter() {
+ for expr in grouping.grouping_expressions.iter() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ for measure in a.measures.iter() {
+ if let Some(agg_fn) = measure.measure.as_ref() {
+ for arg in agg_fn.arguments.iter() {
+ self.collect_schemas_from_arg(arg)?
+ }
+ for sort in agg_fn.sorts.iter() {
+ if let Some(expr) = sort.expr.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ }
+ if let Some(expr) = measure.filter.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ }
+ RelType::Sort(s) => {
+ self.apply(s.input.as_ref().map(|b| b.as_ref()))?;
+ for sort_field in s.sorts.iter() {
+ if let Some(expr) = sort_field.expr.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ }
RelType::Join(j) => {
- self.apply(j.left.as_ref().map(|b| b.as_ref()));
- self.apply(j.right.as_ref().map(|b| b.as_ref()));
+ self.apply(j.left.as_ref().map(|b| b.as_ref()))?;
+ self.apply(j.right.as_ref().map(|b| b.as_ref()))?;
+ if let Some(expr) = j.expression.as_ref() {
+ self.collect_schemas_from_expr(expr)?;
+ }
+ if let Some(expr) = j.post_join_filter.as_ref() {
+ self.collect_schemas_from_expr(expr)?;
+ }
+ }
+ RelType::Project(p) => {
+ self.apply(p.input.as_ref().map(|b| b.as_ref()))?
}
- RelType::Project(p) => self.apply(p.input.as_ref().map(|b|
b.as_ref())),
RelType::Set(s) => {
for input in s.inputs.iter() {
- self.collect_schemas_from_rel(input);
+ self.collect_schemas_from_rel(input)?;
}
}
RelType::ExtensionSingle(s) => {
- self.apply(s.input.as_ref().map(|b| b.as_ref()))
+ self.apply(s.input.as_ref().map(|b| b.as_ref()))?
}
+
RelType::ExtensionMulti(m) => {
for input in m.inputs.iter() {
- self.collect_schemas_from_rel(input)
+ self.collect_schemas_from_rel(input)?
}
}
RelType::ExtensionLeaf(_) => {}
RelType::Cross(c) => {
- self.apply(c.left.as_ref().map(|b| b.as_ref()));
- self.apply(c.right.as_ref().map(|b| b.as_ref()));
+ self.apply(c.left.as_ref().map(|b| b.as_ref()))?;
+ self.apply(c.right.as_ref().map(|b| b.as_ref()))?;
}
// RelType::Reference(_) => {}
// RelType::Write(_) => {}
// RelType::Ddl(_) => {}
RelType::HashJoin(j) => {
- self.apply(j.left.as_ref().map(|b| b.as_ref()));
- self.apply(j.right.as_ref().map(|b| b.as_ref()));
+ self.apply(j.left.as_ref().map(|b| b.as_ref()))?;
+ self.apply(j.right.as_ref().map(|b| b.as_ref()))?;
+ if let Some(expr) = j.post_join_filter.as_ref() {
+ self.collect_schemas_from_expr(expr)?;
+ }
}
RelType::MergeJoin(j) => {
- self.apply(j.left.as_ref().map(|b| b.as_ref()));
- self.apply(j.right.as_ref().map(|b| b.as_ref()));
+ self.apply(j.left.as_ref().map(|b| b.as_ref()))?;
+ self.apply(j.right.as_ref().map(|b| b.as_ref()))?;
+ if let Some(expr) = j.post_join_filter.as_ref() {
+ self.collect_schemas_from_expr(expr)?;
+ }
}
RelType::NestedLoopJoin(j) => {
- self.apply(j.left.as_ref().map(|b| b.as_ref()));
- self.apply(j.right.as_ref().map(|b| b.as_ref()));
+ self.apply(j.left.as_ref().map(|b| b.as_ref()))?;
+ self.apply(j.right.as_ref().map(|b| b.as_ref()))?;
+ if let Some(expr) = j.expression.as_ref() {
+ self.collect_schemas_from_expr(expr)?;
+ }
+ }
+ RelType::Window(w) => {
+ self.apply(w.input.as_ref().map(|b| b.as_ref()))?;
+ for wf in w.window_functions.iter() {
+ for arg in wf.arguments.iter() {
+ self.collect_schemas_from_arg(arg)?;
+ }
+ }
+ for expr in w.partition_expressions.iter() {
+ self.collect_schemas_from_expr(expr)?;
+ }
+ for sort_field in w.sorts.iter() {
+ if let Some(expr) = sort_field.expr.as_ref() {
+ self.collect_schemas_from_expr(expr)?;
+ }
+ }
+ }
+ RelType::Exchange(e) => {
+ self.apply(e.input.as_ref().map(|b| b.as_ref()))?;
+ let exchange_kind = e.exchange_kind.as_ref().ok_or(
+ substrait_datafusion_err!("Exhange must set
exchange_kind"),
+ )?;
+ match exchange_kind {
+ ExchangeKind::ScatterByFields(_) => {}
+ ExchangeKind::SingleTarget(st) => {
+ if let Some(expr) = st.expression.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ ExchangeKind::MultiTarget(mt) => {
+ if let Some(expr) = mt.expression.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ ExchangeKind::RoundRobin(_) => {}
+ ExchangeKind::Broadcast(_) => {}
+ }
+ }
+ RelType::Expand(e) => {
+ self.apply(e.input.as_ref().map(|b| b.as_ref()))?;
+ for expand_field in e.fields.iter() {
+ let expand_type =
expand_field.field_type.as_ref().ok_or(
+ substrait_datafusion_err!("ExpandField must set
field_type"),
+ )?;
+ match expand_type {
+ FieldType::SwitchingField(sf) => {
+ for expr in sf.duplicates.iter() {
+ self.collect_schemas_from_expr(expr)?;
+ }
+ }
+ FieldType::ConsistentField(expr) => {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ }
}
- RelType::Window(w) => self.apply(w.input.as_ref().map(|b|
b.as_ref())),
- RelType::Exchange(e) => self.apply(e.input.as_ref().map(|b|
b.as_ref())),
- RelType::Expand(e) => self.apply(e.input.as_ref().map(|b|
b.as_ref())),
_ => todo!(),
}
+ Ok(())
}
- fn apply(&mut self, input: Option<&Rel>) {
+ fn apply(&mut self, input: Option<&Rel>) -> Result<()> {
match input {
- None => {}
+ None => Ok(()),
Some(rel) => self.collect_schemas_from_rel(rel),
}
}
+
+ fn collect_schemas_from_expr(&mut self, e: &Expression) -> Result<()> {
+ let rex_type = e.rex_type.as_ref().ok_or(substrait_datafusion_err!(
+ "rex_type must be set on Expression"
+ ))?;
+ match rex_type {
+ RexType::Literal(_) => {}
+ RexType::Selection(_) => {}
+ RexType::ScalarFunction(sf) => {
+ for arg in sf.arguments.iter() {
+ self.collect_schemas_from_arg(arg)?
+ }
+ }
+ RexType::WindowFunction(wf) => {
+ for arg in wf.arguments.iter() {
+ self.collect_schemas_from_arg(arg)?
+ }
+ for sort_field in wf.sorts.iter() {
+ if let Some(expr) = sort_field.expr.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ for expr in wf.partitions.iter() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ RexType::IfThen(it) => {
+ for if_clause in it.ifs.iter() {
+ if let Some(expr) = if_clause.r#if.as_ref() {
+ self.collect_schemas_from_expr(expr)?;
+ };
+ if let Some(expr) = if_clause.then.as_ref() {
+ self.collect_schemas_from_expr(expr)?;
+ };
+ }
+ if let Some(expr) = it.r#else.as_ref() {
+ self.collect_schemas_from_expr(expr)?;
+ };
+ }
+ RexType::SwitchExpression(se) => {
+ if let Some(expr) = se.r#match.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ for if_value in se.ifs.iter() {
+ if let Some(expr) = if_value.then.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ }
+ RexType::SingularOrList(sol) => {
+ if let Some(expr) = sol.value.as_ref() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ for expr in sol.options.iter() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ RexType::MultiOrList(mol) => {
+ for expr in mol.value.iter() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ for record in mol.options.iter() {
+ for expr in record.fields.iter() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ }
+ RexType::Cast(c) => {
+ for expr in c.input.iter() {
+ self.collect_schemas_from_expr(expr)?
+ }
+ }
+ RexType::Subquery(subquery) => {
+ let subquery_type = subquery
+ .subquery_type
+ .as_ref()
+ .ok_or(substrait_datafusion_err!("subquery_type must
be set"))?;
+ match subquery_type {
Review Comment:
A number of the TPCH plans include Subquery expressions. Without this
traversal, the schemas generated from the plans were missing tables.
##########
datafusion/substrait/tests/utils.rs:
##########
@@ -57,28 +82,33 @@ pub mod test {
}
}
- fn collect_schemas(plan: &Plan) -> Vec<(TableReference, Arc<dyn
TableProvider>)> {
+ fn collect_schemas(
+ plan: &Plan,
+ ) -> Result<Vec<(TableReference, Arc<dyn TableProvider>)>> {
Review Comment:
I switched this to be Result based because eventually I would like this to
be part of the core library. Collecting schemas like this is super useful.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]