peter-toth commented on code in PR #11357:
URL: https://github.com/apache/datafusion/pull/11357#discussion_r1674185531
##########
datafusion/optimizer/src/common_subexpr_eliminate.rs:
##########
@@ -954,53 +952,113 @@ impl<'n> ExprIdentifierVisitor<'_, 'n> {
}
unreachable!("Enter mark should paired with node number");
}
+
+ /// Save the current `conditional` status and run `f` with `conditional`
set to true.
+ fn conditionally<F: FnMut(&mut Self) -> Result<()>>(
+ &mut self,
+ mut f: F,
+ ) -> Result<()> {
+ let conditional = self.conditional;
+ self.conditional = true;
+ f(self)?;
+ self.conditional = conditional;
+
+ Ok(())
+ }
}
impl<'n> TreeNodeVisitor<'n> for ExprIdentifierVisitor<'_, 'n> {
type Node = Expr;
fn f_down(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
- // If an expression can short circuit its children then don't consider
its
- // children for CSE
(https://github.com/apache/arrow-datafusion/issues/8814).
- // This means that we don't recurse into its children, but handle the
expression
- // as a subtree when we calculate its identifier.
- // TODO: consider surely executed children of "short circuited"s for
CSE
- let is_tree = expr.short_circuits();
- let tnr = if is_tree {
- TreeNodeRecursion::Jump
- } else {
- TreeNodeRecursion::Continue
- };
-
self.id_array.push((0, None));
self.visit_stack
- .push(VisitRecord::EnterMark(self.down_index, is_tree));
+ .push(VisitRecord::EnterMark(self.down_index));
self.down_index += 1;
- Ok(tnr)
+ // If an expression can short-circuit then some of its children might
not be
+ // executed so count the occurrence of subexpressions as conditional
in all
+ // children.
+ Ok(match expr {
+ // If we are already in a conditionally evaluated subtree then
continue
+ // traversal.
+ _ if self.conditional => TreeNodeRecursion::Continue,
+
+ // In case of `ScalarFunction`s we don't know which children are
surely
+ // executed so start visiting all children conditionally and stop
the
+ // recursion with `TreeNodeRecursion::Jump`.
+ Expr::ScalarFunction(ScalarFunction { func, args })
+ if func.short_circuits() =>
+ {
+ self.conditionally(|visitor| {
+ args.iter().try_for_each(|e| e.visit(visitor).map(|_| ()))
+ })?;
+
+ TreeNodeRecursion::Jump
+ }
+
+ // In case of `And` and `Or` the first child is surely executed,
but we
+ // account subexpressions as conditional in the second.
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: Operator::And | Operator::Or,
+ right,
+ }) => {
+ left.visit(self)?;
+ self.conditionally(|visitor| right.visit(visitor).map(|_|
()))?;
+
+ TreeNodeRecursion::Jump
+ }
+
+ // In case of `Case` the optional base expression and the first
when
+ // expressions are surely executed, but we account subexpressions
as
+ // conditional in the others.
+ Expr::Case(Case {
+ expr,
+ when_then_expr,
+ else_expr,
+ }) => {
+ expr.iter().try_for_each(|e| e.visit(self).map(|_| ()))?;
+ when_then_expr.iter().take(1).try_for_each(|(when, then)| {
+ when.visit(self)?;
+ self.conditionally(|visitor| then.visit(visitor).map(|_|
()))
+ })?;
+ self.conditionally(|visitor| {
+ when_then_expr.iter().skip(1).try_for_each(|(when, then)| {
+ when.visit(visitor)?;
+ then.visit(visitor).map(|_| ())
+ })?;
+ else_expr
+ .iter()
+ .try_for_each(|e| e.visit(visitor).map(|_| ()))
+ })?;
+
+ TreeNodeRecursion::Jump
+ }
+
+ // In case of non-short-circuit expressions continue the traversal.
+ _ => TreeNodeRecursion::Continue,
+ })
}
fn f_up(&mut self, expr: &'n Expr) -> Result<TreeNodeRecursion> {
- let (down_index, is_tree, sub_expr_id, sub_expr_is_valid) =
self.pop_enter_mark();
+ let (down_index, sub_expr_id, sub_expr_is_valid) =
self.pop_enter_mark();
- let (expr_id, is_valid) = if is_tree {
- (
- Identifier::new(expr, true, self.random_state),
- !expr.is_volatile()?,
- )
- } else {
- (
- Identifier::new(expr, false,
self.random_state).combine(sub_expr_id),
- !expr.is_volatile_node() && sub_expr_is_valid,
- )
- };
+ let expr_id =
+ Identifier::new(expr, false,
self.random_state).combine(sub_expr_id);
+ let is_valid = !expr.is_volatile_node() && sub_expr_is_valid;
self.id_array[down_index].0 = self.up_index;
if is_valid && !self.expr_mask.ignores(expr) {
self.id_array[down_index].1 = Some(expr_id);
- let count = self.expr_stats.entry(expr_id).or_insert(0);
- *count += 1;
- if *count > 1 {
+ let (count, conditional_count) =
+ self.expr_stats.entry(expr_id).or_insert((0, 0));
+ if self.conditional {
+ *conditional_count += 1;
+ } else {
+ *count += 1;
+ }
+ if *count > 1 || *count == 1 && *conditional_count > 0 {
Review Comment:
Sure, fixed in
https://github.com/apache/datafusion/pull/11357/commits/b79a9a697566ea8ddadcf36f494b74a79826a526.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]