This is an automated email from the ASF dual-hosted git repository.
lzljs3620320 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/paimon-rust.git
The following commit(s) were added to refs/heads/main by this push:
new 10e190b feat(datafusion): support partition predicate pushdown (#190)
10e190b is described below
commit 10e190b10a0c8c588aeaccea86eb4a24a538253a
Author: Zach <[email protected]>
AuthorDate: Thu Apr 2 21:54:59 2026 +0800
feat(datafusion): support partition predicate pushdown (#190)
---
.../integrations/datafusion/src/filter_pushdown.rs | 536 +++++++++++++++++++++
crates/integrations/datafusion/src/lib.rs | 4 +-
.../datafusion/src/physical_plan/scan.rs | 5 +
crates/integrations/datafusion/src/table/mod.rs | 164 ++++++-
.../integrations/datafusion/tests/read_tables.rs | 137 ++++--
5 files changed, 807 insertions(+), 39 deletions(-)
diff --git a/crates/integrations/datafusion/src/filter_pushdown.rs
b/crates/integrations/datafusion/src/filter_pushdown.rs
new file mode 100644
index 0000000..91c65d3
--- /dev/null
+++ b/crates/integrations/datafusion/src/filter_pushdown.rs
@@ -0,0 +1,536 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use datafusion::common::{Column, ScalarValue};
+use datafusion::logical_expr::expr::InList;
+use datafusion::logical_expr::{Between, BinaryExpr, Expr, Operator,
TableProviderFilterPushDown};
+use paimon::spec::{DataField, DataType, Datum, Predicate, PredicateBuilder};
+
+pub(crate) fn classify_filter_pushdown(
+ filter: &Expr,
+ fields: &[DataField],
+ partition_keys: &[String],
+) -> TableProviderFilterPushDown {
+ let translator = FilterTranslator::new(fields);
+ if translator.translate(filter).is_some() {
+ let partition_translator =
FilterTranslator::for_allowed_columns(fields, partition_keys);
+ if partition_translator.translate(filter).is_some() {
+ TableProviderFilterPushDown::Exact
+ } else {
+ TableProviderFilterPushDown::Inexact
+ }
+ } else if split_conjunction(filter)
+ .into_iter()
+ .any(|expr| translator.translate(expr).is_some())
+ {
+ TableProviderFilterPushDown::Inexact
+ } else {
+ TableProviderFilterPushDown::Unsupported
+ }
+}
+
+pub(crate) fn build_pushed_predicate(filters: &[Expr], fields: &[DataField])
-> Option<Predicate> {
+ let translator = FilterTranslator::new(fields);
+ let pushed: Vec<_> = filters
+ .iter()
+ .flat_map(split_conjunction)
+ .filter_map(|filter| translator.translate(filter))
+ .collect();
+
+ if pushed.is_empty() {
+ None
+ } else {
+ Some(Predicate::and(pushed))
+ }
+}
+
+fn split_conjunction(expr: &Expr) -> Vec<&Expr> {
+ match expr {
+ Expr::BinaryExpr(BinaryExpr {
+ left,
+ op: Operator::And,
+ right,
+ }) => {
+ let mut conjuncts = split_conjunction(left.as_ref());
+ conjuncts.extend(split_conjunction(right.as_ref()));
+ conjuncts
+ }
+ other => vec![other],
+ }
+}
+
+struct FilterTranslator<'a> {
+ fields: &'a [DataField],
+ allowed_columns: Option<&'a [String]>,
+ predicate_builder: PredicateBuilder,
+}
+
+impl<'a> FilterTranslator<'a> {
+ fn new(fields: &'a [DataField]) -> Self {
+ Self {
+ fields,
+ allowed_columns: None,
+ predicate_builder: PredicateBuilder::new(fields),
+ }
+ }
+
+ fn for_allowed_columns(fields: &'a [DataField], allowed_columns: &'a
[String]) -> Self {
+ Self {
+ fields,
+ allowed_columns: Some(allowed_columns),
+ predicate_builder: PredicateBuilder::new(fields),
+ }
+ }
+
+ fn translate(&self, expr: &Expr) -> Option<Predicate> {
+ match expr {
+ Expr::BinaryExpr(binary) => self.translate_binary(binary),
+ // NOT is intentionally not translated: Predicate::Not uses
two-valued
+ // logic (!bool), which incorrectly returns true when the inner
predicate
+ // evaluates NULL to false. Combined with Exact pushdown precision,
+ // DataFusion would remove the residual filter, producing wrong
results.
+ Expr::Not(_) => None,
+ Expr::IsNull(inner) => {
+ let field = self.resolve_field(inner.as_ref())?;
+ self.predicate_builder.is_null(field.name()).ok()
+ }
+ Expr::IsNotNull(inner) => {
+ let field = self.resolve_field(inner.as_ref())?;
+ self.predicate_builder.is_not_null(field.name()).ok()
+ }
+ Expr::InList(in_list) => self.translate_in_list(in_list),
+ Expr::Between(between) => self.translate_between(between),
+ _ => None,
+ }
+ }
+
+ fn translate_binary(&self, binary: &BinaryExpr) -> Option<Predicate> {
+ match binary.op {
+ Operator::And => Some(Predicate::and(vec![
+ self.translate(binary.left.as_ref())?,
+ self.translate(binary.right.as_ref())?,
+ ])),
+ Operator::Or => Some(Predicate::or(vec![
+ self.translate(binary.left.as_ref())?,
+ self.translate(binary.right.as_ref())?,
+ ])),
+ Operator::Eq
+ | Operator::NotEq
+ | Operator::Lt
+ | Operator::LtEq
+ | Operator::Gt
+ | Operator::GtEq => self.translate_comparison(binary),
+ _ => None,
+ }
+ }
+
+ fn translate_comparison(&self, binary: &BinaryExpr) -> Option<Predicate> {
+ if let Some(predicate) = self.translate_column_literal_comparison(
+ binary.left.as_ref(),
+ binary.op,
+ binary.right.as_ref(),
+ ) {
+ return Some(predicate);
+ }
+
+ let reversed = reverse_comparison_operator(binary.op)?;
+ self.translate_column_literal_comparison(
+ binary.right.as_ref(),
+ reversed,
+ binary.left.as_ref(),
+ )
+ }
+
+ fn translate_column_literal_comparison(
+ &self,
+ column_expr: &Expr,
+ op: Operator,
+ literal_expr: &Expr,
+ ) -> Option<Predicate> {
+ let field = self.resolve_field(column_expr)?;
+ let scalar = extract_scalar_literal(literal_expr)?;
+ let datum = scalar_to_datum(scalar, field.data_type())?;
+
+ match op {
+ Operator::Eq => self.predicate_builder.equal(field.name(),
datum).ok(),
+ Operator::NotEq => self.predicate_builder.not_equal(field.name(),
datum).ok(),
+ Operator::Lt => self.predicate_builder.less_than(field.name(),
datum).ok(),
+ Operator::LtEq => self
+ .predicate_builder
+ .less_or_equal(field.name(), datum)
+ .ok(),
+ Operator::Gt => self
+ .predicate_builder
+ .greater_than(field.name(), datum)
+ .ok(),
+ Operator::GtEq => self
+ .predicate_builder
+ .greater_or_equal(field.name(), datum)
+ .ok(),
+ _ => None,
+ }
+ }
+
+ fn translate_in_list(&self, in_list: &InList) -> Option<Predicate> {
+ let field = self.resolve_field(in_list.expr.as_ref())?;
+ let literals: Option<Vec<_>> = in_list
+ .list
+ .iter()
+ .map(|expr| {
+ let scalar = extract_scalar_literal(expr)?;
+ scalar_to_datum(scalar, field.data_type())
+ })
+ .collect();
+ let literals = literals?;
+
+ if in_list.negated {
+ self.predicate_builder
+ .is_not_in(field.name(), literals)
+ .ok()
+ } else {
+ self.predicate_builder.is_in(field.name(), literals).ok()
+ }
+ }
+
+ fn translate_between(&self, between: &Between) -> Option<Predicate> {
+ let field = self.resolve_field(between.expr.as_ref())?;
+ let low = scalar_to_datum(
+ extract_scalar_literal(between.low.as_ref())?,
+ field.data_type(),
+ )?;
+ let high = scalar_to_datum(
+ extract_scalar_literal(between.high.as_ref())?,
+ field.data_type(),
+ )?;
+
+ let predicate = Predicate::and(vec![
+ self.predicate_builder
+ .greater_or_equal(field.name(), low)
+ .ok()?,
+ self.predicate_builder
+ .less_or_equal(field.name(), high)
+ .ok()?,
+ ]);
+
+ if between.negated {
+ // Same concern as Expr::Not: negation wraps in Predicate::Not
+ // which has incorrect NULL semantics for Exact pushdown.
+ None
+ } else {
+ Some(predicate)
+ }
+ }
+
+ fn resolve_field(&self, expr: &Expr) -> Option<&'a DataField> {
+ let Expr::Column(Column { name, .. }) = expr else {
+ return None;
+ };
+
+ if let Some(allowed_columns) = self.allowed_columns {
+ if !allowed_columns.iter().any(|column| column == name) {
+ return None;
+ }
+ }
+
+ self.fields.iter().find(|field| field.name() == name)
+ }
+}
+
+fn extract_scalar_literal(expr: &Expr) -> Option<&ScalarValue> {
+ match expr {
+ Expr::Literal(scalar, _) if !scalar.is_null() => Some(scalar),
+ _ => None,
+ }
+}
+
+fn reverse_comparison_operator(op: Operator) -> Option<Operator> {
+ match op {
+ Operator::Eq => Some(Operator::Eq),
+ Operator::NotEq => Some(Operator::NotEq),
+ Operator::Lt => Some(Operator::Gt),
+ Operator::LtEq => Some(Operator::GtEq),
+ Operator::Gt => Some(Operator::Lt),
+ Operator::GtEq => Some(Operator::LtEq),
+ _ => None,
+ }
+}
+
+fn scalar_to_datum(scalar: &ScalarValue, data_type: &DataType) ->
Option<Datum> {
+ match data_type {
+ DataType::Boolean(_) => match scalar {
+ ScalarValue::Boolean(Some(value)) => Some(Datum::Bool(*value)),
+ _ => None,
+ },
+ DataType::TinyInt(_) => scalar_to_i128(scalar)
+ .and_then(|value| i8::try_from(value).ok())
+ .map(Datum::TinyInt),
+ DataType::SmallInt(_) => scalar_to_i128(scalar)
+ .and_then(|value| i16::try_from(value).ok())
+ .map(Datum::SmallInt),
+ DataType::Int(_) => scalar_to_i128(scalar)
+ .and_then(|value| i32::try_from(value).ok())
+ .map(Datum::Int),
+ DataType::BigInt(_) => scalar_to_i128(scalar)
+ .and_then(|value| i64::try_from(value).ok())
+ .map(Datum::Long),
+ DataType::Float(_) => match scalar {
+ ScalarValue::Float32(Some(value)) => Some(Datum::Float(*value)),
+ _ => None,
+ },
+ DataType::Double(_) => match scalar {
+ ScalarValue::Float64(Some(value)) => Some(Datum::Double(*value)),
+ ScalarValue::Float32(Some(value)) => Some(Datum::Double(*value as
f64)),
+ _ => None,
+ },
+ DataType::Char(_) | DataType::VarChar(_) => match scalar {
+ ScalarValue::Utf8(Some(value))
+ | ScalarValue::Utf8View(Some(value))
+ | ScalarValue::LargeUtf8(Some(value)) =>
Some(Datum::String(value.clone())),
+ _ => None,
+ },
+ DataType::Date(_) => match scalar {
+ ScalarValue::Date32(Some(value)) => Some(Datum::Date(*value)),
+ _ => None,
+ },
+ DataType::Decimal(decimal) => match scalar {
+ ScalarValue::Decimal128(Some(unscaled), precision, scale)
+ if u32::from(*precision) <= decimal.precision() &&
i32::from(*scale) >= 0 =>
+ {
+ let scale = u32::try_from(i32::from(*scale)).ok()?;
+ if scale != decimal.scale() {
+ return None;
+ }
+ Some(Datum::Decimal {
+ unscaled: *unscaled,
+ precision: decimal.precision(),
+ scale: decimal.scale(),
+ })
+ }
+ _ => None,
+ },
+ DataType::Binary(_) | DataType::VarBinary(_) => match scalar {
+ ScalarValue::Binary(Some(value))
+ | ScalarValue::BinaryView(Some(value))
+ | ScalarValue::LargeBinary(Some(value)) =>
Some(Datum::Bytes(value.clone())),
+ ScalarValue::FixedSizeBinary(_, Some(value)) =>
Some(Datum::Bytes(value.clone())),
+ _ => None,
+ },
+ _ => None,
+ }
+}
+
+fn scalar_to_i128(scalar: &ScalarValue) -> Option<i128> {
+ match scalar {
+ ScalarValue::Int8(Some(value)) => Some(i128::from(*value)),
+ ScalarValue::Int16(Some(value)) => Some(i128::from(*value)),
+ ScalarValue::Int32(Some(value)) => Some(i128::from(*value)),
+ ScalarValue::Int64(Some(value)) => Some(i128::from(*value)),
+ ScalarValue::UInt8(Some(value)) => Some(i128::from(*value)),
+ ScalarValue::UInt16(Some(value)) => Some(i128::from(*value)),
+ ScalarValue::UInt32(Some(value)) => Some(i128::from(*value)),
+ ScalarValue::UInt64(Some(value)) => Some(i128::from(*value)),
+ _ => None,
+ }
+}
+
+#[cfg(test)]
+mod tests {
+ use super::*;
+ use datafusion::common::Column;
+ use datafusion::logical_expr::{expr::InList, lit,
TableProviderFilterPushDown};
+ use paimon::spec::{IntType, VarCharType};
+
+ fn test_fields() -> Vec<DataField> {
+ vec![
+ DataField::new(0, "id".to_string(), DataType::Int(IntType::new())),
+ DataField::new(
+ 1,
+ "dt".to_string(),
+ DataType::VarChar(VarCharType::string_type()),
+ ),
+ DataField::new(2, "hr".to_string(), DataType::Int(IntType::new())),
+ ]
+ }
+
+ fn partition_keys() -> Vec<String> {
+ vec!["dt".to_string(), "hr".to_string()]
+ }
+
+ #[test]
+ fn test_translate_partition_equality_filter() {
+ let fields = test_fields();
+ let filter =
Expr::Column(Column::from_name("dt")).eq(lit("2024-01-01"));
+
+ let predicate =
+ build_pushed_predicate(&[filter], &fields).expect("partition
filter should translate");
+
+ assert_eq!(predicate.to_string(), "dt = '2024-01-01'");
+ }
+
+ #[test]
+ fn test_classify_partition_filter_as_exact() {
+ let fields = test_fields();
+ let filter =
Expr::Column(Column::from_name("dt")).eq(lit("2024-01-01"));
+
+ assert_eq!(
+ classify_filter_pushdown(&filter, &fields, &partition_keys()),
+ TableProviderFilterPushDown::Exact
+ );
+ }
+
+ #[test]
+ fn test_translate_reversed_partition_comparison() {
+ let fields = test_fields();
+ let filter = lit(10).lt(Expr::Column(Column::from_name("hr")));
+
+ let predicate = build_pushed_predicate(&[filter], &fields)
+ .expect("reversed comparison should translate");
+
+ assert_eq!(predicate.to_string(), "hr > 10");
+ }
+
+ #[test]
+ fn test_translate_partition_in_list() {
+ let fields = test_fields();
+ let filter = Expr::InList(InList::new(
+ Box::new(Expr::Column(Column::from_name("dt"))),
+ vec![lit("2024-01-01"), lit("2024-01-02")],
+ false,
+ ));
+
+ let predicate =
+ build_pushed_predicate(&[filter], &fields).expect("in-list filter
should translate");
+
+ assert_eq!(predicate.to_string(), "dt IN ('2024-01-01',
'2024-01-02')");
+ }
+
+ #[test]
+ fn test_translate_mixed_or_filter() {
+ let fields = test_fields();
+ let filter = Expr::Column(Column::from_name("dt"))
+ .eq(lit("2024-01-01"))
+ .or(Expr::Column(Column::from_name("id")).gt(lit(10)));
+
+ let predicate =
+ build_pushed_predicate(&[filter], &fields).expect("mixed OR filter
should translate");
+
+ assert_eq!(predicate.to_string(), "(dt = '2024-01-01' OR id > 10)");
+ }
+
+ #[test]
+ fn test_translate_non_partition_filter() {
+ let fields = test_fields();
+ let filter = Expr::Column(Column::from_name("id")).gt(lit(10));
+
+ let predicate =
+ build_pushed_predicate(&[filter], &fields).expect("data filter
should translate");
+
+ assert_eq!(predicate.to_string(), "id > 10");
+ }
+
+ #[test]
+ fn test_classify_non_partition_filter_as_inexact() {
+ let fields = test_fields();
+ let filter = Expr::Column(Column::from_name("id")).gt(lit(10));
+
+ assert_eq!(
+ classify_filter_pushdown(&filter, &fields, &partition_keys()),
+ TableProviderFilterPushDown::Inexact
+ );
+ }
+
+ #[test]
+ fn test_translate_mixed_and_filter() {
+ let fields = test_fields();
+ let filter = Expr::Column(Column::from_name("dt"))
+ .eq(lit("2024-01-01"))
+ .and(Expr::Column(Column::from_name("id")).gt(lit(10)));
+
+ let predicate =
+ build_pushed_predicate(&[filter], &fields).expect("mixed filter
should translate");
+
+ assert_eq!(predicate.to_string(), "(dt = '2024-01-01' AND id > 10)");
+ }
+
+ #[test]
+ fn test_classify_mixed_and_filter_as_inexact() {
+ let fields = test_fields();
+ let filter = Expr::Column(Column::from_name("dt"))
+ .eq(lit("2024-01-01"))
+ .and(Expr::Column(Column::from_name("id")).gt(lit(10)));
+
+ assert_eq!(
+ classify_filter_pushdown(&filter, &fields, &partition_keys()),
+ TableProviderFilterPushDown::Inexact
+ );
+ }
+
+ #[test]
+ fn test_translate_not_is_not_supported() {
+ let fields = test_fields();
+ let filter = Expr::Not(Box::new(
+ Expr::Column(Column::from_name("dt")).eq(lit("2024-01-01")),
+ ));
+
+ assert!(
+ build_pushed_predicate(&[filter], &fields).is_none(),
+ "NOT expressions should not translate due to NULL semantics"
+ );
+ }
+
+ #[test]
+ fn test_classify_not_filter_as_unsupported() {
+ let fields = test_fields();
+ let filter = Expr::Not(Box::new(
+ Expr::Column(Column::from_name("dt")).eq(lit("2024-01-01")),
+ ));
+
+ assert_eq!(
+ classify_filter_pushdown(&filter, &fields, &partition_keys()),
+ TableProviderFilterPushDown::Unsupported
+ );
+ }
+
+ #[test]
+ fn test_translate_negated_between_is_not_supported() {
+ let fields = test_fields();
+ let filter = Expr::Between(Between::new(
+ Box::new(Expr::Column(Column::from_name("hr"))),
+ true, // negated
+ Box::new(lit(1)),
+ Box::new(lit(20)),
+ ));
+
+ assert!(
+ build_pushed_predicate(&[filter], &fields).is_none(),
+ "Negated BETWEEN should not translate due to NULL semantics"
+ );
+ }
+
+ #[test]
+ fn test_translate_boolean_literal_is_not_supported() {
+ let fields = test_fields();
+
+ for value in [true, false] {
+ let filter = Expr::Literal(ScalarValue::Boolean(Some(value)),
None);
+ assert!(
+ build_pushed_predicate(&[filter], &fields).is_none(),
+ "Boolean literal ({value}) is not a partition predicate and
must not be translated"
+ );
+ }
+ }
+}
diff --git a/crates/integrations/datafusion/src/lib.rs
b/crates/integrations/datafusion/src/lib.rs
index e9f8955..40639e2 100644
--- a/crates/integrations/datafusion/src/lib.rs
+++ b/crates/integrations/datafusion/src/lib.rs
@@ -33,10 +33,12 @@
//! let df = ctx.sql("SELECT * FROM my_table").await?;
//! ```
//!
-//! This version does not support write or predicate pushdown.
+//! This version supports partition predicate pushdown by extracting
+//! translatable partition-only conjuncts from DataFusion filters.
mod catalog;
mod error;
+mod filter_pushdown;
mod physical_plan;
mod schema;
mod table;
diff --git a/crates/integrations/datafusion/src/physical_plan/scan.rs
b/crates/integrations/datafusion/src/physical_plan/scan.rs
index dd27612..52f5357 100644
--- a/crates/integrations/datafusion/src/physical_plan/scan.rs
+++ b/crates/integrations/datafusion/src/physical_plan/scan.rs
@@ -72,6 +72,11 @@ impl PaimonTableScan {
pub fn table(&self) -> &Table {
&self.table
}
+
+ #[cfg(test)]
+ pub(crate) fn planned_partitions(&self) -> &[Arc<[DataSplit]>] {
+ &self.planned_partitions
+ }
}
impl ExecutionPlan for PaimonTableScan {
diff --git a/crates/integrations/datafusion/src/table/mod.rs
b/crates/integrations/datafusion/src/table/mod.rs
index 2e0a49e..6ea4803 100644
--- a/crates/integrations/datafusion/src/table/mod.rs
+++ b/crates/integrations/datafusion/src/table/mod.rs
@@ -25,18 +25,19 @@ use datafusion::arrow::datatypes::{Field, Schema, SchemaRef
as ArrowSchemaRef};
use datafusion::catalog::Session;
use datafusion::datasource::{TableProvider, TableType};
use datafusion::error::Result as DFResult;
-use datafusion::logical_expr::Expr;
+use datafusion::logical_expr::{Expr, TableProviderFilterPushDown};
use datafusion::physical_plan::ExecutionPlan;
use paimon::table::Table;
use crate::error::to_datafusion_error;
+use crate::filter_pushdown::{build_pushed_predicate, classify_filter_pushdown};
use crate::physical_plan::PaimonTableScan;
use crate::schema::paimon_schema_to_arrow;
/// Read-only table provider for a Paimon table.
///
-/// Supports full table scan and column projection. Predicate pushdown and
writes
-/// are not yet supported.
+/// Supports full table scan, column projection, and partition predicate
pushdown.
+/// Data-level filtering remains a residual DataFusion filter.
#[derive(Debug, Clone)]
pub struct PaimonTableProvider {
table: Table,
@@ -81,11 +82,24 @@ impl TableProvider for PaimonTableProvider {
TableType::Base
}
+ fn supports_filters_pushdown(
+ &self,
+ filters: &[&Expr],
+ ) -> DFResult<Vec<TableProviderFilterPushDown>> {
+ let fields = self.table.schema().fields();
+ let partition_keys = self.table.schema().partition_keys();
+
+ Ok(filters
+ .iter()
+ .map(|filter| classify_filter_pushdown(filter, fields,
partition_keys))
+ .collect())
+ }
+
async fn scan(
&self,
state: &dyn Session,
projection: Option<&Vec<usize>>,
- _filters: &[Expr],
+ filters: &[Expr],
_limit: Option<usize>,
) -> DFResult<Arc<dyn ExecutionPlan>> {
// Convert projection indices to column names and compute projected
schema
@@ -101,7 +115,10 @@ impl TableProvider for PaimonTableProvider {
};
// Plan splits eagerly so we know partition count upfront.
- let read_builder = self.table.new_read_builder();
+ let mut read_builder = self.table.new_read_builder();
+ if let Some(filter) = build_pushed_predicate(filters,
self.table.schema().fields()) {
+ read_builder.with_filter(filter);
+ }
let scan = read_builder.new_scan();
let plan = scan.plan().await.map_err(to_datafusion_error)?;
@@ -133,6 +150,16 @@ impl TableProvider for PaimonTableProvider {
#[cfg(test)]
mod tests {
use super::*;
+ use std::collections::BTreeSet;
+ use std::sync::Arc;
+
+ use datafusion::datasource::TableProvider;
+ use datafusion::logical_expr::{col, lit, Expr};
+ use datafusion::prelude::{SessionConfig, SessionContext};
+ use paimon::catalog::Identifier;
+ use paimon::{Catalog, CatalogOptions, DataSplit, FileSystemCatalog,
Options};
+
+ use crate::physical_plan::PaimonTableScan;
#[test]
fn test_bucket_round_robin_distributes_evenly() {
@@ -151,4 +178,131 @@ mod tests {
let result = bucket_round_robin(vec![1, 2, 3], 1);
assert_eq!(result, vec![vec![1, 2, 3]]);
}
+
+ fn get_test_warehouse() -> String {
+ std::env::var("PAIMON_TEST_WAREHOUSE")
+ .unwrap_or_else(|_| "/tmp/paimon-warehouse".to_string())
+ }
+
+ fn create_catalog() -> FileSystemCatalog {
+ let warehouse = get_test_warehouse();
+ let mut options = Options::new();
+ options.set(CatalogOptions::WAREHOUSE, warehouse);
+ FileSystemCatalog::new(options).expect("Failed to create catalog")
+ }
+
+ async fn create_provider(table_name: &str) -> PaimonTableProvider {
+ let catalog = create_catalog();
+ let identifier = Identifier::new("default", table_name);
+ let table = catalog
+ .get_table(&identifier)
+ .await
+ .expect("Failed to get table");
+
+ PaimonTableProvider::try_new(table).expect("Failed to create table
provider")
+ }
+
+ async fn plan_partitions(
+ provider: &PaimonTableProvider,
+ filters: Vec<Expr>,
+ ) -> Vec<Arc<[DataSplit]>> {
+ let config = SessionConfig::new().with_target_partitions(8);
+ let ctx = SessionContext::new_with_config(config);
+ let state = ctx.state();
+ let plan = provider
+ .scan(&state, None, &filters, None)
+ .await
+ .expect("scan() should succeed");
+ let scan = plan
+ .as_any()
+ .downcast_ref::<PaimonTableScan>()
+ .expect("Expected PaimonTableScan");
+
+ scan.planned_partitions().to_vec()
+ }
+
+ fn extract_dt_partition_set(planned_partitions: &[Arc<[DataSplit]>]) ->
BTreeSet<String> {
+ planned_partitions
+ .iter()
+ .flat_map(|splits| splits.iter())
+ .map(|split| {
+ split
+ .partition()
+ .get_string(0)
+ .expect("Failed to decode dt")
+ .to_string()
+ })
+ .collect()
+ }
+
+ fn extract_dt_hr_partition_set(
+ planned_partitions: &[Arc<[DataSplit]>],
+ ) -> BTreeSet<(String, i32)> {
+ planned_partitions
+ .iter()
+ .flat_map(|splits| splits.iter())
+ .map(|split| {
+ let partition = split.partition();
+ (
+ partition
+ .get_string(0)
+ .expect("Failed to decode dt")
+ .to_string(),
+ partition.get_int(1).expect("Failed to decode hr"),
+ )
+ })
+ .collect()
+ }
+
+ #[tokio::test]
+ async fn test_scan_partition_filter_plans_matching_partition_set() {
+ let provider = create_provider("partitioned_log_table").await;
+ let planned_partitions =
+ plan_partitions(&provider,
vec![col("dt").eq(lit("2024-01-01"))]).await;
+
+ assert_eq!(
+ extract_dt_partition_set(&planned_partitions),
+ BTreeSet::from(["2024-01-01".to_string()]),
+ );
+ }
+
+ #[tokio::test]
+ async fn test_scan_mixed_and_filter_keeps_partition_pruning() {
+ let provider = create_provider("partitioned_log_table").await;
+ let planned_partitions = plan_partitions(
+ &provider,
+ vec![col("dt").eq(lit("2024-01-01")).and(col("id").gt(lit(1)))],
+ )
+ .await;
+
+ assert_eq!(
+ extract_dt_partition_set(&planned_partitions),
+ BTreeSet::from(["2024-01-01".to_string()]),
+ );
+ }
+
+ #[tokio::test]
+ async fn test_scan_multi_partition_filter_plans_exact_partition_set() {
+ let provider = create_provider("multi_partitioned_log_table").await;
+
+ let dt_only_partitions =
+ plan_partitions(&provider,
vec![col("dt").eq(lit("2024-01-01"))]).await;
+ let dt_hr_partitions = plan_partitions(
+ &provider,
+ vec![col("dt").eq(lit("2024-01-01")).and(col("hr").eq(lit(10)))],
+ )
+ .await;
+
+ assert_eq!(
+ extract_dt_hr_partition_set(&dt_only_partitions),
+ BTreeSet::from([
+ ("2024-01-01".to_string(), 10),
+ ("2024-01-01".to_string(), 20),
+ ]),
+ );
+ assert_eq!(
+ extract_dt_hr_partition_set(&dt_hr_partitions),
+ BTreeSet::from([("2024-01-01".to_string(), 10)]),
+ );
+ }
}
diff --git a/crates/integrations/datafusion/tests/read_tables.rs
b/crates/integrations/datafusion/tests/read_tables.rs
index be43ded..68ed58f 100644
--- a/crates/integrations/datafusion/tests/read_tables.rs
+++ b/crates/integrations/datafusion/tests/read_tables.rs
@@ -19,7 +19,9 @@ use std::sync::Arc;
use datafusion::arrow::array::{Int32Array, StringArray};
use datafusion::catalog::CatalogProvider;
-use datafusion::prelude::SessionContext;
+use datafusion::datasource::TableProvider;
+use datafusion::logical_expr::{col, lit, TableProviderFilterPushDown};
+use datafusion::prelude::{SessionConfig, SessionContext};
use paimon::catalog::Identifier;
use paimon::{Catalog, CatalogOptions, FileSystemCatalog, Options};
use paimon_datafusion::{PaimonCatalogProvider, PaimonTableProvider};
@@ -36,6 +38,15 @@ fn create_catalog() -> FileSystemCatalog {
}
async fn create_context(table_name: &str) -> SessionContext {
+ let provider = create_provider(table_name).await;
+ let ctx = SessionContext::new();
+ ctx.register_table(table_name, Arc::new(provider))
+ .expect("Failed to register table");
+
+ ctx
+}
+
+async fn create_provider(table_name: &str) -> PaimonTableProvider {
let catalog = create_catalog();
let identifier = Identifier::new("default", table_name);
let table = catalog
@@ -43,12 +54,7 @@ async fn create_context(table_name: &str) -> SessionContext {
.await
.expect("Failed to get table");
- let provider = PaimonTableProvider::try_new(table).expect("Failed to
create table provider");
- let ctx = SessionContext::new();
- ctx.register_table(table_name, Arc::new(provider))
- .expect("Failed to register table");
-
- ctx
+ PaimonTableProvider::try_new(table).expect("Failed to create table
provider")
}
async fn read_rows(table_name: &str) -> Vec<(i32, String)> {
@@ -61,8 +67,25 @@ async fn read_rows(table_name: &str) -> Vec<(i32, String)> {
"Expected at least one batch from table {table_name}"
);
- let mut actual_rows = Vec::new();
- for batch in &batches {
+ let mut actual_rows = extract_id_name_rows(&batches);
+ actual_rows.sort_by_key(|(id, _)| *id);
+ actual_rows
+}
+
+async fn collect_query(
+ table_name: &str,
+ sql: &str,
+) ->
datafusion::error::Result<Vec<datafusion::arrow::record_batch::RecordBatch>> {
+ let ctx = create_context(table_name).await;
+
+ ctx.sql(sql).await?.collect().await
+}
+
+fn extract_id_name_rows(
+ batches: &[datafusion::arrow::record_batch::RecordBatch],
+) -> Vec<(i32, String)> {
+ let mut rows = Vec::new();
+ for batch in batches {
let id_array = batch
.column_by_name("id")
.and_then(|column| column.as_any().downcast_ref::<Int32Array>())
@@ -73,24 +96,13 @@ async fn read_rows(table_name: &str) -> Vec<(i32, String)> {
.expect("Expected StringArray for name column");
for row_index in 0..batch.num_rows() {
- actual_rows.push((
+ rows.push((
id_array.value(row_index),
name_array.value(row_index).to_string(),
));
}
}
-
- actual_rows.sort_by_key(|(id, _)| *id);
- actual_rows
-}
-
-async fn collect_query(
- table_name: &str,
- sql: &str,
-) ->
datafusion::error::Result<Vec<datafusion::arrow::record_batch::RecordBatch>> {
- let ctx = create_context(table_name).await;
-
- ctx.sql(sql).await?.collect().await
+ rows
}
#[tokio::test]
@@ -164,22 +176,33 @@ async fn test_projection_via_datafusion() {
);
}
+#[tokio::test]
+async fn test_supports_partition_filters_pushdown() {
+ let provider = create_provider("multi_partitioned_log_table").await;
+ let partition_filter = col("dt").eq(lit("2024-01-01"));
+ let mixed_and_filter =
col("dt").eq(lit("2024-01-01")).and(col("id").gt(lit(1)));
+ let data_filter = col("id").gt(lit(1));
+
+ let supports = provider
+ .supports_filters_pushdown(&[&partition_filter, &mixed_and_filter,
&data_filter])
+ .expect("supports_filters_pushdown should succeed");
+
+ assert_eq!(
+ supports,
+ vec![
+ TableProviderFilterPushDown::Exact,
+ TableProviderFilterPushDown::Inexact,
+ TableProviderFilterPushDown::Inexact,
+ ]
+ );
+}
+
/// Verifies that `PaimonTableProvider::scan()` produces more than one
/// execution partition for a multi-partition table, and that the reported
/// partition count is still capped by `target_partitions`.
#[tokio::test]
async fn test_scan_partition_count_respects_session_config() {
- use datafusion::datasource::TableProvider;
- use datafusion::prelude::SessionConfig;
-
- let catalog = create_catalog();
- let identifier = Identifier::new("default", "partitioned_log_table");
- let table = catalog
- .get_table(&identifier)
- .await
- .expect("Failed to get table");
-
- let provider = PaimonTableProvider::try_new(table).expect("Failed to
create table provider");
+ let provider = create_provider("partitioned_log_table").await;
// With generous target_partitions, the plan should expose more than one
partition.
let config = SessionConfig::new().with_target_partitions(8);
@@ -215,6 +238,54 @@ async fn
test_scan_partition_count_respects_session_config() {
);
}
+#[tokio::test]
+async fn test_partition_filter_query_via_datafusion() {
+ let batches = collect_query(
+ "partitioned_log_table",
+ "SELECT id, name FROM partitioned_log_table WHERE dt = '2024-01-01'",
+ )
+ .await
+ .expect("Partition filter query should succeed");
+
+ let mut actual_rows = extract_id_name_rows(&batches);
+ actual_rows.sort_by_key(|(id, _)| *id);
+ assert_eq!(
+ actual_rows,
+ vec![(1, "alice".to_string()), (2, "bob".to_string())]
+ );
+}
+
+#[tokio::test]
+async fn test_multi_partition_filter_query_via_datafusion() {
+ let batches = collect_query(
+ "multi_partitioned_log_table",
+ "SELECT id, name FROM multi_partitioned_log_table WHERE dt =
'2024-01-01' AND hr = 10",
+ )
+ .await
+ .expect("Multi-partition filter query should succeed");
+
+ let mut actual_rows = extract_id_name_rows(&batches);
+ actual_rows.sort_by_key(|(id, _)| *id);
+ assert_eq!(
+ actual_rows,
+ vec![(1, "alice".to_string()), (2, "bob".to_string())]
+ );
+}
+
+#[tokio::test]
+async fn test_mixed_and_filter_keeps_residual_datafusion_filter() {
+ let batches = collect_query(
+ "partitioned_log_table",
+ "SELECT id, name FROM partitioned_log_table WHERE dt = '2024-01-01'
AND id > 1",
+ )
+ .await
+ .expect("Mixed filter query should succeed");
+
+ let actual_rows = extract_id_name_rows(&batches);
+
+ assert_eq!(actual_rows, vec![(2, "bob".to_string())]);
+}
+
// ======================= Catalog Provider Tests =======================
#[tokio::test]
async fn test_query_via_catalog_provider() {