duongcongtoai commented on code in PR #16016: URL: https://github.com/apache/datafusion/pull/16016#discussion_r2110867569
########## datafusion/optimizer/src/decorrelate_general.rs: ########## @@ -0,0 +1,1137 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +//! [`DependentJoinRewriter`] converts correlated subqueries to `DependentJoin` + +use std::ops::Deref; +use std::sync::Arc; + +use crate::{ApplyOrder, OptimizerConfig, OptimizerRule}; + +use arrow::datatypes::DataType; +use datafusion_common::alias::AliasGenerator; +use datafusion_common::tree_node::{ + Transformed, TreeNode, TreeNodeRecursion, TreeNodeRewriter, +}; +use datafusion_common::{internal_err, Column, HashMap, Result}; +use datafusion_expr::{col, lit, Expr, LogicalPlan, LogicalPlanBuilder}; + +use indexmap::IndexMap; +use itertools::Itertools; + +pub struct DependentJoinRewriter { + // each logical plan traversal will assign it a integer id + current_id: usize, + subquery_depth: usize, + // each newly visted `LogicalPlan` is inserted inside this map for tracking + nodes: IndexMap<usize, Node>, + // all the node ids from root to the current node + // this is mutated duri traversal + stack: Vec<usize>, + // track for each column, the nodes/logical plan that reference to its within the tree + all_outer_ref_columns: IndexMap<Column, Vec<ColumnAccess>>, + alias_generator: Arc<AliasGenerator>, +} + +#[derive(Debug, Hash, PartialEq, PartialOrd, Eq, Clone)] +struct ColumnAccess { + // node ids from root to the node that is referencing the column + stack: Vec<usize>, + // the node referencing the column + node_id: usize, + col: Column, + data_type: DataType, + subquery_depth: usize, +} + +impl DependentJoinRewriter { + // lowest common ancestor from stack + // given a tree of + // n1 + // | + // n2 filter where outer.column = exists(subquery) + // ---------------------- + // | \ + // | n5: subquery + // | | + // n3 scan table outer n6 filter outer.column=inner.column + // | + // n7 scan table inner + // this function is called with 2 args a:[1,2,3] and [1,2,5,6,7] + // it then returns the id of the dependent join node (2) + // and the id of the subquery node (5) + fn dependent_join_and_subquery_node_ids( + stack_with_table_provider: &[usize], + stack_with_subquery: &[usize], + ) -> (usize, usize) { + let mut lowest_common_ancestor = 0; + let mut subquery_node_id = 0; + + let min_len = stack_with_table_provider + .len() + .min(stack_with_subquery.len()); + + for i in 0..min_len { + let right_id = stack_with_subquery[i]; + let left_id = stack_with_table_provider[i]; + + if right_id == left_id { + // common parent + lowest_common_ancestor = right_id; + subquery_node_id = stack_with_subquery[i + 1]; + } else { + break; + } + } + + (lowest_common_ancestor, subquery_node_id) + } + + // because the column providers are visited after column-accessor + // (function visit_with_subqueries always visit the subquery before visiting the other children) + // we can always infer the LCA inside this function, by getting the deepest common parent + fn conclude_lowest_dependent_join_node_if_any( + &mut self, + child_id: usize, + col: &Column, + ) { + if let Some(accesses) = self.all_outer_ref_columns.get(col) { + for access in accesses.iter() { + let mut cur_stack = self.stack.clone(); + + cur_stack.push(child_id); + let (dependent_join_node_id, subquery_node_id) = + Self::dependent_join_and_subquery_node_ids(&cur_stack, &access.stack); + let node = self.nodes.get_mut(&dependent_join_node_id).unwrap(); + let accesses = node + .columns_accesses_by_subquery_id + .entry(subquery_node_id) + .or_default(); + accesses.push(ColumnAccess { + col: col.clone(), + node_id: access.node_id, + stack: access.stack.clone(), + data_type: access.data_type.clone(), + subquery_depth: access.subquery_depth, + }); + } + } + } + + fn mark_outer_column_access( + &mut self, + child_id: usize, + data_type: &DataType, + col: &Column, + ) { + // iter from bottom to top, the goal is to mark the dependent node + // the current child's access + self.all_outer_ref_columns + .entry(col.clone()) + .or_default() + .push(ColumnAccess { + stack: self.stack.clone(), + node_id: child_id, + col: col.clone(), + data_type: data_type.clone(), + subquery_depth: self.subquery_depth, + }); + } + fn rewrite_subqueries_into_dependent_joins( + &mut self, + plan: LogicalPlan, + ) -> Result<Transformed<LogicalPlan>> { + plan.rewrite_with_subqueries(self) + } +} + +impl DependentJoinRewriter { + fn new(alias_generator: Arc<AliasGenerator>) -> Self { + DependentJoinRewriter { + alias_generator, + current_id: 0, + nodes: IndexMap::new(), + stack: vec![], + all_outer_ref_columns: IndexMap::new(), + subquery_depth: 0, + } + } +} + +#[derive(Debug, Clone)] +struct Node { + plan: LogicalPlan, + + // This field is only meaningful if the node is dependent join node. + // It tracks which descendent nodes still accessing the outer columns provided by its + // left child + // The key of this map is node_id of the children subqueries. + // The insertion order matters here, and thus we use IndexMap + columns_accesses_by_subquery_id: IndexMap<usize, Vec<ColumnAccess>>, + + is_dependent_join_node: bool, + + // note that for dependent join nodes, there can be more than 1 + // subquery children at a time, but always 1 outer-column-providing-child + // which is at the last element + subquery_type: SubqueryType, +} +#[derive(Debug, Clone)] +enum SubqueryType { + None, + In, + Exists, + Scalar, + LateralJoin, +} + +impl SubqueryType { + fn prefix(&self) -> String { + match self { + SubqueryType::None => "", + SubqueryType::In => "__in_sq", + SubqueryType::Exists => "__exists_sq", + SubqueryType::Scalar => "__scalar_sq", + SubqueryType::LateralJoin => "__lateral_sq", + } + .to_string() + } +} +fn unwrap_subquery_input_from_expr(expr: &Expr) -> Arc<LogicalPlan> { + match expr { + Expr::ScalarSubquery(sq) => Arc::clone(&sq.subquery), + Expr::Exists(exists) => Arc::clone(&exists.subquery.subquery), + Expr::InSubquery(in_sq) => Arc::clone(&in_sq.subquery.subquery), + _ => unreachable!(), + } +} + +// if current expr contains any subquery expr +// this function must not be recursive +fn contains_subquery(expr: &Expr) -> bool { + expr.exists(|expr| { + Ok(matches!( + expr, + Expr::ScalarSubquery(_) | Expr::InSubquery(_) | Expr::Exists(_) + )) + }) + .expect("Inner is always Ok") +} + +/// The rewriting happens up-down, where the parent nodes are downward-visited +/// before its children (subqueries children are visited first). +/// This behavior allow the fact that, at any moment, if we observe a `LogicalPlan` +/// that provides the data for columns, we can assume that all subqueries that reference +/// its data were already visited, and we can conclude the information of the `DependentJoin` +/// needed for the decorrelation: +/// - The subquery expr +/// - The correlated columns on the LHS referenced from the RHS (and its recursing subqueries if any) +/// +/// If in the original node there exists multiple subqueries at the same time +/// two nested `DependentJoin` plans are generated (with equal depth). +/// +/// For illustration, given this query +/// ```sql +/// SELECT ID FROM T1 WHERE EXISTS(SELECT * FROM T2 WHERE T2.ID=T1.ID) OR EXISTS(SELECT * FROM T2 WHERE T2.VALUE=T1.ID); +/// ``` +/// +/// The traversal happens in the following sequence +/// +/// ```text +/// ↓1 +/// ↑12 +/// ┌────────────┐ +/// │ FILTER │<--- DependentJoin rewrite +/// │ │ happens here +/// └────┬────┬──┘ +/// ↓2 ↓6 ↓10 +/// ↑5 ↑9 ↑11 <---Here we already have enough information +/// │ | | of which node is accessing which column +/// │ | | provided by "Table Scan t1" node +/// │ | | +/// ┌─────┘ │ └─────┐ +/// │ │ │ +/// ┌───▼───┐ ┌──▼───┐ ┌───▼───────┐ +/// │SUBQ1 │ │SUBQ2 │ │TABLE SCAN │ +/// └──┬────┘ └──┬───┘ │ t1 │ +/// ↓3 ↓7 └───────────┘ +/// ↑4 ↑8 +/// ┌──▼────┐ ┌──▼────┐ +/// │SCAN t2│ │SCAN t2│ +/// └───────┘ └───────┘ +/// ``` +impl TreeNodeRewriter for DependentJoinRewriter { + type Node = LogicalPlan; + + fn f_down(&mut self, node: LogicalPlan) -> Result<Transformed<LogicalPlan>> { + let new_id = self.current_id; + self.current_id += 1; + let mut is_dependent_join_node = false; + let mut subquery_type = SubqueryType::None; + // for each node, find which column it is accessing, which column it is providing + // Set of columns current node access + match &node { + LogicalPlan::Filter(f) => { + if contains_subquery(&f.predicate) { + is_dependent_join_node = true; + } + + f.predicate + .apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + }) + .expect("traversal is infallible"); + } + // TODO: maybe there are more logical plan that provides columns + // aside from TableScan + LogicalPlan::TableScan(tbl_scan) => { + tbl_scan.projected_schema.columns().iter().for_each(|col| { + self.conclude_lowest_dependent_join_node_if_any(new_id, col); + }); + } + // Similar to TableScan, this node may provide column names which + // is referenced inside some subqueries + LogicalPlan::SubqueryAlias(alias) => { + alias.schema.columns().iter().for_each(|col| { + self.conclude_lowest_dependent_join_node_if_any(new_id, col); + }); + } + // TODO: this is untested + LogicalPlan::Projection(proj) => { + for expr in &proj.expr { + if contains_subquery(expr) { + is_dependent_join_node = true; + break; + } + expr.apply(|expr| { + if let Expr::OuterReferenceColumn(data_type, col) = expr { + self.mark_outer_column_access(new_id, data_type, col); + } + Ok(TreeNodeRecursion::Continue) + })?; + } + } Review Comment: right, i didn't test this logic, so didn't catch this issue :+1: -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org