This is an automated email from the ASF dual-hosted git repository.

alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git


The following commit(s) were added to refs/heads/master by this push:
     new 4005076d8 feat: add optimize rule `rewrite_disjunctive_predicate` 
(#2858)
4005076d8 is described below

commit 4005076d8e3e4fa07541da62f7a6c9c755029da1
Author: xudong.w <[email protected]>
AuthorDate: Wed Jul 27 04:00:19 2022 +0800

    feat: add optimize rule `rewrite_disjunctive_predicate` (#2858)
    
    * feat: add optimize rule: rewrite_disjunctive_predicate
    
    * address comments and add tests
    
    * Update datafusion/optimizer/src/rewrite_disjunctive_predicate.rs
    
    Co-authored-by: Andrew Lamb <[email protected]>
    
    Co-authored-by: Andrew Lamb <[email protected]>
---
 datafusion/core/src/execution/context.rs           |   2 +
 datafusion/core/tests/sql/predicates.rs            |  56 ++++
 datafusion/optimizer/src/lib.rs                    |   1 +
 .../optimizer/src/rewrite_disjunctive_predicate.rs | 353 +++++++++++++++++++++
 4 files changed, 412 insertions(+)

diff --git a/datafusion/core/src/execution/context.rs 
b/datafusion/core/src/execution/context.rs
index 41964e33a..96705bb0c 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -106,6 +106,7 @@ use 
datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery
 use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
 use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
 use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
+use 
datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
 use datafusion_sql::{
     parser::DFParser,
     planner::{ContextProvider, SqlToRel},
@@ -1367,6 +1368,7 @@ impl SessionState {
             Arc::new(CommonSubexprEliminate::new()),
             Arc::new(EliminateLimit::new()),
             Arc::new(ProjectionPushDown::new()),
+            Arc::new(RewriteDisjunctivePredicate::new()),
         ];
         if config.config_options.get_bool(OPT_FILTER_NULL_JOIN_KEYS) {
             rules.push(Arc::new(FilterNullJoinKeys::default()));
diff --git a/datafusion/core/tests/sql/predicates.rs 
b/datafusion/core/tests/sql/predicates.rs
index ea79e2b14..e6cb77d9a 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -386,3 +386,59 @@ async fn csv_in_set_test() -> Result<()> {
     assert_batches_sorted_eq!(expected, &actual);
     Ok(())
 }
+
+#[tokio::test]
+async fn multiple_or_predicates() -> Result<()> {
+    let ctx = SessionContext::new();
+    register_tpch_csv(&ctx, "lineitem").await?;
+    register_tpch_csv(&ctx, "part").await?;
+    let sql = "explain select
+    l_partkey
+    from
+    lineitem,
+    part
+    where
+    (
+                p_partkey = l_partkey
+            and p_brand = 'Brand#12'
+            and l_quantity >= 1 and l_quantity <= 1 + 10
+            and p_size between 1 and 5
+        )
+    or
+    (
+                p_partkey = l_partkey
+            and p_brand = 'Brand#23'
+            and l_quantity >= 10 and l_quantity <= 10 + 10
+            and p_size between 1 and 10
+        )
+    or
+    (
+                p_partkey = l_partkey
+            and p_brand = 'Brand#34'
+            and l_quantity >= 20 and l_quantity <= 20 + 10
+            and p_size between 1 and 15
+        )";
+    let msg = format!("Creating logical plan for '{}'", sql);
+    let plan = ctx.create_logical_plan(sql).expect(&msg);
+    let state = ctx.state();
+    let plan = state.optimize(&plan)?;
+    // Note that we expect `#part.p_partkey = #lineitem.l_partkey` to have been
+    // factored out and appear only once in the following plan
+    let expected =vec![
+        "Explain [plan_type:Utf8, plan:Utf8]",
+        "  Projection: #lineitem.l_partkey [l_partkey:Int64]",
+        "    Projection: #part.p_partkey = #lineitem.l_partkey AS 
BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey, 
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size 
[BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey:Boolean;N, 
l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
+        "      Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand 
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND 
#lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND 
Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >= 
Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN 
Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND 
#lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int6 [...]
+        "        CrossJoin: [l_partkey:Int64, l_quantity:Float64, 
p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
+        "          TableScan: lineitem projection=[l_partkey, l_quantity] 
[l_partkey:Int64, l_quantity:Float64]",
+        "          TableScan: part projection=[p_partkey, p_brand, p_size] 
[p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
+    ];
+    let formatted = plan.display_indent_schema().to_string();
+    let actual: Vec<&str> = formatted.trim().lines().collect();
+    assert_eq!(
+        expected, actual,
+        "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+        expected, actual
+    );
+    Ok(())
+}
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 588903ad0..6da67b6fc 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -33,6 +33,7 @@ pub mod single_distinct_to_groupby;
 pub mod subquery_filter_to_join;
 pub mod utils;
 
+pub mod rewrite_disjunctive_predicate;
 #[cfg(test)]
 pub mod test;
 
diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs 
b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs
new file mode 100644
index 000000000..b68adef5a
--- /dev/null
+++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs
@@ -0,0 +1,353 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{OptimizerConfig, OptimizerRule};
+use datafusion_common::Result;
+use datafusion_expr::logical_plan::Filter;
+use datafusion_expr::utils::from_plan;
+use datafusion_expr::Expr::BinaryExpr;
+use datafusion_expr::{Expr, LogicalPlan, Operator};
+use std::sync::Arc;
+
+#[derive(Clone, PartialEq, Debug)]
+enum Predicate {
+    And { args: Vec<Predicate> },
+    Or { args: Vec<Predicate> },
+    Other { expr: Box<Expr> },
+}
+
+fn predicate(expr: &Expr) -> Result<Predicate> {
+    match expr {
+        BinaryExpr { left, op, right } => match op {
+            Operator::And => {
+                let args = vec![predicate(left)?, predicate(right)?];
+                Ok(Predicate::And { args })
+            }
+            Operator::Or => {
+                let args = vec![predicate(left)?, predicate(right)?];
+                Ok(Predicate::Or { args })
+            }
+            _ => Ok(Predicate::Other {
+                expr: Box::new(BinaryExpr {
+                    left: left.clone(),
+                    op: *op,
+                    right: right.clone(),
+                }),
+            }),
+        },
+        _ => Ok(Predicate::Other {
+            expr: Box::new(expr.clone()),
+        }),
+    }
+}
+
+fn normalize_predicate(predicate: Predicate) -> Expr {
+    match predicate {
+        Predicate::And { args } => {
+            assert!(args.len() >= 2);
+            args.into_iter()
+                .map(normalize_predicate)
+                .reduce(Expr::and)
+                .expect("had more than one arg")
+        }
+        Predicate::Or { args } => {
+            assert!(args.len() >= 2);
+            args.into_iter()
+                .map(normalize_predicate)
+                .reduce(Expr::or)
+                .expect("had more than one arg")
+        }
+        Predicate::Other { expr } => *expr,
+    }
+}
+
+fn rewrite_predicate(predicate: Predicate) -> Predicate {
+    match predicate {
+        Predicate::And { args } => {
+            let mut rewritten_args = Vec::with_capacity(args.len());
+            for arg in args.iter() {
+                rewritten_args.push(rewrite_predicate(arg.clone()));
+            }
+            rewritten_args = flatten_and_predicates(rewritten_args);
+            Predicate::And {
+                args: rewritten_args,
+            }
+        }
+        Predicate::Or { args } => {
+            let mut rewritten_args = vec![];
+            for arg in args.iter() {
+                rewritten_args.push(rewrite_predicate(arg.clone()));
+            }
+            rewritten_args = flatten_or_predicates(rewritten_args);
+            delete_duplicate_predicates(&rewritten_args)
+        }
+        Predicate::Other { expr } => Predicate::Other {
+            expr: Box::new(*expr),
+        },
+    }
+}
+
+fn flatten_and_predicates(
+    and_predicates: impl IntoIterator<Item = Predicate>,
+) -> Vec<Predicate> {
+    let mut flattened_predicates = vec![];
+    for predicate in and_predicates {
+        match predicate {
+            Predicate::And { args } => {
+                flattened_predicates
+                    
.extend_from_slice(flatten_and_predicates(args).as_slice());
+            }
+            _ => {
+                flattened_predicates.push(predicate);
+            }
+        }
+    }
+    flattened_predicates
+}
+
+fn flatten_or_predicates(
+    or_predicates: impl IntoIterator<Item = Predicate>,
+) -> Vec<Predicate> {
+    let mut flattened_predicates = vec![];
+    for predicate in or_predicates {
+        match predicate {
+            Predicate::Or { args } => {
+                flattened_predicates
+                    .extend_from_slice(flatten_or_predicates(args).as_slice());
+            }
+            _ => {
+                flattened_predicates.push(predicate);
+            }
+        }
+    }
+    flattened_predicates
+}
+
+fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate {
+    let mut shortest_exprs: Vec<Predicate> = vec![];
+    let mut shortest_exprs_len = 0;
+    // choose the shortest AND predicate
+    for or_predicate in or_predicates.iter() {
+        match or_predicate {
+            Predicate::And { args } => {
+                let args_num = args.len();
+                if shortest_exprs.is_empty() || args_num < shortest_exprs_len {
+                    shortest_exprs = (*args).clone();
+                    shortest_exprs_len = args_num;
+                }
+            }
+            _ => {
+                // if there is no AND predicate, it must be the shortest 
expression.
+                shortest_exprs = vec![or_predicate.clone()];
+                break;
+            }
+        }
+    }
+
+    // dedup shortest_exprs
+    shortest_exprs.dedup();
+
+    // Check each element in shortest_exprs to see if it's in all the OR 
arguments.
+    let mut exist_exprs: Vec<Predicate> = vec![];
+    for expr in shortest_exprs.iter() {
+        let found = or_predicates.iter().all(|or_predicate| match or_predicate 
{
+            Predicate::And { args } => args.contains(expr),
+            _ => or_predicate == expr,
+        });
+        if found {
+            exist_exprs.push((*expr).clone());
+        }
+    }
+    if exist_exprs.is_empty() {
+        return Predicate::Or {
+            args: or_predicates.to_vec(),
+        };
+    }
+
+    // Rebuild the OR predicate.
+    // (A AND B) OR A will be optimized to A.
+    let mut new_or_predicates = vec![];
+    for or_predicate in or_predicates.iter() {
+        match or_predicate {
+            Predicate::And { args } => {
+                let mut new_args = (*args).clone();
+                new_args.retain(|expr| !exist_exprs.contains(expr));
+                if !new_args.is_empty() {
+                    if new_args.len() == 1 {
+                        new_or_predicates.push(new_args[0].clone());
+                    } else {
+                        new_or_predicates.push(Predicate::And { args: new_args 
});
+                    }
+                } else {
+                    new_or_predicates.clear();
+                    break;
+                }
+            }
+            _ => {
+                if exist_exprs.contains(or_predicate) {
+                    new_or_predicates.clear();
+                    break;
+                }
+            }
+        }
+    }
+    if !new_or_predicates.is_empty() {
+        if new_or_predicates.len() == 1 {
+            exist_exprs.push(new_or_predicates[0].clone());
+        } else {
+            exist_exprs.push(Predicate::Or {
+                args: flatten_or_predicates(new_or_predicates),
+            });
+        }
+    }
+
+    if exist_exprs.len() == 1 {
+        exist_exprs[0].clone()
+    } else {
+        Predicate::And {
+            args: flatten_and_predicates(exist_exprs),
+        }
+    }
+}
+
+#[derive(Default)]
+pub struct RewriteDisjunctivePredicate;
+
+impl RewriteDisjunctivePredicate {
+    pub fn new() -> Self {
+        Self::default()
+    }
+    fn rewrite_disjunctive_predicate(
+        &self,
+        plan: &LogicalPlan,
+        _optimizer_config: &OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        match plan {
+            LogicalPlan::Filter(filter) => {
+                let predicate = predicate(&filter.predicate)?;
+                let rewritten_predicate = rewrite_predicate(predicate);
+                let rewritten_expr = normalize_predicate(rewritten_predicate);
+                Ok(LogicalPlan::Filter(Filter {
+                    predicate: rewritten_expr,
+                    input: Arc::new(self.rewrite_disjunctive_predicate(
+                        &filter.input,
+                        _optimizer_config,
+                    )?),
+                }))
+            }
+            _ => {
+                let expr = plan.expressions();
+                let inputs = plan.inputs();
+                let new_inputs = inputs
+                    .iter()
+                    .map(|input| {
+                        self.rewrite_disjunctive_predicate(input, 
_optimizer_config)
+                    })
+                    .collect::<Result<Vec<_>>>()?;
+                from_plan(plan, &expr, &new_inputs)
+            }
+        }
+    }
+}
+
+impl OptimizerRule for RewriteDisjunctivePredicate {
+    fn optimize(
+        &self,
+        plan: &LogicalPlan,
+        optimizer_config: &mut OptimizerConfig,
+    ) -> Result<LogicalPlan> {
+        self.rewrite_disjunctive_predicate(plan, optimizer_config)
+    }
+
+    fn name(&self) -> &str {
+        "rewrite_disjunctive_predicate"
+    }
+}
+
+#[cfg(test)]
+
+mod tests {
+    use crate::rewrite_disjunctive_predicate::{
+        normalize_predicate, predicate, rewrite_predicate, Predicate,
+    };
+
+    use datafusion_common::{Result, ScalarValue};
+    use datafusion_expr::{and, col, lit, or};
+
+    #[test]
+    fn test_rewrite_predicate() -> Result<()> {
+        let equi_expr = col("t1.a").eq(col("t2.b"));
+        let gt_expr = col("t1.c").gt(lit(ScalarValue::Int8(Some(1))));
+        let lt_expr = col("t1.d").lt(lit(ScalarValue::Int8(Some(2))));
+        let expr = or(
+            and(equi_expr.clone(), gt_expr.clone()),
+            and(equi_expr.clone(), lt_expr.clone()),
+        );
+        let predicate = predicate(&expr)?;
+        assert_eq!(
+            predicate,
+            Predicate::Or {
+                args: vec![
+                    Predicate::And {
+                        args: vec![
+                            Predicate::Other {
+                                expr: Box::new(equi_expr.clone())
+                            },
+                            Predicate::Other {
+                                expr: Box::new(gt_expr.clone())
+                            }
+                        ]
+                    },
+                    Predicate::And {
+                        args: vec![
+                            Predicate::Other {
+                                expr: Box::new(equi_expr.clone())
+                            },
+                            Predicate::Other {
+                                expr: Box::new(lt_expr.clone())
+                            }
+                        ]
+                    }
+                ]
+            }
+        );
+        let rewritten_predicate = rewrite_predicate(predicate);
+        assert_eq!(
+            rewritten_predicate,
+            Predicate::And {
+                args: vec![
+                    Predicate::Other {
+                        expr: Box::new(equi_expr.clone())
+                    },
+                    Predicate::Or {
+                        args: vec![
+                            Predicate::Other {
+                                expr: Box::new(gt_expr.clone())
+                            },
+                            Predicate::Other {
+                                expr: Box::new(lt_expr.clone())
+                            }
+                        ]
+                    }
+                ]
+            }
+        );
+        let rewritten_expr = normalize_predicate(rewritten_predicate);
+        assert_eq!(rewritten_expr, and(equi_expr, or(gt_expr, lt_expr)));
+        Ok(())
+    }
+}

Reply via email to