This is an automated email from the ASF dual-hosted git repository.
alamb pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion.git
The following commit(s) were added to refs/heads/master by this push:
new 4005076d8 feat: add optimize rule `rewrite_disjunctive_predicate`
(#2858)
4005076d8 is described below
commit 4005076d8e3e4fa07541da62f7a6c9c755029da1
Author: xudong.w <[email protected]>
AuthorDate: Wed Jul 27 04:00:19 2022 +0800
feat: add optimize rule `rewrite_disjunctive_predicate` (#2858)
* feat: add optimize rule: rewrite_disjunctive_predicate
* address comments and add tests
* Update datafusion/optimizer/src/rewrite_disjunctive_predicate.rs
Co-authored-by: Andrew Lamb <[email protected]>
Co-authored-by: Andrew Lamb <[email protected]>
---
datafusion/core/src/execution/context.rs | 2 +
datafusion/core/tests/sql/predicates.rs | 56 ++++
datafusion/optimizer/src/lib.rs | 1 +
.../optimizer/src/rewrite_disjunctive_predicate.rs | 353 +++++++++++++++++++++
4 files changed, 412 insertions(+)
diff --git a/datafusion/core/src/execution/context.rs
b/datafusion/core/src/execution/context.rs
index 41964e33a..96705bb0c 100644
--- a/datafusion/core/src/execution/context.rs
+++ b/datafusion/core/src/execution/context.rs
@@ -106,6 +106,7 @@ use
datafusion_optimizer::decorrelate_scalar_subquery::DecorrelateScalarSubquery
use datafusion_optimizer::decorrelate_where_exists::DecorrelateWhereExists;
use datafusion_optimizer::decorrelate_where_in::DecorrelateWhereIn;
use datafusion_optimizer::filter_null_join_keys::FilterNullJoinKeys;
+use
datafusion_optimizer::rewrite_disjunctive_predicate::RewriteDisjunctivePredicate;
use datafusion_sql::{
parser::DFParser,
planner::{ContextProvider, SqlToRel},
@@ -1367,6 +1368,7 @@ impl SessionState {
Arc::new(CommonSubexprEliminate::new()),
Arc::new(EliminateLimit::new()),
Arc::new(ProjectionPushDown::new()),
+ Arc::new(RewriteDisjunctivePredicate::new()),
];
if config.config_options.get_bool(OPT_FILTER_NULL_JOIN_KEYS) {
rules.push(Arc::new(FilterNullJoinKeys::default()));
diff --git a/datafusion/core/tests/sql/predicates.rs
b/datafusion/core/tests/sql/predicates.rs
index ea79e2b14..e6cb77d9a 100644
--- a/datafusion/core/tests/sql/predicates.rs
+++ b/datafusion/core/tests/sql/predicates.rs
@@ -386,3 +386,59 @@ async fn csv_in_set_test() -> Result<()> {
assert_batches_sorted_eq!(expected, &actual);
Ok(())
}
+
+#[tokio::test]
+async fn multiple_or_predicates() -> Result<()> {
+ let ctx = SessionContext::new();
+ register_tpch_csv(&ctx, "lineitem").await?;
+ register_tpch_csv(&ctx, "part").await?;
+ let sql = "explain select
+ l_partkey
+ from
+ lineitem,
+ part
+ where
+ (
+ p_partkey = l_partkey
+ and p_brand = 'Brand#12'
+ and l_quantity >= 1 and l_quantity <= 1 + 10
+ and p_size between 1 and 5
+ )
+ or
+ (
+ p_partkey = l_partkey
+ and p_brand = 'Brand#23'
+ and l_quantity >= 10 and l_quantity <= 10 + 10
+ and p_size between 1 and 10
+ )
+ or
+ (
+ p_partkey = l_partkey
+ and p_brand = 'Brand#34'
+ and l_quantity >= 20 and l_quantity <= 20 + 10
+ and p_size between 1 and 15
+ )";
+ let msg = format!("Creating logical plan for '{}'", sql);
+ let plan = ctx.create_logical_plan(sql).expect(&msg);
+ let state = ctx.state();
+ let plan = state.optimize(&plan)?;
+ // Note that we expect `#part.p_partkey = #lineitem.l_partkey` to have been
+ // factored out and appear only once in the following plan
+ let expected =vec![
+ "Explain [plan_type:Utf8, plan:Utf8]",
+ " Projection: #lineitem.l_partkey [l_partkey:Int64]",
+ " Projection: #part.p_partkey = #lineitem.l_partkey AS
BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey,
#lineitem.l_partkey, #lineitem.l_quantity, #part.p_brand, #part.p_size
[BinaryExpr-=Column-lineitem.l_partkeyColumn-part.p_partkey:Boolean;N,
l_partkey:Int64, l_quantity:Float64, p_brand:Utf8, p_size:Int32]",
+ " Filter: #part.p_partkey = #lineitem.l_partkey AND #part.p_brand
= Utf8(\"Brand#12\") AND #lineitem.l_quantity >= Int64(1) AND
#lineitem.l_quantity <= Int64(11) AND #part.p_size BETWEEN Int64(1) AND
Int64(5) OR #part.p_brand = Utf8(\"Brand#23\") AND #lineitem.l_quantity >=
Int64(10) AND #lineitem.l_quantity <= Int64(20) AND #part.p_size BETWEEN
Int64(1) AND Int64(10) OR #part.p_brand = Utf8(\"Brand#34\") AND
#lineitem.l_quantity >= Int64(20) AND #lineitem.l_quantity <= Int6 [...]
+ " CrossJoin: [l_partkey:Int64, l_quantity:Float64,
p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
+ " TableScan: lineitem projection=[l_partkey, l_quantity]
[l_partkey:Int64, l_quantity:Float64]",
+ " TableScan: part projection=[p_partkey, p_brand, p_size]
[p_partkey:Int64, p_brand:Utf8, p_size:Int32]",
+ ];
+ let formatted = plan.display_indent_schema().to_string();
+ let actual: Vec<&str> = formatted.trim().lines().collect();
+ assert_eq!(
+ expected, actual,
+ "\n\nexpected:\n\n{:#?}\nactual:\n\n{:#?}\n\n",
+ expected, actual
+ );
+ Ok(())
+}
diff --git a/datafusion/optimizer/src/lib.rs b/datafusion/optimizer/src/lib.rs
index 588903ad0..6da67b6fc 100644
--- a/datafusion/optimizer/src/lib.rs
+++ b/datafusion/optimizer/src/lib.rs
@@ -33,6 +33,7 @@ pub mod single_distinct_to_groupby;
pub mod subquery_filter_to_join;
pub mod utils;
+pub mod rewrite_disjunctive_predicate;
#[cfg(test)]
pub mod test;
diff --git a/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs
b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs
new file mode 100644
index 000000000..b68adef5a
--- /dev/null
+++ b/datafusion/optimizer/src/rewrite_disjunctive_predicate.rs
@@ -0,0 +1,353 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use crate::{OptimizerConfig, OptimizerRule};
+use datafusion_common::Result;
+use datafusion_expr::logical_plan::Filter;
+use datafusion_expr::utils::from_plan;
+use datafusion_expr::Expr::BinaryExpr;
+use datafusion_expr::{Expr, LogicalPlan, Operator};
+use std::sync::Arc;
+
+#[derive(Clone, PartialEq, Debug)]
+enum Predicate {
+ And { args: Vec<Predicate> },
+ Or { args: Vec<Predicate> },
+ Other { expr: Box<Expr> },
+}
+
+fn predicate(expr: &Expr) -> Result<Predicate> {
+ match expr {
+ BinaryExpr { left, op, right } => match op {
+ Operator::And => {
+ let args = vec![predicate(left)?, predicate(right)?];
+ Ok(Predicate::And { args })
+ }
+ Operator::Or => {
+ let args = vec![predicate(left)?, predicate(right)?];
+ Ok(Predicate::Or { args })
+ }
+ _ => Ok(Predicate::Other {
+ expr: Box::new(BinaryExpr {
+ left: left.clone(),
+ op: *op,
+ right: right.clone(),
+ }),
+ }),
+ },
+ _ => Ok(Predicate::Other {
+ expr: Box::new(expr.clone()),
+ }),
+ }
+}
+
+fn normalize_predicate(predicate: Predicate) -> Expr {
+ match predicate {
+ Predicate::And { args } => {
+ assert!(args.len() >= 2);
+ args.into_iter()
+ .map(normalize_predicate)
+ .reduce(Expr::and)
+ .expect("had more than one arg")
+ }
+ Predicate::Or { args } => {
+ assert!(args.len() >= 2);
+ args.into_iter()
+ .map(normalize_predicate)
+ .reduce(Expr::or)
+ .expect("had more than one arg")
+ }
+ Predicate::Other { expr } => *expr,
+ }
+}
+
+fn rewrite_predicate(predicate: Predicate) -> Predicate {
+ match predicate {
+ Predicate::And { args } => {
+ let mut rewritten_args = Vec::with_capacity(args.len());
+ for arg in args.iter() {
+ rewritten_args.push(rewrite_predicate(arg.clone()));
+ }
+ rewritten_args = flatten_and_predicates(rewritten_args);
+ Predicate::And {
+ args: rewritten_args,
+ }
+ }
+ Predicate::Or { args } => {
+ let mut rewritten_args = vec![];
+ for arg in args.iter() {
+ rewritten_args.push(rewrite_predicate(arg.clone()));
+ }
+ rewritten_args = flatten_or_predicates(rewritten_args);
+ delete_duplicate_predicates(&rewritten_args)
+ }
+ Predicate::Other { expr } => Predicate::Other {
+ expr: Box::new(*expr),
+ },
+ }
+}
+
+fn flatten_and_predicates(
+ and_predicates: impl IntoIterator<Item = Predicate>,
+) -> Vec<Predicate> {
+ let mut flattened_predicates = vec![];
+ for predicate in and_predicates {
+ match predicate {
+ Predicate::And { args } => {
+ flattened_predicates
+
.extend_from_slice(flatten_and_predicates(args).as_slice());
+ }
+ _ => {
+ flattened_predicates.push(predicate);
+ }
+ }
+ }
+ flattened_predicates
+}
+
+fn flatten_or_predicates(
+ or_predicates: impl IntoIterator<Item = Predicate>,
+) -> Vec<Predicate> {
+ let mut flattened_predicates = vec![];
+ for predicate in or_predicates {
+ match predicate {
+ Predicate::Or { args } => {
+ flattened_predicates
+ .extend_from_slice(flatten_or_predicates(args).as_slice());
+ }
+ _ => {
+ flattened_predicates.push(predicate);
+ }
+ }
+ }
+ flattened_predicates
+}
+
+fn delete_duplicate_predicates(or_predicates: &[Predicate]) -> Predicate {
+ let mut shortest_exprs: Vec<Predicate> = vec![];
+ let mut shortest_exprs_len = 0;
+ // choose the shortest AND predicate
+ for or_predicate in or_predicates.iter() {
+ match or_predicate {
+ Predicate::And { args } => {
+ let args_num = args.len();
+ if shortest_exprs.is_empty() || args_num < shortest_exprs_len {
+ shortest_exprs = (*args).clone();
+ shortest_exprs_len = args_num;
+ }
+ }
+ _ => {
+ // if there is no AND predicate, it must be the shortest
expression.
+ shortest_exprs = vec![or_predicate.clone()];
+ break;
+ }
+ }
+ }
+
+ // dedup shortest_exprs
+ shortest_exprs.dedup();
+
+ // Check each element in shortest_exprs to see if it's in all the OR
arguments.
+ let mut exist_exprs: Vec<Predicate> = vec![];
+ for expr in shortest_exprs.iter() {
+ let found = or_predicates.iter().all(|or_predicate| match or_predicate
{
+ Predicate::And { args } => args.contains(expr),
+ _ => or_predicate == expr,
+ });
+ if found {
+ exist_exprs.push((*expr).clone());
+ }
+ }
+ if exist_exprs.is_empty() {
+ return Predicate::Or {
+ args: or_predicates.to_vec(),
+ };
+ }
+
+ // Rebuild the OR predicate.
+ // (A AND B) OR A will be optimized to A.
+ let mut new_or_predicates = vec![];
+ for or_predicate in or_predicates.iter() {
+ match or_predicate {
+ Predicate::And { args } => {
+ let mut new_args = (*args).clone();
+ new_args.retain(|expr| !exist_exprs.contains(expr));
+ if !new_args.is_empty() {
+ if new_args.len() == 1 {
+ new_or_predicates.push(new_args[0].clone());
+ } else {
+ new_or_predicates.push(Predicate::And { args: new_args
});
+ }
+ } else {
+ new_or_predicates.clear();
+ break;
+ }
+ }
+ _ => {
+ if exist_exprs.contains(or_predicate) {
+ new_or_predicates.clear();
+ break;
+ }
+ }
+ }
+ }
+ if !new_or_predicates.is_empty() {
+ if new_or_predicates.len() == 1 {
+ exist_exprs.push(new_or_predicates[0].clone());
+ } else {
+ exist_exprs.push(Predicate::Or {
+ args: flatten_or_predicates(new_or_predicates),
+ });
+ }
+ }
+
+ if exist_exprs.len() == 1 {
+ exist_exprs[0].clone()
+ } else {
+ Predicate::And {
+ args: flatten_and_predicates(exist_exprs),
+ }
+ }
+}
+
+#[derive(Default)]
+pub struct RewriteDisjunctivePredicate;
+
+impl RewriteDisjunctivePredicate {
+ pub fn new() -> Self {
+ Self::default()
+ }
+ fn rewrite_disjunctive_predicate(
+ &self,
+ plan: &LogicalPlan,
+ _optimizer_config: &OptimizerConfig,
+ ) -> Result<LogicalPlan> {
+ match plan {
+ LogicalPlan::Filter(filter) => {
+ let predicate = predicate(&filter.predicate)?;
+ let rewritten_predicate = rewrite_predicate(predicate);
+ let rewritten_expr = normalize_predicate(rewritten_predicate);
+ Ok(LogicalPlan::Filter(Filter {
+ predicate: rewritten_expr,
+ input: Arc::new(self.rewrite_disjunctive_predicate(
+ &filter.input,
+ _optimizer_config,
+ )?),
+ }))
+ }
+ _ => {
+ let expr = plan.expressions();
+ let inputs = plan.inputs();
+ let new_inputs = inputs
+ .iter()
+ .map(|input| {
+ self.rewrite_disjunctive_predicate(input,
_optimizer_config)
+ })
+ .collect::<Result<Vec<_>>>()?;
+ from_plan(plan, &expr, &new_inputs)
+ }
+ }
+ }
+}
+
+impl OptimizerRule for RewriteDisjunctivePredicate {
+ fn optimize(
+ &self,
+ plan: &LogicalPlan,
+ optimizer_config: &mut OptimizerConfig,
+ ) -> Result<LogicalPlan> {
+ self.rewrite_disjunctive_predicate(plan, optimizer_config)
+ }
+
+ fn name(&self) -> &str {
+ "rewrite_disjunctive_predicate"
+ }
+}
+
+#[cfg(test)]
+
+mod tests {
+ use crate::rewrite_disjunctive_predicate::{
+ normalize_predicate, predicate, rewrite_predicate, Predicate,
+ };
+
+ use datafusion_common::{Result, ScalarValue};
+ use datafusion_expr::{and, col, lit, or};
+
+ #[test]
+ fn test_rewrite_predicate() -> Result<()> {
+ let equi_expr = col("t1.a").eq(col("t2.b"));
+ let gt_expr = col("t1.c").gt(lit(ScalarValue::Int8(Some(1))));
+ let lt_expr = col("t1.d").lt(lit(ScalarValue::Int8(Some(2))));
+ let expr = or(
+ and(equi_expr.clone(), gt_expr.clone()),
+ and(equi_expr.clone(), lt_expr.clone()),
+ );
+ let predicate = predicate(&expr)?;
+ assert_eq!(
+ predicate,
+ Predicate::Or {
+ args: vec![
+ Predicate::And {
+ args: vec![
+ Predicate::Other {
+ expr: Box::new(equi_expr.clone())
+ },
+ Predicate::Other {
+ expr: Box::new(gt_expr.clone())
+ }
+ ]
+ },
+ Predicate::And {
+ args: vec![
+ Predicate::Other {
+ expr: Box::new(equi_expr.clone())
+ },
+ Predicate::Other {
+ expr: Box::new(lt_expr.clone())
+ }
+ ]
+ }
+ ]
+ }
+ );
+ let rewritten_predicate = rewrite_predicate(predicate);
+ assert_eq!(
+ rewritten_predicate,
+ Predicate::And {
+ args: vec![
+ Predicate::Other {
+ expr: Box::new(equi_expr.clone())
+ },
+ Predicate::Or {
+ args: vec![
+ Predicate::Other {
+ expr: Box::new(gt_expr.clone())
+ },
+ Predicate::Other {
+ expr: Box::new(lt_expr.clone())
+ }
+ ]
+ }
+ ]
+ }
+ );
+ let rewritten_expr = normalize_predicate(rewritten_predicate);
+ assert_eq!(rewritten_expr, and(equi_expr, or(gt_expr, lt_expr)));
+ Ok(())
+ }
+}