This is an automated email from the ASF dual-hosted git repository.
agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git
The following commit(s) were added to refs/heads/master by this push:
new eb01db1 [DataFrame] - Add intersect and except bindings for DataFrame
(#36)
eb01db1 is described below
commit eb01db1da28e5e314017b39a84ac8bf356805f7d
Author: Francis Du <[email protected]>
AuthorDate: Thu Sep 8 21:34:03 2022 +0800
[DataFrame] - Add intersect and except bindings for DataFrame (#36)
* fix: conflicting
* fix: python linter
* fix: flake W503 issue
---
datafusion/tests/test_dataframe.py | 56 ++++++++++++++++++++++++++++++++++++++
src/dataframe.rs | 12 ++++++++
2 files changed, 68 insertions(+)
diff --git a/datafusion/tests/test_dataframe.py
b/datafusion/tests/test_dataframe.py
index c9544ab..bbbdddd 100644
--- a/datafusion/tests/test_dataframe.py
+++ b/datafusion/tests/test_dataframe.py
@@ -256,3 +256,59 @@ def test_repartition(df):
def test_repartition_by_hash(df):
df.repartition_by_hash(column("a"), num=2)
+
+
+def test_intersect():
+ ctx = SessionContext()
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+ names=["a", "b"],
+ )
+ df_a = ctx.create_dataframe([[batch]])
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([3, 4, 5]), pa.array([6, 7, 8])],
+ names=["a", "b"],
+ )
+ df_b = ctx.create_dataframe([[batch]])
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([3]), pa.array([6])],
+ names=["a", "b"],
+ )
+ df_c = ctx.create_dataframe([[batch]]).sort(
+ column("a").sort(ascending=True)
+ )
+
+ df_a_i_b = df_a.intersect(df_b).sort(column("a").sort(ascending=True))
+
+ assert df_c.collect() == df_a_i_b.collect()
+
+
+def test_except_all():
+ ctx = SessionContext()
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+ names=["a", "b"],
+ )
+ df_a = ctx.create_dataframe([[batch]])
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([3, 4, 5]), pa.array([6, 7, 8])],
+ names=["a", "b"],
+ )
+ df_b = ctx.create_dataframe([[batch]])
+
+ batch = pa.RecordBatch.from_arrays(
+ [pa.array([1, 2]), pa.array([4, 5])],
+ names=["a", "b"],
+ )
+ df_c = ctx.create_dataframe([[batch]]).sort(
+ column("a").sort(ascending=True)
+ )
+
+ df_a_e_b = df_a.except_all(df_b).sort(column("a").sort(ascending=True))
+
+ assert df_c.collect() == df_a_e_b.collect()
diff --git a/src/dataframe.rs b/src/dataframe.rs
index f6cb4f1..e491c3d 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -192,4 +192,16 @@ impl PyDataFrame {
let new_df = self.df.repartition(Partitioning::Hash(expr, num))?;
Ok(Self::new(new_df))
}
+
+ /// Calculate the intersection of two `DataFrame`s. The two `DataFrame`s
must have exactly the same schema
+ fn intersect(&self, py_df: PyDataFrame) -> PyResult<Self> {
+ let new_df = self.df.intersect(py_df.df)?;
+ Ok(Self::new(new_df))
+ }
+
+ /// Calculate the exception of two `DataFrame`s. The two `DataFrame`s
must have exactly the same schema
+ fn except_all(&self, py_df: PyDataFrame) -> PyResult<Self> {
+ let new_df = self.df.except(py_df.df)?;
+ Ok(Self::new(new_df))
+ }
}