This is an automated email from the ASF dual-hosted git repository.

xushiyan pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/hudi-rs.git


The following commit(s) were added to refs/heads/main by this push:
     new b828538  refactor: enhance `Filter` and related structs (#221)
b828538 is described below

commit b828538783a1ab5c12bc43c4f860422f61c2ea65
Author: Shiyan Xu <[email protected]>
AuthorDate: Sat Dec 7 23:09:21 2024 -1000

    refactor: enhance `Filter` and related structs (#221)
    
    - Rename `PartitionFilter` to `SchemableFilter` for generic use
    - Fix missing `Operator::NotEq` support
    - Add ergonomic functions for composing `Filter`s, such as 
`col("foo").eq("a")`
---
 crates/core/src/error.rs           |   2 +-
 crates/core/src/expr/filter.rs     | 124 ++++++++++++++++++++++++++++++++++++-
 crates/core/src/table/partition.rs |  71 ++++-----------------
 crates/datafusion/src/lib.rs       |  25 +++++---
 crates/datafusion/src/util/expr.rs |  59 ++++++------------
 5 files changed, 170 insertions(+), 111 deletions(-)

diff --git a/crates/core/src/error.rs b/crates/core/src/error.rs
index b94e39b..ae7ef2d 100644
--- a/crates/core/src/error.rs
+++ b/crates/core/src/error.rs
@@ -31,7 +31,7 @@ pub enum CoreError {
     Config(#[from] ConfigError),
 
     #[error("Data type error: {0}")]
-    DataType(String),
+    Schema(String),
 
     #[error("File group error: {0}")]
     FileGroup(String),
diff --git a/crates/core/src/expr/filter.rs b/crates/core/src/expr/filter.rs
index cf7ee5d..6be42f9 100644
--- a/crates/core/src/expr/filter.rs
+++ b/crates/core/src/expr/filter.rs
@@ -20,6 +20,9 @@
 use crate::error::CoreError;
 use crate::expr::ExprOperator;
 use crate::Result;
+use arrow_array::{ArrayRef, Scalar, StringArray};
+use arrow_cast::{cast_with_options, CastOptions};
+use arrow_schema::{DataType, Field, Schema};
 use std::str::FromStr;
 
 #[derive(Debug, Clone)]
@@ -29,7 +32,14 @@ pub struct Filter {
     pub field_value: String,
 }
 
-impl Filter {}
+impl Filter {
+    pub fn negate(&self) -> Option<Self> {
+        self.operator.negate().map(|op| Self {
+            operator: op,
+            ..self.clone()
+        })
+    }
+}
 
 impl TryFrom<(&str, &str, &str)> for Filter {
     type Error = CoreError;
@@ -50,3 +60,115 @@ impl TryFrom<(&str, &str, &str)> for Filter {
         })
     }
 }
+
+pub struct FilterField {
+    pub name: String,
+}
+
+impl FilterField {
+    pub fn new(name: impl Into<String>) -> Self {
+        Self { name: name.into() }
+    }
+
+    pub fn name(&self) -> &str {
+        &self.name
+    }
+
+    pub fn eq(&self, value: impl Into<String>) -> Filter {
+        Filter {
+            field_name: self.name.clone(),
+            operator: ExprOperator::Eq,
+            field_value: value.into(),
+        }
+    }
+
+    pub fn ne(&self, value: impl Into<String>) -> Filter {
+        Filter {
+            field_name: self.name.clone(),
+            operator: ExprOperator::Ne,
+            field_value: value.into(),
+        }
+    }
+
+    pub fn lt(&self, value: impl Into<String>) -> Filter {
+        Filter {
+            field_name: self.name.clone(),
+            operator: ExprOperator::Lt,
+            field_value: value.into(),
+        }
+    }
+
+    pub fn lte(&self, value: impl Into<String>) -> Filter {
+        Filter {
+            field_name: self.name.clone(),
+            operator: ExprOperator::Lte,
+            field_value: value.into(),
+        }
+    }
+
+    pub fn gt(&self, value: impl Into<String>) -> Filter {
+        Filter {
+            field_name: self.name.clone(),
+            operator: ExprOperator::Gt,
+            field_value: value.into(),
+        }
+    }
+
+    pub fn gte(&self, value: impl Into<String>) -> Filter {
+        Filter {
+            field_name: self.name.clone(),
+            operator: ExprOperator::Gte,
+            field_value: value.into(),
+        }
+    }
+}
+
+pub fn col(name: impl Into<String>) -> FilterField {
+    FilterField::new(name)
+}
+
+#[derive(Debug, Clone)]
+pub struct SchemableFilter {
+    pub field: Field,
+    pub operator: ExprOperator,
+    pub value: Scalar<ArrayRef>,
+}
+
+impl TryFrom<(Filter, &Schema)> for SchemableFilter {
+    type Error = CoreError;
+
+    fn try_from((filter, schema): (Filter, &Schema)) -> Result<Self, 
Self::Error> {
+        let field_name = filter.field_name.clone();
+        let field: &Field = schema.field_with_name(&field_name).map_err(|e| {
+            CoreError::Schema(format!("Field {} not found in schema: {:?}", 
field_name, e))
+        })?;
+
+        let operator = filter.operator;
+        let value = &[filter.field_value.as_str()];
+        let value = Self::cast_value(value, field.data_type())?;
+
+        let field = field.clone();
+        Ok(SchemableFilter {
+            field,
+            operator,
+            value,
+        })
+    }
+}
+
+impl SchemableFilter {
+    pub fn cast_value(value: &[&str; 1], data_type: &DataType) -> 
Result<Scalar<ArrayRef>> {
+        let cast_options = CastOptions {
+            safe: false,
+            format_options: Default::default(),
+        };
+
+        let value = StringArray::from(Vec::from(value));
+
+        Ok(Scalar::new(
+            cast_with_options(&value, data_type, &cast_options).map_err(|e| {
+                CoreError::Schema(format!("Unable to cast {:?}: {:?}", 
data_type, e))
+            })?,
+        ))
+    }
+}
diff --git a/crates/core/src/table/partition.rs 
b/crates/core/src/table/partition.rs
index 14f6b30..b781fca 100644
--- a/crates/core/src/table/partition.rs
+++ b/crates/core/src/table/partition.rs
@@ -19,17 +19,14 @@
 use crate::config::table::HudiTableConfig;
 use crate::config::HudiConfigs;
 use crate::error::CoreError::InvalidPartitionPath;
-use crate::expr::filter::Filter;
+use crate::expr::filter::{Filter, SchemableFilter};
 use crate::expr::ExprOperator;
 use crate::Result;
 
-use arrow_array::{ArrayRef, Scalar, StringArray};
-use arrow_cast::{cast_with_options, CastOptions};
+use arrow_array::{ArrayRef, Scalar};
 use arrow_ord::cmp::{eq, gt, gt_eq, lt, lt_eq, neq};
 use arrow_schema::Schema;
-use arrow_schema::{DataType, Field};
 
-use crate::table::CoreError;
 use std::collections::HashMap;
 use std::sync::Arc;
 
@@ -39,7 +36,7 @@ pub struct PartitionPruner {
     schema: Arc<Schema>,
     is_hive_style: bool,
     is_url_encoded: bool,
-    and_filters: Vec<PartitionFilter>,
+    and_filters: Vec<SchemableFilter>,
 }
 
 impl PartitionPruner {
@@ -50,8 +47,8 @@ impl PartitionPruner {
     ) -> Result<Self> {
         let and_filters = and_filters
             .iter()
-            .map(|filter| PartitionFilter::try_from((filter.clone(), 
partition_schema)))
-            .collect::<Result<Vec<PartitionFilter>>>()?;
+            .map(|filter| SchemableFilter::try_from((filter.clone(), 
partition_schema)))
+            .collect::<Result<Vec<SchemableFilter>>>()?;
 
         let schema = Arc::new(partition_schema.clone());
         let is_hive_style: bool = hudi_configs
@@ -151,59 +148,13 @@ impl PartitionPruner {
                 } else {
                     part
                 };
-                let scalar = PartitionFilter::cast_value(&[value], 
field.data_type())?;
+                let scalar = SchemableFilter::cast_value(&[value], 
field.data_type())?;
                 Ok((field.name().to_string(), scalar))
             })
             .collect()
     }
 }
 
-/// A partition filter that represents a filter expression for partition 
pruning.
-#[derive(Debug, Clone)]
-pub struct PartitionFilter {
-    pub field: Field,
-    pub operator: ExprOperator,
-    pub value: Scalar<ArrayRef>,
-}
-
-impl TryFrom<(Filter, &Schema)> for PartitionFilter {
-    type Error = CoreError;
-
-    fn try_from((filter, partition_schema): (Filter, &Schema)) -> Result<Self, 
Self::Error> {
-        let field: &Field = partition_schema
-            .field_with_name(&filter.field_name)
-            .map_err(|_| InvalidPartitionPath("Partition path should be in 
schema.".to_string()))?;
-
-        let operator = filter.operator;
-        let value = &[filter.field_value.as_str()];
-        let value = Self::cast_value(value, field.data_type())?;
-
-        let field = field.clone();
-        Ok(PartitionFilter {
-            field,
-            operator,
-            value,
-        })
-    }
-}
-
-impl PartitionFilter {
-    pub fn cast_value(value: &[&str; 1], data_type: &DataType) -> 
Result<Scalar<ArrayRef>> {
-        let cast_options = CastOptions {
-            safe: false,
-            format_options: Default::default(),
-        };
-
-        let value = StringArray::from(Vec::from(value));
-
-        Ok(Scalar::new(
-            cast_with_options(&value, data_type, &cast_options).map_err(|e| {
-                CoreError::DataType(format!("Unable to cast {:?}: {:?}", 
data_type, e))
-            })?,
-        ))
-    }
-}
-
 #[cfg(test)]
 mod tests {
     use super::*;
@@ -345,7 +296,7 @@ mod tests {
             field_value: "2023-01-01".to_string(),
         };
 
-        let partition_filter = PartitionFilter::try_from((filter, 
&schema)).unwrap();
+        let partition_filter = SchemableFilter::try_from((filter, 
&schema)).unwrap();
         assert_eq!(partition_filter.field.name(), "date");
         assert_eq!(partition_filter.operator, ExprOperator::Eq);
 
@@ -365,12 +316,12 @@ mod tests {
             operator: ExprOperator::Eq,
             field_value: "2023-01-01".to_string(),
         };
-        let result = PartitionFilter::try_from((filter, &schema));
+        let result = SchemableFilter::try_from((filter, &schema));
         assert!(result.is_err());
         assert!(result
             .unwrap_err()
             .to_string()
-            .contains("Partition path should be in schema."));
+            .contains("Field invalid_field not found in schema"));
     }
 
     #[test]
@@ -381,7 +332,7 @@ mod tests {
             operator: ExprOperator::Eq,
             field_value: "not_a_number".to_string(),
         };
-        let result = PartitionFilter::try_from((filter, &schema));
+        let result = SchemableFilter::try_from((filter, &schema));
         assert!(result.is_err());
     }
 
@@ -394,7 +345,7 @@ mod tests {
                 operator: ExprOperator::from_str(op).unwrap(),
                 field_value: "5".to_string(),
             };
-            let partition_filter = PartitionFilter::try_from((filter, 
&schema));
+            let partition_filter = SchemableFilter::try_from((filter, 
&schema));
             let filter = partition_filter.unwrap();
             assert_eq!(filter.field.name(), "count");
             assert_eq!(filter.operator, ExprOperator::from_str(op).unwrap());
diff --git a/crates/datafusion/src/lib.rs b/crates/datafusion/src/lib.rs
index a976a0f..718dc61 100644
--- a/crates/datafusion/src/lib.rs
+++ b/crates/datafusion/src/lib.rs
@@ -127,13 +127,18 @@ impl HudiDataSource {
     fn is_supported_operator(&self, op: &Operator) -> bool {
         matches!(
             op,
-            Operator::Eq | Operator::Gt | Operator::Lt | Operator::GtEq | 
Operator::LtEq
+            Operator::Eq
+                | Operator::NotEq
+                | Operator::Gt
+                | Operator::Lt
+                | Operator::GtEq
+                | Operator::LtEq
         )
     }
 
     fn is_supported_operand(&self, expr: &Expr) -> bool {
         match expr {
-            Expr::Column(col) => 
self.schema().field_with_name(&col.name).is_ok(),
+            Expr::Column(col) => 
self.schema().column_with_name(&col.name).is_some(),
             Expr::Literal(_) => true,
             _ => false,
         }
@@ -546,19 +551,19 @@ mod tests {
                 .await
                 .unwrap();
 
-        let expr1 = Expr::BinaryExpr(BinaryExpr {
+        let expr0 = Expr::BinaryExpr(BinaryExpr {
             left: 
Box::new(Expr::Column(Column::from_name("name".to_string()))),
             op: Operator::Eq,
             right: 
Box::new(Expr::Literal(ScalarValue::Utf8(Some("Alice".to_string())))),
         });
 
-        let expr2 = Expr::BinaryExpr(BinaryExpr {
+        let expr1 = Expr::BinaryExpr(BinaryExpr {
             left: 
Box::new(Expr::Column(Column::from_name("intField".to_string()))),
             op: Operator::Gt,
             right: Box::new(Expr::Literal(ScalarValue::Int32(Some(20000)))),
         });
 
-        let expr3 = Expr::BinaryExpr(BinaryExpr {
+        let expr2 = Expr::BinaryExpr(BinaryExpr {
             left: Box::new(Expr::Column(Column::from_name(
                 "nonexistent_column".to_string(),
             ))),
@@ -566,28 +571,28 @@ mod tests {
             right: Box::new(Expr::Literal(ScalarValue::Int32(Some(1)))),
         });
 
-        let expr4 = Expr::BinaryExpr(BinaryExpr {
+        let expr3 = Expr::BinaryExpr(BinaryExpr {
             left: 
Box::new(Expr::Column(Column::from_name("name".to_string()))),
             op: Operator::NotEq,
             right: 
Box::new(Expr::Literal(ScalarValue::Utf8(Some("Diana".to_string())))),
         });
 
-        let expr5 = Expr::Literal(ScalarValue::Int32(Some(10)));
+        let expr4 = Expr::Literal(ScalarValue::Int32(Some(10)));
 
-        let expr6 = Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr {
+        let expr5 = Expr::Not(Box::new(Expr::BinaryExpr(BinaryExpr {
             left: 
Box::new(Expr::Column(Column::from_name("intField".to_string()))),
             op: Operator::Gt,
             right: Box::new(Expr::Literal(ScalarValue::Int32(Some(20000)))),
         })));
 
-        let filters = vec![&expr1, &expr2, &expr3, &expr4, &expr5, &expr6];
+        let filters = vec![&expr0, &expr1, &expr2, &expr3, &expr4, &expr5];
         let result = 
table_provider.supports_filters_pushdown(&filters).unwrap();
 
         assert_eq!(result.len(), 6);
         assert_eq!(result[0], TableProviderFilterPushDown::Inexact);
         assert_eq!(result[1], TableProviderFilterPushDown::Inexact);
         assert_eq!(result[2], TableProviderFilterPushDown::Unsupported);
-        assert_eq!(result[3], TableProviderFilterPushDown::Unsupported);
+        assert_eq!(result[3], TableProviderFilterPushDown::Inexact);
         assert_eq!(result[4], TableProviderFilterPushDown::Unsupported);
         assert_eq!(result[5], TableProviderFilterPushDown::Inexact);
     }
diff --git a/crates/datafusion/src/util/expr.rs 
b/crates/datafusion/src/util/expr.rs
index daa21b6..3a8c94b 100644
--- a/crates/datafusion/src/util/expr.rs
+++ b/crates/datafusion/src/util/expr.rs
@@ -19,8 +19,7 @@
 
 use datafusion::logical_expr::Operator;
 use datafusion_expr::{BinaryExpr, Expr};
-use hudi_core::expr::filter::Filter as HudiFilter;
-use hudi_core::expr::ExprOperator;
+use hudi_core::expr::filter::{col, Filter as HudiFilter};
 
 /// Converts DataFusion expressions into Hudi filters.
 ///
@@ -36,25 +35,14 @@ use hudi_core::expr::ExprOperator;
 ///
 /// TODO: Handle other DataFusion [`Expr`]
 pub fn exprs_to_filters(exprs: &[Expr]) -> Vec<HudiFilter> {
-    let mut filters: Vec<HudiFilter> = Vec::new();
-
-    for expr in exprs {
-        match expr {
-            Expr::BinaryExpr(binary_expr) => {
-                if let Some(filter) = binary_expr_to_filter(binary_expr) {
-                    filters.push(filter);
-                }
-            }
-            Expr::Not(not_expr) => {
-                if let Some(filter) = not_expr_to_filter(not_expr) {
-                    filters.push(filter);
-                }
-            }
-            _ => {}
-        }
-    }
-
-    filters
+    exprs
+        .iter()
+        .filter_map(|expr| match expr {
+            Expr::BinaryExpr(binary_expr) => 
binary_expr_to_filter(binary_expr),
+            Expr::Not(not_expr) => not_expr_to_filter(not_expr),
+            _ => None,
+        })
+        .collect()
 }
 
 /// Converts a binary expression [`Expr::BinaryExpr`] into a [`HudiFilter`].
@@ -66,34 +54,27 @@ fn binary_expr_to_filter(binary_expr: &BinaryExpr) -> 
Option<HudiFilter> {
         _ => return None,
     };
 
-    let field_name = column.name().to_string();
+    let field = col(column.name());
+    let lit_str = literal.to_string();
 
-    let operator = match binary_expr.op {
-        Operator::Eq => ExprOperator::Eq,
-        Operator::NotEq => ExprOperator::Ne,
-        Operator::Lt => ExprOperator::Lt,
-        Operator::LtEq => ExprOperator::Lte,
-        Operator::Gt => ExprOperator::Gt,
-        Operator::GtEq => ExprOperator::Gte,
+    let filter = match binary_expr.op {
+        Operator::Eq => field.eq(lit_str),
+        Operator::NotEq => field.ne(lit_str),
+        Operator::Lt => field.lt(lit_str),
+        Operator::LtEq => field.lte(lit_str),
+        Operator::Gt => field.gt(lit_str),
+        Operator::GtEq => field.gte(lit_str),
         _ => return None,
     };
 
-    let value = literal.to_string();
-
-    Some(HudiFilter {
-        field_name,
-        operator,
-        field_value: value,
-    })
+    Some(filter)
 }
 
 /// Converts a NOT expression (`Expr::Not`) into a `PartitionFilter`.
 fn not_expr_to_filter(not_expr: &Expr) -> Option<HudiFilter> {
     match not_expr {
         Expr::BinaryExpr(ref binary_expr) => {
-            let mut filter = binary_expr_to_filter(binary_expr)?;
-            filter.operator = filter.operator.negate()?;
-            Some(filter)
+            binary_expr_to_filter(binary_expr).map(|filter| filter.negate())?
         }
         _ => None,
     }

Reply via email to