This is an automated email from the ASF dual-hosted git repository.

lzljs3620320 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-rust.git


The following commit(s) were added to refs/heads/main by this push:
     new 10e190b  feat(datafusion): support partition predicate pushdown (#190)
10e190b is described below

commit 10e190b10a0c8c588aeaccea86eb4a24a538253a
Author: Zach <[email protected]>
AuthorDate: Thu Apr 2 21:54:59 2026 +0800

    feat(datafusion): support partition predicate pushdown (#190)
---
 .../integrations/datafusion/src/filter_pushdown.rs | 536 +++++++++++++++++++++
 crates/integrations/datafusion/src/lib.rs          |   4 +-
 .../datafusion/src/physical_plan/scan.rs           |   5 +
 crates/integrations/datafusion/src/table/mod.rs    | 164 ++++++-
 .../integrations/datafusion/tests/read_tables.rs   | 137 ++++--
 5 files changed, 807 insertions(+), 39 deletions(-)

diff --git a/crates/integrations/datafusion/src/filter_pushdown.rs 
b/crates/integrations/datafusion/src/filter_pushdown.rs
new file mode 100644
index 0000000..91c65d3
--- /dev/null
+++ b/crates/integrations/datafusion/src/filter_pushdown.rs
@@ -0,0 +1,536 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements.  See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership.  The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License.  You may obtain a copy of the License at
+//
+//   http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied.  See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::common::{Column, ScalarValue};
+use datafusion::logical_expr::expr::InList;
+use datafusion::logical_expr::{Between, BinaryExpr, Expr, Operator, 
TableProviderFilterPushDown};
+use paimon::spec::{DataField, DataType, Datum, Predicate, PredicateBuilder};
+
+pub(crate) fn classify_filter_pushdown(
+    filter: &Expr,
+    fields: &[DataField],
+    partition_keys: &[String],
+) -> TableProviderFilterPushDown {
+    let translator = FilterTranslator::new(fields);
+    if translator.translate(filter).is_some() {
+        let partition_translator = 
FilterTranslator::for_allowed_columns(fields, partition_keys);
+        if partition_translator.translate(filter).is_some() {
+            TableProviderFilterPushDown::Exact
+        } else {
+            TableProviderFilterPushDown::Inexact
+        }
+    } else if split_conjunction(filter)
+        .into_iter()
+        .any(|expr| translator.translate(expr).is_some())
+    {
+        TableProviderFilterPushDown::Inexact
+    } else {
+        TableProviderFilterPushDown::Unsupported
+    }
+}
+
+pub(crate) fn build_pushed_predicate(filters: &[Expr], fields: &[DataField]) 
-> Option<Predicate> {
+    let translator = FilterTranslator::new(fields);
+    let pushed: Vec<_> = filters
+        .iter()
+        .flat_map(split_conjunction)
+        .filter_map(|filter| translator.translate(filter))
+        .collect();
+
+    if pushed.is_empty() {
+        None
+    } else {
+        Some(Predicate::and(pushed))
+    }
+}
+
+fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
+    match expr {
+        Expr::BinaryExpr(BinaryExpr {
+            left,
+            op: Operator::And,
+            right,
+        }) => {
+            let mut conjuncts = split_conjunction(left.as_ref());
+            conjuncts.extend(split_conjunction(right.as_ref()));
+            conjuncts
+        }
+        other => vec![other],
+    }
+}
+
+struct FilterTranslator<'a> {
+    fields: &'a [DataField],
+    allowed_columns: Option<&'a [String]>,
+    predicate_builder: PredicateBuilder,
+}
+
+impl<'a> FilterTranslator<'a> {
+    fn new(fields: &'a [DataField]) -> Self {
+        Self {
+            fields,
+            allowed_columns: None,
+            predicate_builder: PredicateBuilder::new(fields),
+        }
+    }
+
+    fn for_allowed_columns(fields: &'a [DataField], allowed_columns: &'a 
[String]) -> Self {
+        Self {
+            fields,
+            allowed_columns: Some(allowed_columns),
+            predicate_builder: PredicateBuilder::new(fields),
+        }
+    }
+
+    fn translate(&self, expr: &Expr) -> Option<Predicate> {
+        match expr {
+            Expr::BinaryExpr(binary) => self.translate_binary(binary),
+            // NOT is intentionally not translated: Predicate::Not uses 
two-valued
+            // logic (!bool), which incorrectly returns true when the inner 
predicate
+            // evaluates NULL to false. Combined with Exact pushdown precision,
+            // DataFusion would remove the residual filter, producing wrong 
results.
+            Expr::Not(_) => None,
+            Expr::IsNull(inner) => {
+                let field = self.resolve_field(inner.as_ref())?;
+                self.predicate_builder.is_null(field.name()).ok()
+            }
+            Expr::IsNotNull(inner) => {
+                let field = self.resolve_field(inner.as_ref())?;
+                self.predicate_builder.is_not_null(field.name()).ok()
+            }
+            Expr::InList(in_list) => self.translate_in_list(in_list),
+            Expr::Between(between) => self.translate_between(between),
+            _ => None,
+        }
+    }
+
+    fn translate_binary(&self, binary: &BinaryExpr) -> Option<Predicate> {
+        match binary.op {
+            Operator::And => Some(Predicate::and(vec![
+                self.translate(binary.left.as_ref())?,
+                self.translate(binary.right.as_ref())?,
+            ])),
+            Operator::Or => Some(Predicate::or(vec![
+                self.translate(binary.left.as_ref())?,
+                self.translate(binary.right.as_ref())?,
+            ])),
+            Operator::Eq
+            | Operator::NotEq
+            | Operator::Lt
+            | Operator::LtEq
+            | Operator::Gt
+            | Operator::GtEq => self.translate_comparison(binary),
+            _ => None,
+        }
+    }
+
+    fn translate_comparison(&self, binary: &BinaryExpr) -> Option<Predicate> {
+        if let Some(predicate) = self.translate_column_literal_comparison(
+            binary.left.as_ref(),
+            binary.op,
+            binary.right.as_ref(),
+        ) {
+            return Some(predicate);
+        }
+
+        let reversed = reverse_comparison_operator(binary.op)?;
+        self.translate_column_literal_comparison(
+            binary.right.as_ref(),
+            reversed,
+            binary.left.as_ref(),
+        )
+    }
+
+    fn translate_column_literal_comparison(
+        &self,
+        column_expr: &Expr,
+        op: Operator,
+        literal_expr: &Expr,
+    ) -> Option<Predicate> {
+        let field = self.resolve_field(column_expr)?;
+        let scalar = extract_scalar_literal(literal_expr)?;
+        let datum = scalar_to_datum(scalar, field.data_type())?;
+
+        match op {
+            Operator::Eq => self.predicate_builder.equal(field.name(), 
datum).ok(),
+            Operator::NotEq => self.predicate_builder.not_equal(field.name(), 
datum).ok(),
+            Operator::Lt => self.predicate_builder.less_than(field.name(), 
datum).ok(),
+            Operator::LtEq => self
+                .predicate_builder
+                .less_or_equal(field.name(), datum)
+                .ok(),
+            Operator::Gt => self
+                .predicate_builder
+                .greater_than(field.name(), datum)
+                .ok(),
+            Operator::GtEq => self
+                .predicate_builder
+                .greater_or_equal(field.name(), datum)
+                .ok(),
+            _ => None,
+        }
+    }
+
+    fn translate_in_list(&self, in_list: &InList) -> Option<Predicate> {
+        let field = self.resolve_field(in_list.expr.as_ref())?;
+        let literals: Option<Vec<_>> = in_list
+            .list
+            .iter()
+            .map(|expr| {
+                let scalar = extract_scalar_literal(expr)?;
+                scalar_to_datum(scalar, field.data_type())
+            })
+            .collect();
+        let literals = literals?;
+
+        if in_list.negated {
+            self.predicate_builder
+                .is_not_in(field.name(), literals)
+                .ok()
+        } else {
+            self.predicate_builder.is_in(field.name(), literals).ok()
+        }
+    }
+
+    fn translate_between(&self, between: &Between) -> Option<Predicate> {
+        let field = self.resolve_field(between.expr.as_ref())?;
+        let low = scalar_to_datum(
+            extract_scalar_literal(between.low.as_ref())?,
+            field.data_type(),
+        )?;
+        let high = scalar_to_datum(
+            extract_scalar_literal(between.high.as_ref())?,
+            field.data_type(),
+        )?;
+
+        let predicate = Predicate::and(vec![
+            self.predicate_builder
+                .greater_or_equal(field.name(), low)
+                .ok()?,
+            self.predicate_builder
+                .less_or_equal(field.name(), high)
+                .ok()?,
+        ]);
+
+        if between.negated {
+            // Same concern as Expr::Not: negation wraps in Predicate::Not
+            // which has incorrect NULL semantics for Exact pushdown.
+            None
+        } else {
+            Some(predicate)
+        }
+    }
+
+    fn resolve_field(&self, expr: &Expr) -> Option<&'a DataField> {
+        let Expr::Column(Column { name, .. }) = expr else {
+            return None;
+        };
+
+        if let Some(allowed_columns) = self.allowed_columns {
+            if !allowed_columns.iter().any(|column| column == name) {
+                return None;
+            }
+        }
+
+        self.fields.iter().find(|field| field.name() == name)
+    }
+}
+
+fn extract_scalar_literal(expr: &Expr) -> Option<&ScalarValue> {
+    match expr {
+        Expr::Literal(scalar, _) if !scalar.is_null() => Some(scalar),
+        _ => None,
+    }
+}
+
+fn reverse_comparison_operator(op: Operator) -> Option<Operator> {
+    match op {
+        Operator::Eq => Some(Operator::Eq),
+        Operator::NotEq => Some(Operator::NotEq),
+        Operator::Lt => Some(Operator::Gt),
+        Operator::LtEq => Some(Operator::GtEq),
+        Operator::Gt => Some(Operator::Lt),
+        Operator::GtEq => Some(Operator::LtEq),
+        _ => None,
+    }
+}
+
+fn scalar_to_datum(scalar: &ScalarValue, data_type: &DataType) -> 
Option<Datum> {
+    match data_type {
+        DataType::Boolean(_) => match scalar {
+            ScalarValue::Boolean(Some(value)) => Some(Datum::Bool(*value)),
+            _ => None,
+        },
+        DataType::TinyInt(_) => scalar_to_i128(scalar)
+            .and_then(|value| i8::try_from(value).ok())
+            .map(Datum::TinyInt),
+        DataType::SmallInt(_) => scalar_to_i128(scalar)
+            .and_then(|value| i16::try_from(value).ok())
+            .map(Datum::SmallInt),
+        DataType::Int(_) => scalar_to_i128(scalar)
+            .and_then(|value| i32::try_from(value).ok())
+            .map(Datum::Int),
+        DataType::BigInt(_) => scalar_to_i128(scalar)
+            .and_then(|value| i64::try_from(value).ok())
+            .map(Datum::Long),
+        DataType::Float(_) => match scalar {
+            ScalarValue::Float32(Some(value)) => Some(Datum::Float(*value)),
+            _ => None,
+        },
+        DataType::Double(_) => match scalar {
+            ScalarValue::Float64(Some(value)) => Some(Datum::Double(*value)),
+            ScalarValue::Float32(Some(value)) => Some(Datum::Double(*value as 
f64)),
+            _ => None,
+        },
+        DataType::Char(_) | DataType::VarChar(_) => match scalar {
+            ScalarValue::Utf8(Some(value))
+            | ScalarValue::Utf8View(Some(value))
+            | ScalarValue::LargeUtf8(Some(value)) => 
Some(Datum::String(value.clone())),
+            _ => None,
+        },
+        DataType::Date(_) => match scalar {
+            ScalarValue::Date32(Some(value)) => Some(Datum::Date(*value)),
+            _ => None,
+        },
+        DataType::Decimal(decimal) => match scalar {
+            ScalarValue::Decimal128(Some(unscaled), precision, scale)
+                if u32::from(*precision) <= decimal.precision() && 
i32::from(*scale) >= 0 =>
+            {
+                let scale = u32::try_from(i32::from(*scale)).ok()?;
+                if scale != decimal.scale() {
+                    return None;
+                }
+                Some(Datum::Decimal {
+                    unscaled: *unscaled,
+                    precision: decimal.precision(),
+                    scale: decimal.scale(),
+                })
+            }
+            _ => None,
+        },
+        DataType::Binary(_) | DataType::VarBinary(_) => match scalar {
+            ScalarValue::Binary(Some(value))
+            | ScalarValue::BinaryView(Some(value))
+            | ScalarValue::LargeBinary(Some(value)) => 
Some(Datum::Bytes(value.clone())),
+            ScalarValue::FixedSizeBinary(_, Some(value)) => 
Some(Datum::Bytes(value.clone())),
+            _ => None,
+        },
+        _ => None,
+    }
+}
+
+fn scalar_to_i128(scalar: &ScalarValue) -> Option<i128> {
+    match scalar {
+        ScalarValue::Int8(Some(value)) => Some(i128::from(*value)),
+        ScalarValue::Int16(Some(value)) => Some(i128::from(*value)),
+        ScalarValue::Int32(Some(value)) => Some(i128::from(*value)),
+        ScalarValue::Int64(Some(value)) => Some(i128::from(*value)),
+        ScalarValue::UInt8(Some(value)) => Some(i128::from(*value)),
+        ScalarValue::UInt16(Some(value)) => Some(i128::from(*value)),
+        ScalarValue::UInt32(Some(value)) => Some(i128::from(*value)),
+        ScalarValue::UInt64(Some(value)) => Some(i128::from(*value)),
+        _ => None,
+    }
+}
+
+#[cfg(test)]
+mod tests {
+    use super::*;
+    use datafusion::common::Column;
+    use datafusion::logical_expr::{expr::InList, lit, 
TableProviderFilterPushDown};
+    use paimon::spec::{IntType, VarCharType};
+
+    fn test_fields() -> Vec<DataField> {
+        vec![
+            DataField::new(0, "id".to_string(), DataType::Int(IntType::new())),
+            DataField::new(
+                1,
+                "dt".to_string(),
+                DataType::VarChar(VarCharType::string_type()),
+            ),
+            DataField::new(2, "hr".to_string(), DataType::Int(IntType::new())),
+        ]
+    }
+
+    fn partition_keys() -> Vec<String> {
+        vec!["dt".to_string(), "hr".to_string()]
+    }
+
+    #[test]
+    fn test_translate_partition_equality_filter() {
+        let fields = test_fields();
+        let filter = 
Expr::Column(Column::from_name("dt")).eq(lit("2024-01-01"));
+
+        let predicate =
+            build_pushed_predicate(&[filter], &fields).expect("partition 
filter should translate");
+
+        assert_eq!(predicate.to_string(), "dt = '2024-01-01'");
+    }
+
+    #[test]
+    fn test_classify_partition_filter_as_exact() {
+        let fields = test_fields();
+        let filter = 
Expr::Column(Column::from_name("dt")).eq(lit("2024-01-01"));
+
+        assert_eq!(
+            classify_filter_pushdown(&filter, &fields, &partition_keys()),
+            TableProviderFilterPushDown::Exact
+        );
+    }
+
+    #[test]
+    fn test_translate_reversed_partition_comparison() {
+        let fields = test_fields();
+        let filter = lit(10).lt(Expr::Column(Column::from_name("hr")));
+
+        let predicate = build_pushed_predicate(&[filter], &fields)
+            .expect("reversed comparison should translate");
+
+        assert_eq!(predicate.to_string(), "hr > 10");
+    }
+
+    #[test]
+    fn test_translate_partition_in_list() {
+        let fields = test_fields();
+        let filter = Expr::InList(InList::new(
+            Box::new(Expr::Column(Column::from_name("dt"))),
+            vec![lit("2024-01-01"), lit("2024-01-02")],
+            false,
+        ));
+
+        let predicate =
+            build_pushed_predicate(&[filter], &fields).expect("in-list filter 
should translate");
+
+        assert_eq!(predicate.to_string(), "dt IN ('2024-01-01', 
'2024-01-02')");
+    }
+
+    #[test]
+    fn test_translate_mixed_or_filter() {
+        let fields = test_fields();
+        let filter = Expr::Column(Column::from_name("dt"))
+            .eq(lit("2024-01-01"))
+            .or(Expr::Column(Column::from_name("id")).gt(lit(10)));
+
+        let predicate =
+            build_pushed_predicate(&[filter], &fields).expect("mixed OR filter 
should translate");
+
+        assert_eq!(predicate.to_string(), "(dt = '2024-01-01' OR id > 10)");
+    }
+
+    #[test]
+    fn test_translate_non_partition_filter() {
+        let fields = test_fields();
+        let filter = Expr::Column(Column::from_name("id")).gt(lit(10));
+
+        let predicate =
+            build_pushed_predicate(&[filter], &fields).expect("data filter 
should translate");
+
+        assert_eq!(predicate.to_string(), "id > 10");
+    }
+
+    #[test]
+    fn test_classify_non_partition_filter_as_inexact() {
+        let fields = test_fields();
+        let filter = Expr::Column(Column::from_name("id")).gt(lit(10));
+
+        assert_eq!(
+            classify_filter_pushdown(&filter, &fields, &partition_keys()),
+            TableProviderFilterPushDown::Inexact
+        );
+    }
+
+    #[test]
+    fn test_translate_mixed_and_filter() {
+        let fields = test_fields();
+        let filter = Expr::Column(Column::from_name("dt"))
+            .eq(lit("2024-01-01"))
+            .and(Expr::Column(Column::from_name("id")).gt(lit(10)));
+
+        let predicate =
+            build_pushed_predicate(&[filter], &fields).expect("mixed filter 
should translate");
+
+        assert_eq!(predicate.to_string(), "(dt = '2024-01-01' AND id > 10)");
+    }
+
+    #[test]
+    fn test_classify_mixed_and_filter_as_inexact() {
+        let fields = test_fields();
+        let filter = Expr::Column(Column::from_name("dt"))
+            .eq(lit("2024-01-01"))
+            .and(Expr::Column(Column::from_name("id")).gt(lit(10)));
+
+        assert_eq!(
+            classify_filter_pushdown(&filter, &fields, &partition_keys()),
+            TableProviderFilterPushDown::Inexact
+        );
+    }
+
+    #[test]
+    fn test_translate_not_is_not_supported() {
+        let fields = test_fields();
+        let filter = Expr::Not(Box::new(
+            Expr::Column(Column::from_name("dt")).eq(lit("2024-01-01")),
+        ));
+
+        assert!(
+            build_pushed_predicate(&[filter], &fields).is_none(),
+            "NOT expressions should not translate due to NULL semantics"
+        );
+    }
+
+    #[test]
+    fn test_classify_not_filter_as_unsupported() {
+        let fields = test_fields();
+        let filter = Expr::Not(Box::new(
+            Expr::Column(Column::from_name("dt")).eq(lit("2024-01-01")),
+        ));
+
+        assert_eq!(
+            classify_filter_pushdown(&filter, &fields, &partition_keys()),
+            TableProviderFilterPushDown::Unsupported
+        );
+    }
+
+    #[test]
+    fn test_translate_negated_between_is_not_supported() {
+        let fields = test_fields();
+        let filter = Expr::Between(Between::new(
+            Box::new(Expr::Column(Column::from_name("hr"))),
+            true, // negated
+            Box::new(lit(1)),
+            Box::new(lit(20)),
+        ));
+
+        assert!(
+            build_pushed_predicate(&[filter], &fields).is_none(),
+            "Negated BETWEEN should not translate due to NULL semantics"
+        );
+    }
+
+    #[test]
+    fn test_translate_boolean_literal_is_not_supported() {
+        let fields = test_fields();
+
+        for value in [true, false] {
+            let filter = Expr::Literal(ScalarValue::Boolean(Some(value)), 
None);
+            assert!(
+                build_pushed_predicate(&[filter], &fields).is_none(),
+                "Boolean literal ({value}) is not a partition predicate and 
must not be translated"
+            );
+        }
+    }
+}
diff --git a/crates/integrations/datafusion/src/lib.rs 
b/crates/integrations/datafusion/src/lib.rs
index e9f8955..40639e2 100644
--- a/crates/integrations/datafusion/src/lib.rs
+++ b/crates/integrations/datafusion/src/lib.rs
@@ -33,10 +33,12 @@
 //! let df = ctx.sql("SELECT * FROM my_table").await?;
 //! ```
 //!
-//! This version does not support write or predicate pushdown.
+//! This version supports partition predicate pushdown by extracting
+//! translatable partition-only conjuncts from DataFusion filters.
 
 mod catalog;
 mod error;
+mod filter_pushdown;
 mod physical_plan;
 mod schema;
 mod table;
diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs 
b/crates/integrations/datafusion/src/physical_plan/scan.rs
index dd27612..52f5357 100644
--- a/crates/integrations/datafusion/src/physical_plan/scan.rs
+++ b/crates/integrations/datafusion/src/physical_plan/scan.rs
@@ -72,6 +72,11 @@ impl PaimonTableScan {
     pub fn table(&self) -> &Table {
         &self.table
     }
+
+    #[cfg(test)]
+    pub(crate) fn planned_partitions(&self) -> &[Arc<[DataSplit]>] {
+        &self.planned_partitions
+    }
 }
 
 impl ExecutionPlan for PaimonTableScan {
diff --git a/crates/integrations/datafusion/src/table/mod.rs 
b/crates/integrations/datafusion/src/table/mod.rs
index 2e0a49e..6ea4803 100644
--- a/crates/integrations/datafusion/src/table/mod.rs
+++ b/crates/integrations/datafusion/src/table/mod.rs
@@ -25,18 +25,19 @@ use datafusion::arrow::datatypes::{Field, Schema, SchemaRef 
as ArrowSchemaRef};
 use datafusion::catalog::Session;
 use datafusion::datasource::{TableProvider, TableType};
 use datafusion::error::Result as DFResult;
-use datafusion::logical_expr::Expr;
+use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
 use datafusion::physical_plan::ExecutionPlan;
 use paimon::table::Table;
 
 use crate::error::to_datafusion_error;
+use crate::filter_pushdown::{build_pushed_predicate, classify_filter_pushdown};
 use crate::physical_plan::PaimonTableScan;
 use crate::schema::paimon_schema_to_arrow;
 
 /// Read-only table provider for a Paimon table.
 ///
-/// Supports full table scan and column projection. Predicate pushdown and 
writes
-/// are not yet supported.
+/// Supports full table scan, column projection, and partition predicate 
pushdown.
+/// Data-level filtering remains a residual DataFusion filter.
 #[derive(Debug, Clone)]
 pub struct PaimonTableProvider {
     table: Table,
@@ -81,11 +82,24 @@ impl TableProvider for PaimonTableProvider {
         TableType::Base
     }
 
+    fn supports_filters_pushdown(
+        &self,
+        filters: &[&Expr],
+    ) -> DFResult<Vec<TableProviderFilterPushDown>> {
+        let fields = self.table.schema().fields();
+        let partition_keys = self.table.schema().partition_keys();
+
+        Ok(filters
+            .iter()
+            .map(|filter| classify_filter_pushdown(filter, fields, 
partition_keys))
+            .collect())
+    }
+
     async fn scan(
         &self,
         state: &dyn Session,
         projection: Option<&Vec<usize>>,
-        _filters: &[Expr],
+        filters: &[Expr],
         _limit: Option<usize>,
     ) -> DFResult<Arc<dyn ExecutionPlan>> {
         // Convert projection indices to column names and compute projected 
schema
@@ -101,7 +115,10 @@ impl TableProvider for PaimonTableProvider {
         };
 
         // Plan splits eagerly so we know partition count upfront.
-        let read_builder = self.table.new_read_builder();
+        let mut read_builder = self.table.new_read_builder();
+        if let Some(filter) = build_pushed_predicate(filters, 
self.table.schema().fields()) {
+            read_builder.with_filter(filter);
+        }
         let scan = read_builder.new_scan();
         let plan = scan.plan().await.map_err(to_datafusion_error)?;
 
@@ -133,6 +150,16 @@ impl TableProvider for PaimonTableProvider {
 #[cfg(test)]
 mod tests {
     use super::*;
+    use std::collections::BTreeSet;
+    use std::sync::Arc;
+
+    use datafusion::datasource::TableProvider;
+    use datafusion::logical_expr::{col, lit, Expr};
+    use datafusion::prelude::{SessionConfig, SessionContext};
+    use paimon::catalog::Identifier;
+    use paimon::{Catalog, CatalogOptions, DataSplit, FileSystemCatalog, 
Options};
+
+    use crate::physical_plan::PaimonTableScan;
 
     #[test]
     fn test_bucket_round_robin_distributes_evenly() {
@@ -151,4 +178,131 @@ mod tests {
         let result = bucket_round_robin(vec![1, 2, 3], 1);
         assert_eq!(result, vec![vec![1, 2, 3]]);
     }
+
+    fn get_test_warehouse() -> String {
+        std::env::var("PAIMON_TEST_WAREHOUSE")
+            .unwrap_or_else(|_| "/tmp/paimon-warehouse".to_string())
+    }
+
+    fn create_catalog() -> FileSystemCatalog {
+        let warehouse = get_test_warehouse();
+        let mut options = Options::new();
+        options.set(CatalogOptions::WAREHOUSE, warehouse);
+        FileSystemCatalog::new(options).expect("Failed to create catalog")
+    }
+
+    async fn create_provider(table_name: &str) -> PaimonTableProvider {
+        let catalog = create_catalog();
+        let identifier = Identifier::new("default", table_name);
+        let table = catalog
+            .get_table(&identifier)
+            .await
+            .expect("Failed to get table");
+
+        PaimonTableProvider::try_new(table).expect("Failed to create table 
provider")
+    }
+
+    async fn plan_partitions(
+        provider: &PaimonTableProvider,
+        filters: Vec<Expr>,
+    ) -> Vec<Arc<[DataSplit]>> {
+        let config = SessionConfig::new().with_target_partitions(8);
+        let ctx = SessionContext::new_with_config(config);
+        let state = ctx.state();
+        let plan = provider
+            .scan(&state, None, &filters, None)
+            .await
+            .expect("scan() should succeed");
+        let scan = plan
+            .as_any()
+            .downcast_ref::<PaimonTableScan>()
+            .expect("Expected PaimonTableScan");
+
+        scan.planned_partitions().to_vec()
+    }
+
+    fn extract_dt_partition_set(planned_partitions: &[Arc<[DataSplit]>]) -> 
BTreeSet<String> {
+        planned_partitions
+            .iter()
+            .flat_map(|splits| splits.iter())
+            .map(|split| {
+                split
+                    .partition()
+                    .get_string(0)
+                    .expect("Failed to decode dt")
+                    .to_string()
+            })
+            .collect()
+    }
+
+    fn extract_dt_hr_partition_set(
+        planned_partitions: &[Arc<[DataSplit]>],
+    ) -> BTreeSet<(String, i32)> {
+        planned_partitions
+            .iter()
+            .flat_map(|splits| splits.iter())
+            .map(|split| {
+                let partition = split.partition();
+                (
+                    partition
+                        .get_string(0)
+                        .expect("Failed to decode dt")
+                        .to_string(),
+                    partition.get_int(1).expect("Failed to decode hr"),
+                )
+            })
+            .collect()
+    }
+
+    #[tokio::test]
+    async fn test_scan_partition_filter_plans_matching_partition_set() {
+        let provider = create_provider("partitioned_log_table").await;
+        let planned_partitions =
+            plan_partitions(&provider, 
vec![col("dt").eq(lit("2024-01-01"))]).await;
+
+        assert_eq!(
+            extract_dt_partition_set(&planned_partitions),
+            BTreeSet::from(["2024-01-01".to_string()]),
+        );
+    }
+
+    #[tokio::test]
+    async fn test_scan_mixed_and_filter_keeps_partition_pruning() {
+        let provider = create_provider("partitioned_log_table").await;
+        let planned_partitions = plan_partitions(
+            &provider,
+            vec![col("dt").eq(lit("2024-01-01")).and(col("id").gt(lit(1)))],
+        )
+        .await;
+
+        assert_eq!(
+            extract_dt_partition_set(&planned_partitions),
+            BTreeSet::from(["2024-01-01".to_string()]),
+        );
+    }
+
+    #[tokio::test]
+    async fn test_scan_multi_partition_filter_plans_exact_partition_set() {
+        let provider = create_provider("multi_partitioned_log_table").await;
+
+        let dt_only_partitions =
+            plan_partitions(&provider, 
vec![col("dt").eq(lit("2024-01-01"))]).await;
+        let dt_hr_partitions = plan_partitions(
+            &provider,
+            vec![col("dt").eq(lit("2024-01-01")).and(col("hr").eq(lit(10)))],
+        )
+        .await;
+
+        assert_eq!(
+            extract_dt_hr_partition_set(&dt_only_partitions),
+            BTreeSet::from([
+                ("2024-01-01".to_string(), 10),
+                ("2024-01-01".to_string(), 20),
+            ]),
+        );
+        assert_eq!(
+            extract_dt_hr_partition_set(&dt_hr_partitions),
+            BTreeSet::from([("2024-01-01".to_string(), 10)]),
+        );
+    }
 }
diff --git a/crates/integrations/datafusion/tests/read_tables.rs 
b/crates/integrations/datafusion/tests/read_tables.rs
index be43ded..68ed58f 100644
--- a/crates/integrations/datafusion/tests/read_tables.rs
+++ b/crates/integrations/datafusion/tests/read_tables.rs
@@ -19,7 +19,9 @@ use std::sync::Arc;
 
 use datafusion::arrow::array::{Int32Array, StringArray};
 use datafusion::catalog::CatalogProvider;
-use datafusion::prelude::SessionContext;
+use datafusion::datasource::TableProvider;
+use datafusion::logical_expr::{col, lit, TableProviderFilterPushDown};
+use datafusion::prelude::{SessionConfig, SessionContext};
 use paimon::catalog::Identifier;
 use paimon::{Catalog, CatalogOptions, FileSystemCatalog, Options};
 use paimon_datafusion::{PaimonCatalogProvider, PaimonTableProvider};
@@ -36,6 +38,15 @@ fn create_catalog() -> FileSystemCatalog {
 }
 
 async fn create_context(table_name: &str) -> SessionContext {
+    let provider = create_provider(table_name).await;
+    let ctx = SessionContext::new();
+    ctx.register_table(table_name, Arc::new(provider))
+        .expect("Failed to register table");
+
+    ctx
+}
+
+async fn create_provider(table_name: &str) -> PaimonTableProvider {
     let catalog = create_catalog();
     let identifier = Identifier::new("default", table_name);
     let table = catalog
@@ -43,12 +54,7 @@ async fn create_context(table_name: &str) -> SessionContext {
         .await
         .expect("Failed to get table");
 
-    let provider = PaimonTableProvider::try_new(table).expect("Failed to 
create table provider");
-    let ctx = SessionContext::new();
-    ctx.register_table(table_name, Arc::new(provider))
-        .expect("Failed to register table");
-
-    ctx
+    PaimonTableProvider::try_new(table).expect("Failed to create table 
provider")
 }
 
 async fn read_rows(table_name: &str) -> Vec<(i32, String)> {
@@ -61,8 +67,25 @@ async fn read_rows(table_name: &str) -> Vec<(i32, String)> {
         "Expected at least one batch from table {table_name}"
     );
 
-    let mut actual_rows = Vec::new();
-    for batch in &batches {
+    let mut actual_rows = extract_id_name_rows(&batches);
+    actual_rows.sort_by_key(|(id, _)| *id);
+    actual_rows
+}
+
+async fn collect_query(
+    table_name: &str,
+    sql: &str,
+) -> 
datafusion::error::Result<Vec<datafusion::arrow::record_batch::RecordBatch>> {
+    let ctx = create_context(table_name).await;
+
+    ctx.sql(sql).await?.collect().await
+}
+
+fn extract_id_name_rows(
+    batches: &[datafusion::arrow::record_batch::RecordBatch],
+) -> Vec<(i32, String)> {
+    let mut rows = Vec::new();
+    for batch in batches {
         let id_array = batch
             .column_by_name("id")
             .and_then(|column| column.as_any().downcast_ref::<Int32Array>())
@@ -73,24 +96,13 @@ async fn read_rows(table_name: &str) -> Vec<(i32, String)> {
             .expect("Expected StringArray for name column");
 
         for row_index in 0..batch.num_rows() {
-            actual_rows.push((
+            rows.push((
                 id_array.value(row_index),
                 name_array.value(row_index).to_string(),
             ));
         }
     }
-
-    actual_rows.sort_by_key(|(id, _)| *id);
-    actual_rows
-}
-
-async fn collect_query(
-    table_name: &str,
-    sql: &str,
-) -> 
datafusion::error::Result<Vec<datafusion::arrow::record_batch::RecordBatch>> {
-    let ctx = create_context(table_name).await;
-
-    ctx.sql(sql).await?.collect().await
+    rows
 }
 
 #[tokio::test]
@@ -164,22 +176,33 @@ async fn test_projection_via_datafusion() {
     );
 }
 
+#[tokio::test]
+async fn test_supports_partition_filters_pushdown() {
+    let provider = create_provider("multi_partitioned_log_table").await;
+    let partition_filter = col("dt").eq(lit("2024-01-01"));
+    let mixed_and_filter = 
col("dt").eq(lit("2024-01-01")).and(col("id").gt(lit(1)));
+    let data_filter = col("id").gt(lit(1));
+
+    let supports = provider
+        .supports_filters_pushdown(&[&partition_filter, &mixed_and_filter, 
&data_filter])
+        .expect("supports_filters_pushdown should succeed");
+
+    assert_eq!(
+        supports,
+        vec![
+            TableProviderFilterPushDown::Exact,
+            TableProviderFilterPushDown::Inexact,
+            TableProviderFilterPushDown::Inexact,
+        ]
+    );
+}
+
 /// Verifies that `PaimonTableProvider::scan()` produces more than one
 /// execution partition for a multi-partition table, and that the reported
 /// partition count is still capped by `target_partitions`.
 #[tokio::test]
 async fn test_scan_partition_count_respects_session_config() {
-    use datafusion::datasource::TableProvider;
-    use datafusion::prelude::SessionConfig;
-
-    let catalog = create_catalog();
-    let identifier = Identifier::new("default", "partitioned_log_table");
-    let table = catalog
-        .get_table(&identifier)
-        .await
-        .expect("Failed to get table");
-
-    let provider = PaimonTableProvider::try_new(table).expect("Failed to 
create table provider");
+    let provider = create_provider("partitioned_log_table").await;
 
     // With generous target_partitions, the plan should expose more than one 
partition.
     let config = SessionConfig::new().with_target_partitions(8);
@@ -215,6 +238,54 @@ async fn 
test_scan_partition_count_respects_session_config() {
     );
 }
 
+#[tokio::test]
+async fn test_partition_filter_query_via_datafusion() {
+    let batches = collect_query(
+        "partitioned_log_table",
+        "SELECT id, name FROM partitioned_log_table WHERE dt = '2024-01-01'",
+    )
+    .await
+    .expect("Partition filter query should succeed");
+
+    let mut actual_rows = extract_id_name_rows(&batches);
+    actual_rows.sort_by_key(|(id, _)| *id);
+    assert_eq!(
+        actual_rows,
+        vec![(1, "alice".to_string()), (2, "bob".to_string())]
+    );
+}
+
+#[tokio::test]
+async fn test_multi_partition_filter_query_via_datafusion() {
+    let batches = collect_query(
+        "multi_partitioned_log_table",
+        "SELECT id, name FROM multi_partitioned_log_table WHERE dt = 
'2024-01-01' AND hr = 10",
+    )
+    .await
+    .expect("Multi-partition filter query should succeed");
+
+    let mut actual_rows = extract_id_name_rows(&batches);
+    actual_rows.sort_by_key(|(id, _)| *id);
+    assert_eq!(
+        actual_rows,
+        vec![(1, "alice".to_string()), (2, "bob".to_string())]
+    );
+}
+
+#[tokio::test]
+async fn test_mixed_and_filter_keeps_residual_datafusion_filter() {
+    let batches = collect_query(
+        "partitioned_log_table",
+        "SELECT id, name FROM partitioned_log_table WHERE dt = '2024-01-01' 
AND id > 1",
+    )
+    .await
+    .expect("Mixed filter query should succeed");
+
+    let actual_rows = extract_id_name_rows(&batches);
+
+    assert_eq!(actual_rows, vec![(2, "bob".to_string())]);
+}
+
 // ======================= Catalog Provider Tests =======================
 #[tokio::test]
 async fn test_query_via_catalog_provider() {


Reply via email to