This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/master by this push:
new e0cbf48 Implement PyArrow Dataset TableProvider (#9)
e0cbf48 is described below
commit e0cbf48516d79ba28fc07f81d1f6d73e85416796
Author: Kyle Brooks <[email protected]>
AuthorDate: Tue Jul 26 07:26:23 2022 -0400
Implement PyArrow Dataset TableProvider (#9)
* Implement PyArrow Dataset TableProvider and register_dataset context
functions.
* Add dataset filter test.
* Change match on booleans to if else.
* Update Dataset TableProvider for updates in DataFusion 10.0.0 trait.
* Fixes to build with DataFusion 10.0.0.
* Improved DatasetExec physical plan printing.
Added nested filter test.
---
Cargo.lock | 2 +
Cargo.toml | 2 +
datafusion/tests/test_context.py | 70 ++++++++++
datafusion/tests/test_sql.py | 12 ++
src/context.rs | 13 ++
src/dataset.rs | 118 +++++++++++++++++
src/dataset_exec.rs | 270 +++++++++++++++++++++++++++++++++++++++
src/errors.rs | 16 ++-
src/lib.rs | 3 +
src/pyarrow_filter_expression.rs | 190 +++++++++++++++++++++++++++
10 files changed, 695 insertions(+), 1 deletion(-)
diff --git a/Cargo.lock b/Cargo.lock
index 1108211..bdbf3cc 100644
--- a/Cargo.lock
+++ b/Cargo.lock
@@ -395,9 +395,11 @@ dependencies = [
name = "datafusion-python"
version = "0.6.0"
dependencies = [
+ "async-trait",
"datafusion",
"datafusion-common",
"datafusion-expr",
+ "futures",
"mimalloc",
"pyo3",
"rand 0.7.3",
diff --git a/Cargo.toml b/Cargo.toml
index fd20b4d..05a21e0 100644
--- a/Cargo.toml
+++ b/Cargo.toml
@@ -39,6 +39,8 @@ datafusion-expr = { version = "^10.0.0" }
datafusion-common = { version = "^10.0.0", features = ["pyarrow"] }
uuid = { version = "0.8", features = ["v4"] }
mimalloc = { version = "*", optional = true, default-features = false }
+async-trait = "0.1"
+futures = "0.3"
[lib]
name = "datafusion_python"
diff --git a/datafusion/tests/test_context.py b/datafusion/tests/test_context.py
index 4d4a38c..1e1e771 100644
--- a/datafusion/tests/test_context.py
+++ b/datafusion/tests/test_context.py
@@ -16,6 +16,9 @@
# under the License.
import pyarrow as pa
+import pyarrow.dataset as ds
+
+from datafusion import column, literal
def test_register_record_batches(ctx):
@@ -72,3 +75,70 @@ def test_deregister_table(ctx, database):
ctx.deregister_table("csv")
assert public.names() == {"csv1", "csv2"}
+
+def test_register_dataset(ctx):
+ # create a RecordBatch and register it as a pyarrow.dataset.Dataset
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+ names=["a", "b"],
+ )
+ dataset = ds.dataset([batch])
+ ctx.register_dataset("t", dataset)
+
+ assert ctx.tables() == {"t"}
+
+ result = ctx.sql("SELECT a+b, a-b FROM t").collect()
+
+ assert result[0].column(0) == pa.array([5, 7, 9])
+ assert result[0].column(1) == pa.array([-3, -3, -3])
+
+def test_dataset_filter(ctx, capfd):
+ # create a RecordBatch and register it as a pyarrow.dataset.Dataset
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+ names=["a", "b"],
+ )
+ dataset = ds.dataset([batch])
+ ctx.register_dataset("t", dataset)
+
+ assert ctx.tables() == {"t"}
+ df = ctx.sql("SELECT a+b, a-b FROM t WHERE a BETWEEN 2 and 3 AND b > 5")
+
+ # Make sure the filter was pushed down in Physical Plan
+ df.explain()
+ captured = capfd.readouterr()
+ assert "filter_expr=(((2 <= a) and (a <= 3)) and (b > 5))" in captured.out
+
+ result = df.collect()
+
+ assert result[0].column(0) == pa.array([9])
+ assert result[0].column(1) == pa.array([-3])
+
+
+def test_dataset_filter_nested_data(ctx):
+ # create Arrow StructArrays to test nested data types
+ data = pa.StructArray.from_arrays(
+ [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+ names=["a", "b"],
+ )
+ batch = pa.RecordBatch.from_arrays(
+ [data],
+ names=["nested_data"],
+ )
+ dataset = ds.dataset([batch])
+ ctx.register_dataset("t", dataset)
+
+ assert ctx.tables() == {"t"}
+
+ df = ctx.table("t")
+
+ # This filter will not be pushed down to DatasetExec since it isn't
supported
+ df = df.select(
+ column("nested_data")["a"] + column("nested_data")["b"],
+ column("nested_data")["a"] - column("nested_data")["b"],
+ ).filter(column("nested_data")["b"] > literal(5))
+
+ result = df.collect()
+
+ assert result[0].column(0) == pa.array([9])
+ assert result[0].column(1) == pa.array([-3])
diff --git a/datafusion/tests/test_sql.py b/datafusion/tests/test_sql.py
index af3b38a..ffbfc2c 100644
--- a/datafusion/tests/test_sql.py
+++ b/datafusion/tests/test_sql.py
@@ -17,6 +17,7 @@
import numpy as np
import pyarrow as pa
+import pyarrow.dataset as ds
import pytest
from datafusion import udf
@@ -121,6 +122,17 @@ def test_register_parquet_partitioned(ctx, tmp_path):
rd = result.to_pydict()
assert dict(zip(rd["grp"], rd["cnt"])) == {"a": 3, "b": 1}
+def test_register_dataset(ctx, tmp_path):
+ path = helpers.write_parquet(tmp_path / "a.parquet", helpers.data())
+ dataset = ds.dataset(path, format="parquet")
+
+ ctx.register_dataset("t", dataset)
+ assert ctx.tables() == {"t"}
+
+ result = ctx.sql("SELECT COUNT(a) AS cnt FROM t").collect()
+ result = pa.Table.from_batches(result)
+ assert result.to_pydict() == {"cnt": [100]}
+
def test_execute(ctx, tmp_path):
data = [1, 1, 2, 2, 3, 11, 12]
diff --git a/src/context.rs b/src/context.rs
index 213703f..d2c17ad 100644
--- a/src/context.rs
+++ b/src/context.rs
@@ -25,12 +25,14 @@ use pyo3::prelude::*;
use datafusion::arrow::datatypes::Schema;
use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::datasource::datasource::TableProvider;
use datafusion::datasource::MemTable;
use datafusion::execution::context::SessionContext;
use datafusion::prelude::{CsvReadOptions, ParquetReadOptions};
use crate::catalog::{PyCatalog, PyTable};
use crate::dataframe::PyDataFrame;
+use crate::dataset::Dataset;
use crate::errors::DataFusionError;
use crate::udf::PyScalarUDF;
use crate::utils::wait_for_future;
@@ -173,6 +175,17 @@ impl PySessionContext {
Ok(())
}
+ // Registers a PyArrow.Dataset
+ fn register_dataset(&self, name: &str, dataset: &PyAny, py: Python) ->
PyResult<()> {
+ let table: Arc<dyn TableProvider> = Arc::new(Dataset::new(dataset,
py)?);
+
+ self.ctx
+ .register_table(name, table)
+ .map_err(DataFusionError::from)?;
+
+ Ok(())
+ }
+
fn register_udf(&mut self, udf: PyScalarUDF) -> PyResult<()> {
self.ctx.register_udf(udf.function);
Ok(())
diff --git a/src/dataset.rs b/src/dataset.rs
new file mode 100644
index 0000000..6272bc8
--- /dev/null
+++ b/src/dataset.rs
@@ -0,0 +1,118 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+use pyo3::exceptions::PyValueError;
+/// Implements a Datafusion TableProvider that delegates to a PyArrow Dataset
+/// This allows us to use PyArrow Datasets as Datafusion tables while pushing
down projections and filters
+use pyo3::prelude::*;
+use pyo3::types::PyType;
+
+use std::any::Any;
+use std::sync::Arc;
+
+use async_trait::async_trait;
+
+use datafusion::arrow::datatypes::SchemaRef;
+use datafusion::datasource::datasource::TableProviderFilterPushDown;
+use datafusion::datasource::{TableProvider, TableType};
+use datafusion::error::{DataFusionError, Result as DFResult};
+use datafusion::execution::context::SessionState;
+use datafusion::logical_plan::*;
+use datafusion::physical_plan::ExecutionPlan;
+
+use crate::dataset_exec::DatasetExec;
+use crate::pyarrow_filter_expression::PyArrowFilterExpression;
+
+// Wraps a pyarrow.dataset.Dataset class and implements a Datafusion
TableProvider around it
+#[derive(Debug, Clone)]
+pub(crate) struct Dataset {
+ dataset: PyObject,
+}
+
+impl Dataset {
+ // Creates a Python PyArrow.Dataset
+ pub fn new(dataset: &PyAny, py: Python) -> PyResult<Self> {
+ // Ensure that we were passed an instance of pyarrow.dataset.Dataset
+ let ds = PyModule::import(py, "pyarrow.dataset")?;
+ let ds_type: &PyType = ds.getattr("Dataset")?.downcast()?;
+ if dataset.is_instance(ds_type)? {
+ Ok(Dataset {
+ dataset: dataset.into(),
+ })
+ } else {
+ Err(PyValueError::new_err(
+ "dataset argument must be a pyarrow.dataset.Dataset object",
+ ))
+ }
+ }
+}
+
+#[async_trait]
+impl TableProvider for Dataset {
+ /// Returns the table provider as [`Any`](std::any::Any) so that it can be
+ /// downcast to a specific implementation.
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ /// Get a reference to the schema for this table
+ fn schema(&self) -> SchemaRef {
+ Python::with_gil(|py| {
+ let dataset = self.dataset.as_ref(py);
+ // This can panic but since we checked that self.dataset is a
pyarrow.dataset.Dataset it should never
+ Arc::new(dataset.getattr("schema").unwrap().extract().unwrap())
+ })
+ }
+
+ /// Get the type of this table for metadata/catalog purposes.
+ fn table_type(&self) -> TableType {
+ TableType::Base
+ }
+
+ /// Create an ExecutionPlan that will scan the table.
+ /// The table provider will be usually responsible of grouping
+ /// the source data into partitions that can be efficiently
+ /// parallelized or distributed.
+ async fn scan(
+ &self,
+ _ctx: &SessionState,
+ projection: &Option<Vec<usize>>,
+ filters: &[Expr],
+ // limit can be used to reduce the amount scanned
+ // from the datasource as a performance optimization.
+ // If set, it contains the amount of rows needed by the `LogicalPlan`,
+ // The datasource should return *at least* this number of rows if
available.
+ _limit: Option<usize>,
+ ) -> DFResult<Arc<dyn ExecutionPlan>> {
+ Python::with_gil(|py| {
+ let plan: Arc<dyn ExecutionPlan> = Arc::new(
+ DatasetExec::new(py, self.dataset.as_ref(py),
projection.clone(), filters)
+ .map_err(|err| DataFusionError::External(Box::new(err)))?,
+ );
+ Ok(plan)
+ })
+ }
+
+ /// Tests whether the table provider can make use of a filter expression
+ /// to optimise data retrieval.
+ fn supports_filter_pushdown(&self, filter: &Expr) ->
DFResult<TableProviderFilterPushDown> {
+ match PyArrowFilterExpression::try_from(filter) {
+ Ok(_) => Ok(TableProviderFilterPushDown::Exact),
+ _ => Ok(TableProviderFilterPushDown::Unsupported),
+ }
+ }
+}
diff --git a/src/dataset_exec.rs b/src/dataset_exec.rs
new file mode 100644
index 0000000..a3925ad
--- /dev/null
+++ b/src/dataset_exec.rs
@@ -0,0 +1,270 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Implements a Datafusion physical ExecutionPlan that delegates to a PyArrow
Dataset
+/// This actually performs the projection, filtering and scanning of a Dataset
+use pyo3::prelude::*;
+use pyo3::types::{PyDict, PyIterator, PyList};
+
+use std::any::Any;
+use std::sync::Arc;
+
+use futures::stream;
+
+use datafusion::arrow::datatypes::SchemaRef;
+use datafusion::arrow::error::ArrowError;
+use datafusion::arrow::error::Result as ArrowResult;
+use datafusion::arrow::record_batch::RecordBatch;
+use datafusion::error::{DataFusionError as InnerDataFusionError, Result as
DFResult};
+use datafusion::execution::context::TaskContext;
+use datafusion::logical_plan::{combine_filters, Expr};
+use datafusion::physical_expr::PhysicalSortExpr;
+use datafusion::physical_plan::stream::RecordBatchStreamAdapter;
+use datafusion::physical_plan::{
+ DisplayFormatType, ExecutionPlan, Partitioning, SendableRecordBatchStream,
Statistics,
+};
+
+use crate::errors::DataFusionError;
+use crate::pyarrow_filter_expression::PyArrowFilterExpression;
+
+struct PyArrowBatchesAdapter {
+ batches: Py<PyIterator>,
+}
+
+impl Iterator for PyArrowBatchesAdapter {
+ type Item = ArrowResult<RecordBatch>;
+
+ fn next(&mut self) -> Option<Self::Item> {
+ Python::with_gil(|py| {
+ let mut batches: &PyIterator = self.batches.as_ref(py);
+ Some(
+ batches
+ .next()?
+ .and_then(|batch| batch.extract())
+ .map_err(|err| ArrowError::ExternalError(Box::new(err))),
+ )
+ })
+ }
+}
+
+// Wraps a pyarrow.dataset.Dataset class and implements a Datafusion
ExecutionPlan around it
+#[derive(Debug, Clone)]
+pub(crate) struct DatasetExec {
+ dataset: PyObject,
+ schema: SchemaRef,
+ fragments: Py<PyList>,
+ columns: Option<Vec<String>>,
+ filter_expr: Option<PyObject>,
+ projected_statistics: Statistics,
+}
+
+impl DatasetExec {
+ pub fn new(
+ py: Python,
+ dataset: &PyAny,
+ projection: Option<Vec<usize>>,
+ filters: &[Expr],
+ ) -> Result<Self, DataFusionError> {
+ let columns: Option<Result<Vec<String>, DataFusionError>> =
projection.map(|p| {
+ p.iter()
+ .map(|index| {
+ let name: String = dataset
+ .getattr("schema")?
+ .call_method1("field", (*index,))?
+ .getattr("name")?
+ .extract()?;
+ Ok(name)
+ })
+ .collect()
+ });
+ let columns: Option<Vec<String>> = columns.transpose()?;
+ let filter_expr: Option<PyObject> = combine_filters(filters)
+ .map(|filters| {
+ PyArrowFilterExpression::try_from(&filters)
+ .map(|filter_expr| filter_expr.inner().clone_ref(py))
+ })
+ .transpose()?;
+
+ let kwargs = PyDict::new(py);
+
+ kwargs.set_item("columns", columns.clone())?;
+ kwargs.set_item(
+ "filter",
+ filter_expr.as_ref().map(|expr| expr.clone_ref(py)),
+ )?;
+
+ let scanner = dataset.call_method("scanner", (), Some(kwargs))?;
+
+ let schema = Arc::new(scanner.getattr("projected_schema")?.extract()?);
+
+ let builtins = Python::import(py, "builtins")?;
+ let pylist = builtins.getattr("list")?;
+
+ // Get the fragments or partitions of the dataset
+ let fragments_iterator: &PyAny = dataset.call_method1(
+ "get_fragments",
+ (filter_expr.as_ref().map(|expr| expr.clone_ref(py)),),
+ )?;
+
+ let fragments: &PyList = pylist
+ .call1((fragments_iterator,))?
+ .downcast()
+ .map_err(PyErr::from)?;
+
+ Ok(DatasetExec {
+ dataset: dataset.into(),
+ schema,
+ fragments: fragments.into(),
+ columns,
+ filter_expr,
+ projected_statistics: Default::default(),
+ })
+ }
+}
+
+impl ExecutionPlan for DatasetExec {
+ /// Return a reference to Any that can be used for downcasting
+ fn as_any(&self) -> &dyn Any {
+ self
+ }
+
+ /// Get the schema for this execution plan
+ fn schema(&self) -> SchemaRef {
+ self.schema.clone()
+ }
+
+ /// Get the output partitioning of this plan
+ fn output_partitioning(&self) -> Partitioning {
+ Python::with_gil(|py| {
+ let fragments = self.fragments.as_ref(py);
+ Partitioning::UnknownPartitioning(fragments.len())
+ })
+ }
+
+ fn relies_on_input_order(&self) -> bool {
+ false
+ }
+
+ fn output_ordering(&self) -> Option<&[PhysicalSortExpr]> {
+ None
+ }
+
+ fn children(&self) -> Vec<Arc<dyn ExecutionPlan>> {
+ // this is a leaf node and has no children
+ vec![]
+ }
+
+ fn with_new_children(
+ self: Arc<Self>,
+ _: Vec<Arc<dyn ExecutionPlan>>,
+ ) -> DFResult<Arc<dyn ExecutionPlan>> {
+ Ok(self)
+ }
+
+ fn execute(
+ &self,
+ partition: usize,
+ context: Arc<TaskContext>,
+ ) -> DFResult<SendableRecordBatchStream> {
+ let batch_size = context.session_config().batch_size();
+ Python::with_gil(|py| {
+ let dataset = self.dataset.as_ref(py);
+ let fragments = self.fragments.as_ref(py);
+ let fragment = fragments
+ .get_item(partition)
+ .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+
+ // We need to pass the dataset schema to unify the fragment and
dataset schema per PyArrow docs
+ let dataset_schema = dataset
+ .getattr("schema")
+ .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+ let kwargs = PyDict::new(py);
+ kwargs
+ .set_item("columns", self.columns.clone())
+ .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+ kwargs
+ .set_item(
+ "filter",
+ self.filter_expr.as_ref().map(|expr| expr.clone_ref(py)),
+ )
+ .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+ kwargs
+ .set_item("batch_size", batch_size)
+ .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+ let scanner = fragment
+ .call_method("scanner", (dataset_schema,), Some(kwargs))
+ .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+ let schema: SchemaRef = Arc::new(
+ scanner
+ .getattr("projected_schema")
+ .and_then(|schema| schema.extract())
+ .map_err(|err|
InnerDataFusionError::External(Box::new(err)))?,
+ );
+ let record_batches: &PyIterator = scanner
+ .call_method0("to_batches")
+ .map_err(|err| InnerDataFusionError::External(Box::new(err)))?
+ .iter()
+ .map_err(|err| InnerDataFusionError::External(Box::new(err)))?;
+
+ let record_batches = PyArrowBatchesAdapter {
+ batches: record_batches.into(),
+ };
+
+ let record_batch_stream = stream::iter(record_batches);
+ let record_batch_stream: SendableRecordBatchStream =
+ Box::pin(RecordBatchStreamAdapter::new(schema,
record_batch_stream));
+ Ok(record_batch_stream)
+ })
+ }
+
+ fn fmt_as(&self, t: DisplayFormatType, f: &mut std::fmt::Formatter) ->
std::fmt::Result {
+ Python::with_gil(|py| {
+ let number_of_fragments = self.fragments.as_ref(py).len();
+ match t {
+ DisplayFormatType::Default => {
+ let projected_columns: Vec<String> = self
+ .schema
+ .fields()
+ .iter()
+ .map(|x| x.name().to_owned())
+ .collect();
+ if let Some(filter_expr) = &self.filter_expr {
+ let filter_expr =
filter_expr.as_ref(py).str().or(Err(std::fmt::Error))?;
+ write!(
+ f,
+ "DatasetExec: number_of_fragments={},
filter_expr={}, projection=[{}]",
+ number_of_fragments,
+ filter_expr,
+ projected_columns.join(", "),
+ )
+ } else {
+ write!(
+ f,
+ "DatasetExec: number_of_fragments={},
projection=[{}]",
+ number_of_fragments,
+ projected_columns.join(", "),
+ )
+ }
+ }
+ }
+ })
+ }
+
+ fn statistics(&self) -> Statistics {
+ self.projected_statistics.clone()
+ }
+}
diff --git a/src/errors.rs b/src/errors.rs
index 655ed84..29d3e8f 100644
--- a/src/errors.rs
+++ b/src/errors.rs
@@ -16,6 +16,7 @@
// under the License.
use core::fmt;
+use std::error::Error;
use datafusion::arrow::error::ArrowError;
use datafusion::error::DataFusionError as InnerDataFusionError;
@@ -26,6 +27,7 @@ pub enum DataFusionError {
ExecutionError(InnerDataFusionError),
ArrowError(ArrowError),
Common(String),
+ PythonError(PyErr),
}
impl fmt::Display for DataFusionError {
@@ -33,6 +35,7 @@ impl fmt::Display for DataFusionError {
match self {
DataFusionError::ExecutionError(e) => write!(f, "DataFusion error:
{:?}", e),
DataFusionError::ArrowError(e) => write!(f, "Arrow error: {:?}",
e),
+ DataFusionError::PythonError(e) => write!(f, "Python error {:?}",
e),
DataFusionError::Common(e) => write!(f, "{}", e),
}
}
@@ -50,8 +53,19 @@ impl From<InnerDataFusionError> for DataFusionError {
}
}
+impl From<PyErr> for DataFusionError {
+ fn from(err: PyErr) -> DataFusionError {
+ DataFusionError::PythonError(err)
+ }
+}
+
impl From<DataFusionError> for PyErr {
fn from(err: DataFusionError) -> PyErr {
- PyException::new_err(err.to_string())
+ match err {
+ DataFusionError::PythonError(py_err) => py_err,
+ _ => PyException::new_err(err.to_string()),
+ }
}
}
+
+impl Error for DataFusionError {}
diff --git a/src/lib.rs b/src/lib.rs
index 25b63e8..c6ab58e 100644
--- a/src/lib.rs
+++ b/src/lib.rs
@@ -22,9 +22,12 @@ use pyo3::prelude::*;
pub mod catalog;
mod context;
mod dataframe;
+mod dataset;
+mod dataset_exec;
pub mod errors;
mod expression;
mod functions;
+mod pyarrow_filter_expression;
mod udaf;
mod udf;
pub mod utils;
diff --git a/src/pyarrow_filter_expression.rs b/src/pyarrow_filter_expression.rs
new file mode 100644
index 0000000..3807553
--- /dev/null
+++ b/src/pyarrow_filter_expression.rs
@@ -0,0 +1,190 @@
+// Licensed to the Apache Software Foundation (ASF) under one
+// or more contributor license agreements. See the NOTICE file
+// distributed with this work for additional information
+// regarding copyright ownership. The ASF licenses this file
+// to you under the Apache License, Version 2.0 (the
+// "License"); you may not use this file except in compliance
+// with the License. You may obtain a copy of the License at
+//
+// http://www.apache.org/licenses/LICENSE-2.0
+//
+// Unless required by applicable law or agreed to in writing,
+// software distributed under the License is distributed on an
+// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+// KIND, either express or implied. See the License for the
+// specific language governing permissions and limitations
+// under the License.
+
+/// Converts a Datafusion logical plan expression (Expr) into a PyArrow
compute expression
+use pyo3::prelude::*;
+
+use std::convert::TryFrom;
+use std::result::Result;
+
+use datafusion::logical_plan::*;
+use datafusion_common::ScalarValue;
+
+use crate::errors::DataFusionError;
+
+#[derive(Debug, Clone)]
+#[repr(transparent)]
+pub(crate) struct PyArrowFilterExpression(PyObject);
+
+fn operator_to_py<'py>(
+ operator: &Operator,
+ op: &'py PyModule,
+) -> Result<&'py PyAny, DataFusionError> {
+ let py_op: &PyAny = match operator {
+ Operator::Eq => op.getattr("eq")?,
+ Operator::NotEq => op.getattr("ne")?,
+ Operator::Lt => op.getattr("lt")?,
+ Operator::LtEq => op.getattr("le")?,
+ Operator::Gt => op.getattr("gt")?,
+ Operator::GtEq => op.getattr("ge")?,
+ Operator::And => op.getattr("and_")?,
+ Operator::Or => op.getattr("or_")?,
+ _ => {
+ return Err(DataFusionError::Common(format!(
+ "Unsupported operator {:?}",
+ operator
+ )))
+ }
+ };
+ Ok(py_op)
+}
+
+fn extract_scalar_list(exprs: &[Expr], py: Python) -> Result<Vec<PyObject>,
DataFusionError> {
+ let ret: Result<Vec<PyObject>, DataFusionError> = exprs
+ .iter()
+ .map(|expr| match expr {
+ Expr::Literal(v) => match v {
+ ScalarValue::Boolean(Some(b)) => Ok(b.into_py(py)),
+ ScalarValue::Int8(Some(i)) => Ok(i.into_py(py)),
+ ScalarValue::Int16(Some(i)) => Ok(i.into_py(py)),
+ ScalarValue::Int32(Some(i)) => Ok(i.into_py(py)),
+ ScalarValue::Int64(Some(i)) => Ok(i.into_py(py)),
+ ScalarValue::UInt8(Some(i)) => Ok(i.into_py(py)),
+ ScalarValue::UInt16(Some(i)) => Ok(i.into_py(py)),
+ ScalarValue::UInt32(Some(i)) => Ok(i.into_py(py)),
+ ScalarValue::UInt64(Some(i)) => Ok(i.into_py(py)),
+ ScalarValue::Float32(Some(f)) => Ok(f.into_py(py)),
+ ScalarValue::Float64(Some(f)) => Ok(f.into_py(py)),
+ ScalarValue::Utf8(Some(s)) => Ok(s.into_py(py)),
+ _ => Err(DataFusionError::Common(format!(
+ "PyArrow can't handle ScalarValue: {:?}",
+ v
+ ))),
+ },
+ _ => Err(DataFusionError::Common(format!(
+ "Only a list of Literals are supported got {:?}",
+ expr
+ ))),
+ })
+ .collect();
+ ret
+}
+
+impl PyArrowFilterExpression {
+ pub fn inner(&self) -> &PyObject {
+ &self.0
+ }
+}
+
+impl TryFrom<&Expr> for PyArrowFilterExpression {
+ type Error = DataFusionError;
+
+ // Converts a Datafusion filter Expr into an expression string that can be
evaluated by Python
+ // Note that pyarrow.compute.{field,scalar} are put into Python globals()
when evaluated
+ // isin, is_null, and is_valid (~is_null) are methods of
pyarrow.dataset.Expression
+ //
https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Expression.html#pyarrow-dataset-expression
+ fn try_from(expr: &Expr) -> Result<Self, Self::Error> {
+ Python::with_gil(|py| {
+ let pc = Python::import(py, "pyarrow.compute")?;
+ let op_module = Python::import(py, "operator")?;
+ let pc_expr: Result<&PyAny, DataFusionError> = match expr {
+ Expr::Column(Column { name, .. }) =>
Ok(pc.getattr("field")?.call1((name,))?),
+ Expr::Literal(v) => match v {
+ ScalarValue::Boolean(Some(b)) =>
Ok(pc.getattr("scalar")?.call1((*b,))?),
+ ScalarValue::Int8(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
+ ScalarValue::Int16(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
+ ScalarValue::Int32(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
+ ScalarValue::Int64(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
+ ScalarValue::UInt8(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
+ ScalarValue::UInt16(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
+ ScalarValue::UInt32(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
+ ScalarValue::UInt64(Some(i)) =>
Ok(pc.getattr("scalar")?.call1((*i,))?),
+ ScalarValue::Float32(Some(f)) =>
Ok(pc.getattr("scalar")?.call1((*f,))?),
+ ScalarValue::Float64(Some(f)) =>
Ok(pc.getattr("scalar")?.call1((*f,))?),
+ ScalarValue::Utf8(Some(s)) =>
Ok(pc.getattr("scalar")?.call1((s,))?),
+ _ => Err(DataFusionError::Common(format!(
+ "PyArrow can't handle ScalarValue: {:?}",
+ v
+ ))),
+ },
+ Expr::BinaryExpr { left, op, right } => {
+ let operator = operator_to_py(op, op_module)?;
+ let left =
PyArrowFilterExpression::try_from(left.as_ref())?.0;
+ let right =
PyArrowFilterExpression::try_from(right.as_ref())?.0;
+ Ok(operator.call1((left, right))?)
+ }
+ Expr::Not(expr) => {
+ let operator = op_module.getattr("invert")?;
+ let py_expr =
PyArrowFilterExpression::try_from(expr.as_ref())?.0;
+ Ok(operator.call1((py_expr,))?)
+ }
+ Expr::IsNotNull(expr) => {
+ let py_expr =
PyArrowFilterExpression::try_from(expr.as_ref())?
+ .0
+ .into_ref(py);
+ Ok(py_expr.call_method0("is_valid")?)
+ }
+ Expr::IsNull(expr) => {
+ let expr =
PyArrowFilterExpression::try_from(expr.as_ref())?
+ .0
+ .into_ref(py);
+ Ok(expr.call_method1("is_null", (expr,))?)
+ }
+ Expr::Between {
+ expr,
+ negated,
+ low,
+ high,
+ } => {
+ let expr =
PyArrowFilterExpression::try_from(expr.as_ref())?.0;
+ let low =
PyArrowFilterExpression::try_from(low.as_ref())?.0;
+ let high =
PyArrowFilterExpression::try_from(high.as_ref())?.0;
+ let and = op_module.getattr("and_")?;
+ let le = op_module.getattr("le")?;
+ let invert = op_module.getattr("invert")?;
+
+ // scalar <= field() returns a boolean expression so we
need to use and to combine these
+ let ret = and.call1((
+ le.call1((low, expr.clone_ref(py)))?,
+ le.call1((expr, high))?,
+ ))?;
+
+ Ok(if *negated { invert.call1((ret,))? } else { ret })
+ }
+ Expr::InList {
+ expr,
+ list,
+ negated,
+ } => {
+ let expr =
PyArrowFilterExpression::try_from(expr.as_ref())?
+ .0
+ .into_ref(py);
+ let scalars = extract_scalar_list(list, py)?;
+ let ret = expr.call_method1("isin", (scalars,))?;
+ let invert = op_module.getattr("invert")?;
+
+ Ok(if *negated { invert.call1((ret,))? } else { ret })
+ }
+ _ => Err(DataFusionError::Common(format!(
+ "Unsupported Datafusion expression {:?}",
+ expr
+ ))),
+ };
+ Ok(PyArrowFilterExpression(pc_expr?.into()))
+ })
+ }
+}