berkaysynnada commented on code in PR #13560: URL: https://github.com/apache/datafusion/pull/13560#discussion_r1868812172
########## datafusion/core/src/physical_optimizer/sort_pushdown.rs: ########## @@ -606,6 +610,118 @@ fn handle_custom_pushdown( } } +// For hash join we only maintain the input order for the right child +// for join type: Inner, Right, RightSemi, RightAnti +fn handle_hash_join( + plan: &HashJoinExec, + parent_required: &LexRequirement, +) -> Result<Option<Vec<Option<LexRequirement>>>> { + // If there's no requirement from the parent or the plan has no children + // or the join type is not Inner, Right, RightSemi, RightAnti, return early + if parent_required.is_empty() + || plan.children().is_empty() Review Comment: I guess you don't need to add this check, hash join cannot exist without 2 children ########## datafusion/core/src/physical_optimizer/sort_pushdown.rs: ########## @@ -606,6 +610,118 @@ fn handle_custom_pushdown( } } +// For hash join we only maintain the input order for the right child +// for join type: Inner, Right, RightSemi, RightAnti +fn handle_hash_join( + plan: &HashJoinExec, + parent_required: &LexRequirement, +) -> Result<Option<Vec<Option<LexRequirement>>>> { + // If there's no requirement from the parent or the plan has no children + // or the join type is not Inner, Right, RightSemi, RightAnti, return early + if parent_required.is_empty() + || plan.children().is_empty() + || !matches!( + plan.join_type(), + JoinType::Inner | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti + ) + { + return Ok(None); + } + + // Collect all unique column indices used in the parent-required sorting expression + let all_indices: HashSet<usize> = parent_required + .iter() + .flat_map(|order| { + collect_columns(&order.expr) + .iter() + .map(|col| col.index()) + .collect::<HashSet<_>>() + }) + .collect(); Review Comment: ```suggestion let all_indices = parent_required .iter() .flat_map(|order| { collect_columns(&order.expr) .into_iter() .map(|col| col.index()) }) .collect::<HashSet<_>>(); ``` ########## datafusion/core/src/physical_optimizer/sort_pushdown.rs: ########## @@ -606,6 +610,118 @@ fn handle_custom_pushdown( } } +// For hash join we only maintain the input order for the right child +// for join type: Inner, Right, RightSemi, RightAnti +fn handle_hash_join( + plan: &HashJoinExec, + parent_required: &LexRequirement, +) -> Result<Option<Vec<Option<LexRequirement>>>> { + // If there's no requirement from the parent or the plan has no children + // or the join type is not Inner, Right, RightSemi, RightAnti, return early + if parent_required.is_empty() + || plan.children().is_empty() + || !matches!( + plan.join_type(), Review Comment: More idiomatic way of doing this is using `maintains_input_order()` of HashJoin ########## datafusion/core/src/physical_optimizer/sort_pushdown.rs: ########## @@ -606,6 +610,118 @@ fn handle_custom_pushdown( } } +// For hash join we only maintain the input order for the right child +// for join type: Inner, Right, RightSemi, RightAnti +fn handle_hash_join( + plan: &HashJoinExec, + parent_required: &LexRequirement, +) -> Result<Option<Vec<Option<LexRequirement>>>> { + // If there's no requirement from the parent or the plan has no children + // or the join type is not Inner, Right, RightSemi, RightAnti, return early + if parent_required.is_empty() + || plan.children().is_empty() + || !matches!( + plan.join_type(), + JoinType::Inner | JoinType::Right | JoinType::RightSemi | JoinType::RightAnti + ) + { + return Ok(None); + } + + // Collect all unique column indices used in the parent-required sorting expression + let all_indices: HashSet<usize> = parent_required + .iter() + .flat_map(|order| { + collect_columns(&order.expr) + .iter() + .map(|col| col.index()) + .collect::<HashSet<_>>() + }) + .collect(); + + let column_indices = build_join_column_index(plan); + let projected_indices: Vec<_> = if let Some(projection) = &plan.projection { + projection.iter().map(|&i| &column_indices[i]).collect() + } else { + column_indices.iter().collect() + }; + let len_of_left_fields = projected_indices + .iter() + .filter(|ci| ci.side == JoinSide::Left) + .count(); + + let all_from_right_child = all_indices.iter().all(|i| *i >= len_of_left_fields); + + // If all columns are from the right child, update the parent requirements + if all_from_right_child { + // Transform the parent-required expression for the child schema by adjusting columns + let updated_parent_req = parent_required + .iter() + .map(|req| { + let child_schema = plan.children()[1].schema(); + let updated_columns = Arc::clone(&req.expr) + .transform_up(|expr| { + if let Some(col) = expr.as_any().downcast_ref::<Column>() { + let index = projected_indices[col.index()].index; + Ok(Transformed::yes(Arc::new(Column::new( + child_schema.field(index).name(), + index, + )))) + } else { + Ok(Transformed::no(expr)) + } + })? + .data; + Ok(PhysicalSortRequirement::new(updated_columns, req.options)) + }) + .collect::<Result<Vec<_>>>()?; + + // Populating with the updated requirements for children that maintain order + Ok(Some(vec![ + None, + Some(LexRequirement::new(updated_parent_req)), + ])) + } else { + Ok(None) + } +} + +// this function is used to build the column index for the hash join +// push down sort requirements to the right child +fn build_join_column_index(plan: &HashJoinExec) -> Vec<ColumnIndex> { + let left = plan.left().schema(); + let right = plan.right().schema(); + + let left_fields = || { + left.fields() + .iter() + .enumerate() + .map(|(index, _)| ColumnIndex { + index, + side: JoinSide::Left, + }) + }; + + let right_fields = || { + right + .fields() + .iter() + .enumerate() + .map(|(index, _)| ColumnIndex { + index, + side: JoinSide::Right, + }) + }; + + match plan.join_type() { + JoinType::Inner | JoinType::Right => { + left_fields().chain(right_fields()).collect() + } + JoinType::RightSemi | JoinType::RightAnti => right_fields().collect(), + _ => unreachable!("unexpected join type: {}", plan.join_type()), Review Comment: ```suggestion let map_fields = |schema: SchemaRef, side: JoinSide| { schema .fields() .iter() .enumerate() .map(|(index, _)| ColumnIndex { index, side }) .collect::<Vec<_>>() }; match plan.join_type() { JoinType::Inner | JoinType::Right => { map_fields(plan.left().schema(), JoinSide::Left) .into_iter() .chain(map_fields(plan.right().schema(), JoinSide::Right)) .collect::<Vec<_>>() } JoinType::RightSemi | JoinType::RightAnti => { map_fields(plan.right().schema(), JoinSide::Right) } _ => unreachable!("unexpected join type: {}", plan.join_type()), } ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: github-unsubscr...@datafusion.apache.org For additional commands, e-mail: github-h...@datafusion.apache.org