This is an automated email from the ASF dual-hosted git repository.

agrove pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/arrow-datafusion-python.git


The following commit(s) were added to refs/heads/master by this push:
     new eb01db1  [DataFrame] - Add intersect and except bindings for DataFrame 
(#36)
eb01db1 is described below

commit eb01db1da28e5e314017b39a84ac8bf356805f7d
Author: Francis Du <[email protected]>
AuthorDate: Thu Sep 8 21:34:03 2022 +0800

    [DataFrame] - Add intersect and except bindings for DataFrame (#36)
    
    * fix: conflicting
    
    * fix: python linter
    
    * fix: flake W503 issue
---
 datafusion/tests/test_dataframe.py | 56 ++++++++++++++++++++++++++++++++++++++
 src/dataframe.rs                   | 12 ++++++++
 2 files changed, 68 insertions(+)

diff --git a/datafusion/tests/test_dataframe.py 
b/datafusion/tests/test_dataframe.py
index c9544ab..bbbdddd 100644
--- a/datafusion/tests/test_dataframe.py
+++ b/datafusion/tests/test_dataframe.py
@@ -256,3 +256,59 @@ def test_repartition(df):
 
 def test_repartition_by_hash(df):
     df.repartition_by_hash(column("a"), num=2)
+
+
+def test_intersect():
+    ctx = SessionContext()
+
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+    df_a = ctx.create_dataframe([[batch]])
+
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([3, 4, 5]), pa.array([6, 7, 8])],
+        names=["a", "b"],
+    )
+    df_b = ctx.create_dataframe([[batch]])
+
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([3]), pa.array([6])],
+        names=["a", "b"],
+    )
+    df_c = ctx.create_dataframe([[batch]]).sort(
+        column("a").sort(ascending=True)
+    )
+
+    df_a_i_b = df_a.intersect(df_b).sort(column("a").sort(ascending=True))
+
+    assert df_c.collect() == df_a_i_b.collect()
+
+
+def test_except_all():
+    ctx = SessionContext()
+
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2, 3]), pa.array([4, 5, 6])],
+        names=["a", "b"],
+    )
+    df_a = ctx.create_dataframe([[batch]])
+
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([3, 4, 5]), pa.array([6, 7, 8])],
+        names=["a", "b"],
+    )
+    df_b = ctx.create_dataframe([[batch]])
+
+    batch = pa.RecordBatch.from_arrays(
+        [pa.array([1, 2]), pa.array([4, 5])],
+        names=["a", "b"],
+    )
+    df_c = ctx.create_dataframe([[batch]]).sort(
+        column("a").sort(ascending=True)
+    )
+
+    df_a_e_b = df_a.except_all(df_b).sort(column("a").sort(ascending=True))
+
+    assert df_c.collect() == df_a_e_b.collect()
diff --git a/src/dataframe.rs b/src/dataframe.rs
index f6cb4f1..e491c3d 100644
--- a/src/dataframe.rs
+++ b/src/dataframe.rs
@@ -192,4 +192,16 @@ impl PyDataFrame {
         let new_df = self.df.repartition(Partitioning::Hash(expr, num))?;
         Ok(Self::new(new_df))
     }
+
+    /// Calculate the intersection of two `DataFrame`s.  The two `DataFrame`s 
must have exactly the same schema
+    fn intersect(&self, py_df: PyDataFrame) -> PyResult<Self> {
+        let new_df = self.df.intersect(py_df.df)?;
+        Ok(Self::new(new_df))
+    }
+
+    /// Calculate the exception of two `DataFrame`s.  The two `DataFrame`s 
must have exactly the same schema
+    fn except_all(&self, py_df: PyDataFrame) -> PyResult<Self> {
+        let new_df = self.df.except(py_df.df)?;
+        Ok(Self::new(new_df))
+    }
 }

Reply via email to