allisonwang-db commented on code in PR #49961:
URL: https://github.com/apache/spark/pull/49961#discussion_r2002140192


##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala:
##########
@@ -213,6 +220,66 @@ class PythonDataSourceSuite extends 
PythonDataSourceSuiteBase {
       parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\""))
   }
 
+  test("data source reader with filter pushdown") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import (
+         |    DataSource,
+         |    DataSourceReader,
+         |    EqualTo,
+         |    InputPartition,
+         |)
+         |
+         |class SimpleDataSourceReader(DataSourceReader):
+         |    def partitions(self):
+         |        return [InputPartition(i) for i in range(2)]
+         |
+         |    def pushFilters(self, filters):
+         |        yield filters[filters.index(EqualTo(("id",), 1))]
+         |
+         |    def read(self, partition):
+         |        yield (0, partition.value)
+         |        yield (1, partition.value)
+         |        yield (2, partition.value)
+         |
+         |class SimpleDataSource(DataSource):
+         |    def schema(self):
+         |        return "id int, partition int"
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader()
+         |""".stripMargin
+    val schema = StructType.fromDDL("id INT, partition INT")
+    val dataSource =
+      createUserDefinedPythonDataSource(name = dataSourceName, pythonScript = 
dataSourceScript)
+    spark.conf.set(SQLConf.PYTHON_FILTER_PUSHDOWN_ENABLED, true)

Review Comment:
   Let's use `withSQLConf` 



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala:
##########
@@ -213,6 +220,66 @@ class PythonDataSourceSuite extends 
PythonDataSourceSuiteBase {
       parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\""))
   }
 
+  test("data source reader with filter pushdown") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import (
+         |    DataSource,
+         |    DataSourceReader,
+         |    EqualTo,
+         |    InputPartition,
+         |)
+         |
+         |class SimpleDataSourceReader(DataSourceReader):
+         |    def partitions(self):
+         |        return [InputPartition(i) for i in range(2)]
+         |
+         |    def pushFilters(self, filters):
+         |        yield filters[filters.index(EqualTo(("id",), 1))]
+         |
+         |    def read(self, partition):
+         |        yield (0, partition.value)
+         |        yield (1, partition.value)
+         |        yield (2, partition.value)
+         |
+         |class SimpleDataSource(DataSource):
+         |    def schema(self):
+         |        return "id int, partition int"
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader()
+         |""".stripMargin
+    val schema = StructType.fromDDL("id INT, partition INT")
+    val dataSource =
+      createUserDefinedPythonDataSource(name = dataSourceName, pythonScript = 
dataSourceScript)
+    spark.conf.set(SQLConf.PYTHON_FILTER_PUSHDOWN_ENABLED, true)

Review Comment:
   Can we also add another test case with this config disabled? 



##########
python/pyspark/sql/tests/test_python_datasource.py:
##########
@@ -246,6 +251,161 @@ def reader(self, schema) -> "DataSourceReader":
         assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
         self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2)
 
+    def test_filter_pushdown(self):
+        class TestDataSourceReader(DataSourceReader):
+            def __init__(self):
+                self.has_filter = False
+
+            def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+                assert set(filters) == {
+                    EqualTo(("x",), 1),
+                    EqualTo(("y",), 2),
+                }, filters
+                self.has_filter = True
+                # pretend we support x = 1 filter but in fact we don't
+                # so we only return y = 2 filter
+                yield filters[filters.index(EqualTo(("y",), 2))]
+
+            def partitions(self):
+                assert self.has_filter
+                return super().partitions()
+
+            def read(self, partition):
+                assert self.has_filter
+                yield [1, 1]
+                yield [1, 2]
+                yield [2, 2]
+
+        class TestDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                return "test"
+
+            def schema(self):
+                return "x int, y int"
+
+            def reader(self, schema) -> "DataSourceReader":
+                return TestDataSourceReader()
+
+        self.spark.conf.set("spark.sql.python.filterPushdown.enabled", True)

Review Comment:
   let's use `with self.sql_conf`



##########
python/pyspark/sql/tests/test_python_datasource.py:
##########
@@ -246,6 +251,161 @@ def reader(self, schema) -> "DataSourceReader":
         assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
         self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2)
 
+    def test_filter_pushdown(self):
+        class TestDataSourceReader(DataSourceReader):
+            def __init__(self):
+                self.has_filter = False
+
+            def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+                assert set(filters) == {
+                    EqualTo(("x",), 1),
+                    EqualTo(("y",), 2),
+                }, filters
+                self.has_filter = True
+                # pretend we support x = 1 filter but in fact we don't
+                # so we only return y = 2 filter
+                yield filters[filters.index(EqualTo(("y",), 2))]
+
+            def partitions(self):
+                assert self.has_filter
+                return super().partitions()
+
+            def read(self, partition):
+                assert self.has_filter
+                yield [1, 1]
+                yield [1, 2]
+                yield [2, 2]
+
+        class TestDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                return "test"
+
+            def schema(self):
+                return "x int, y int"
+
+            def reader(self, schema) -> "DataSourceReader":
+                return TestDataSourceReader()
+
+        self.spark.conf.set("spark.sql.python.filterPushdown.enabled", True)
+        self.spark.dataSource.register(TestDataSource)
+        df = self.spark.read.format("test").load().filter("x = 1 and y = 2")
+        # only the y = 2 filter is applied post scan
+        assertDataFrameEqual(df, [Row(x=1, y=2), Row(x=2, y=2)])
+
+    def test_extraneous_filter(self):
+        class TestDataSourceReader(DataSourceReader):
+            def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+                yield EqualTo(("x",), 1)
+
+            def partitions(self):
+                assert False
+
+            def read(self, partition):
+                assert False
+
+        class TestDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                return "test"
+
+            def schema(self):
+                return "x int"
+
+            def reader(self, schema) -> "DataSourceReader":
+                return TestDataSourceReader()
+
+        self.spark.conf.set("spark.sql.python.filterPushdown.enabled", True)
+        self.spark.dataSource.register(TestDataSource)
+        with self.assertRaisesRegex(Exception, 
"DATA_SOURCE_EXTRANEOUS_FILTERS"):
+            self.spark.read.format("test").load().filter("x = 1").show()
+
+    def test_filter_pushdown_error(self):
+        error_str = "dummy error"
+
+        class TestDataSourceReader(DataSourceReader):
+            def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+                raise Exception(error_str)
+
+            def read(self, partition):
+                yield [1]
+
+        class TestDataSource(DataSource):
+            def schema(self):
+                return "x int"
+
+            def reader(self, schema) -> "DataSourceReader":
+                return TestDataSourceReader()
+
+        self.spark.conf.set("spark.sql.python.filterPushdown.enabled", True)
+        self.spark.dataSource.register(TestDataSource)
+        df = self.spark.read.format("TestDataSource").load().filter("x = 1 or 
x is null")
+        assertDataFrameEqual(df, [Row(x=1)])  # works when not pushing down 
filters
+        with self.assertRaisesRegex(Exception, error_str):
+            df.filter("x = 1").explain()
+
+    def test_filter_pushdown_disabled(self):
+        class TestDataSourceReader(DataSourceReader):
+            def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+                assert False
+
+            def read(self, partition):
+                assert False
+
+        class TestDataSource(DataSource):
+            def reader(self, schema) -> "DataSourceReader":
+                return TestDataSourceReader()
+
+        self.spark.conf.set("spark.sql.python.filterPushdown.enabled", False)

Review Comment:
   ditto



##########
python/pyspark/sql/tests/test_python_datasource.py:
##########
@@ -246,6 +251,161 @@ def reader(self, schema) -> "DataSourceReader":
         assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")])
         self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2)
 
+    def test_filter_pushdown(self):
+        class TestDataSourceReader(DataSourceReader):
+            def __init__(self):
+                self.has_filter = False
+
+            def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+                assert set(filters) == {
+                    EqualTo(("x",), 1),
+                    EqualTo(("y",), 2),
+                }, filters
+                self.has_filter = True
+                # pretend we support x = 1 filter but in fact we don't
+                # so we only return y = 2 filter
+                yield filters[filters.index(EqualTo(("y",), 2))]
+
+            def partitions(self):
+                assert self.has_filter
+                return super().partitions()
+
+            def read(self, partition):
+                assert self.has_filter
+                yield [1, 1]
+                yield [1, 2]
+                yield [2, 2]
+
+        class TestDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                return "test"
+
+            def schema(self):
+                return "x int, y int"
+
+            def reader(self, schema) -> "DataSourceReader":
+                return TestDataSourceReader()
+
+        self.spark.conf.set("spark.sql.python.filterPushdown.enabled", True)
+        self.spark.dataSource.register(TestDataSource)
+        df = self.spark.read.format("test").load().filter("x = 1 and y = 2")
+        # only the y = 2 filter is applied post scan
+        assertDataFrameEqual(df, [Row(x=1, y=2), Row(x=2, y=2)])
+
+    def test_extraneous_filter(self):
+        class TestDataSourceReader(DataSourceReader):
+            def pushFilters(self, filters: List[Filter]) -> Iterable[Filter]:
+                yield EqualTo(("x",), 1)
+
+            def partitions(self):
+                assert False
+
+            def read(self, partition):
+                assert False
+
+        class TestDataSource(DataSource):
+            @classmethod
+            def name(cls):
+                return "test"
+
+            def schema(self):
+                return "x int"
+
+            def reader(self, schema) -> "DataSourceReader":
+                return TestDataSourceReader()
+
+        self.spark.conf.set("spark.sql.python.filterPushdown.enabled", True)

Review Comment:
   ditto



##########
sql/core/src/test/scala/org/apache/spark/sql/execution/python/PythonDataSourceSuite.scala:
##########
@@ -213,6 +220,66 @@ class PythonDataSourceSuite extends 
PythonDataSourceSuiteBase {
       parameters = Map("inputSchema" -> "INT", "dataType" -> "\"INT\""))
   }
 
+  test("data source reader with filter pushdown") {
+    assume(shouldTestPandasUDFs)
+    val dataSourceScript =
+      s"""
+         |from pyspark.sql.datasource import (
+         |    DataSource,
+         |    DataSourceReader,
+         |    EqualTo,
+         |    InputPartition,
+         |)
+         |
+         |class SimpleDataSourceReader(DataSourceReader):
+         |    def partitions(self):
+         |        return [InputPartition(i) for i in range(2)]
+         |
+         |    def pushFilters(self, filters):
+         |        yield filters[filters.index(EqualTo(("id",), 1))]
+         |
+         |    def read(self, partition):
+         |        yield (0, partition.value)
+         |        yield (1, partition.value)
+         |        yield (2, partition.value)
+         |
+         |class SimpleDataSource(DataSource):
+         |    def schema(self):
+         |        return "id int, partition int"
+         |
+         |    def reader(self, schema):
+         |        return SimpleDataSourceReader()
+         |""".stripMargin
+    val schema = StructType.fromDDL("id INT, partition INT")
+    val dataSource =
+      createUserDefinedPythonDataSource(name = dataSourceName, pythonScript = 
dataSourceScript)
+    spark.conf.set(SQLConf.PYTHON_FILTER_PUSHDOWN_ENABLED, true)
+    spark.dataSource.registerPython(dataSourceName, dataSource)
+    val df =
+      spark.read.format(dataSourceName).schema(schema).load().filter("id = 1 
and partition = 0")
+    val plan = df.queryExecution.executedPlan
+
+    val filter = collectFirst(df.queryExecution.executedPlan) {

Review Comment:
   Please also add the executed plan in the comment here.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to