allisonwang-db commented on code in PR #49961: URL: https://github.com/apache/spark/pull/49961#discussion_r1972421058
########## python/pyspark/sql/datasource.py: ########## @@ -280,6 +306,9 @@ class DataSourceReader(ABC): .. versionadded: 4.0.0 """ + def pushdownFilters(self, filters: List["Filter"]) -> Iterable["Filter"]: Review Comment: Does it follow the same naming convention as DSv2? ########## python/pyspark/sql/streaming/python_streaming_source_runner.py: ########## @@ -69,7 +70,7 @@ def latest_offset_func(reader: DataSourceStreamReader, outfile: IO) -> None: def partitions_func( reader: DataSourceStreamReader, - data_source: DataSource, + data_source_name: str, Review Comment: why do we need this change? ########## python/pyspark/sql/worker/data_source_pushdown_filters.py: ########## @@ -0,0 +1,104 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +from typing import IO, List + +from pyspark.errors import PySparkAssertionError, PySparkValueError +from pyspark.serializers import ( + UTF8Deserializer, + read_int, + write_int, +) +from pyspark.sql.datasource import DataSourceReader, EqualTo, Filter +from pyspark.sql.worker.internal.data_source_reader_info import DataSourceReaderInfo +from pyspark.sql.worker.internal.data_source_worker import worker_main +from pyspark.worker_util import ( + read_command, + pickleSer, +) + +utf8_deserializer = UTF8Deserializer() + + +@worker_main +def main(infile: IO, outfile: IO) -> None: + # Receive the data source reader instance. + reader_info = read_command(pickleSer, infile) + if not isinstance(reader_info, DataSourceReaderInfo): + raise PySparkAssertionError( + errorClass="DATA_SOURCE_TYPE_MISMATCH", + messageParameters={ + "expected": "a Python data source reader info of type 'DataSourceReaderInfo'", + "actual": f"'{type(reader_info).__name__}'", + }, + ) + + reader = reader_info.reader + if not isinstance(reader, DataSourceReader): + raise PySparkAssertionError( + errorClass="DATA_SOURCE_TYPE_MISMATCH", + messageParameters={ + "expected": "a Python data source reader of type 'DataSourceReader'", + "actual": f"'{type(reader_info).__name__}'", + }, + ) + + # Receive the pushdown filters. + num_filters = read_int(infile) + filters: List[Filter] = [] + for _ in range(num_filters): + name = utf8_deserializer.loads(infile) + if name == "EqualTo": Review Comment: This seems fragile ########## python/pyspark/sql/tests/test_python_datasource.py: ########## @@ -246,6 +248,137 @@ def reader(self, schema) -> "DataSourceReader": assertDataFrameEqual(df, [Row(x=0, y="0"), Row(x=1, y="1")]) self.assertEqual(df.select(spark_partition_id()).distinct().count(), 2) + def test_filter_pushdown(self): + class TestDataSourceReader(DataSourceReader): + def __init__(self): + self.has_filter = False + + def pushdownFilters(self, filters: List[Filter]) -> Iterable[Filter]: + assert set(filters) == { + EqualTo(("x",), 1), + EqualTo(("y",), 2), + }, filters + self.has_filter = True + # pretend we support x = 1 filter but in fact we don't + # so we only return y = 2 filter + yield filters[filters.index(EqualTo(("y",), 2))] + + def partitions(self): + assert self.has_filter + return super().partitions() + + def read(self, partition): + assert self.has_filter + yield [1, 1] + yield [1, 2] + yield [2, 2] + + class TestDataSource(DataSource): + @classmethod + def name(cls): + return "test" + + def schema(self): + return "x int, y int" + + def reader(self, schema) -> "DataSourceReader": + return TestDataSourceReader() + + self.spark.dataSource.register(TestDataSource) + df = self.spark.read.format("test").load().filter("x = 1 and y = 2") Review Comment: I think we should also add Scala side test to check the query plan. Please see PythonDataSourceSuite.scala ########## python/pyspark/sql/datasource.py: ########## @@ -234,6 +249,35 @@ def streamReader(self, schema: StructType) -> "DataSourceStreamReader": ) +ColumnPath = Tuple[str, ...] + + +@dataclass(frozen=True) +class Filter(ABC): + """ + The base class for filters used for filter pushdown. + + .. versionadded: 4.1.0 + Review Comment: Let's also add some examples here ########## python/pyspark/sql/datasource.py: ########## @@ -234,6 +249,17 @@ def streamReader(self, schema: StructType) -> "DataSourceStreamReader": ) +@dataclass(frozen=True) +class Filter(ABC): + pass Review Comment: Can we add examples and docstring here? Let's also mention the limitations here (currently only support EqualTo) ########## python/pyspark/sql/datasource.py: ########## @@ -234,6 +249,35 @@ def streamReader(self, schema: StructType) -> "DataSourceStreamReader": ) +ColumnPath = Tuple[str, ...] + + +@dataclass(frozen=True) +class Filter(ABC): + """ + The base class for filters used for filter pushdown. + + .. versionadded: 4.1.0 + + Notes + ----- + Column references are represented as a tuple of strings. For example, the column + `col1` is represented as `("col1",)`, and the nested column `a.b.c` is + represented as `("a", "b", "c")`. + + Literal values are represented as Python objects of types such as + `int`, `float`, `str`, `bool`, `datetime`, etc. + See `Data Types <https://spark.apache.org/docs/latest/sql-ref-datatypes.html>`_ + for more information about how values are represented in Python. + """ + + +@dataclass(frozen=True) +class EqualTo(Filter): + lhsColumnPath: ColumnPath + rhsValue: Any Review Comment: Can we follow the DSv1 filter design here: https://github.com/apache/spark/blob/master/sql/catalyst/src/main/scala/org/apache/spark/sql/sources/filters.scala#L98 ########## sql/core/src/main/scala/org/apache/spark/sql/execution/datasources/v2/python/PythonScanBuilder.scala: ########## @@ -25,6 +26,24 @@ class PythonScanBuilder( ds: PythonDataSourceV2, shortName: String, outputSchema: StructType, - options: CaseInsensitiveStringMap) extends ScanBuilder { + options: CaseInsensitiveStringMap) + extends ScanBuilder + with SupportsPushDownFilters { + private var supported: Array[Filter] = Array.empty + override def build(): Scan = new PythonScan(ds, shortName, outputSchema, options) + + override def pushFilters(filters: Array[Filter]): Array[Filter] = { + val dataSource = ds.getOrCreateDataSourceInPython(shortName, options, Some(outputSchema)) + val result = ds.source.pushdownFiltersInPython(dataSource, outputSchema, filters) + ds.setDataSourceInPython(dataSource.copy(dataSource = result.dataSource)) Review Comment: So here we are updating the serialized data source instance to get the new states stored in the data source instance. Can we add some comments here (e.g when will the pushedFilters be invoked during analysis, and why do we need to set the data source instance here) ########## python/pyspark/sql/worker/internal/data_source_worker.py: ########## @@ -0,0 +1,90 @@ +# +# Licensed to the Apache Software Foundation (ASF) under one or more +# contributor license agreements. See the NOTICE file distributed with +# this work for additional information regarding copyright ownership. +# The ASF licenses this file to You under the Apache License, Version 2.0 +# (the "License"); you may not use this file except in compliance with +# the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# + +import faulthandler +from functools import wraps +import os +import sys +from typing import IO, Callable + +from pyspark.accumulators import _accumulatorRegistry +from pyspark.serializers import ( + read_int, + write_int, + SpecialLengths, +) +from pyspark.util import handle_worker_exception, local_connect_and_auth +from pyspark.worker_util import ( + check_python_version, + send_accumulator_updates, + setup_broadcasts, + setup_memory_limits, + setup_spark_files, +) + + +F = Callable[[IO, IO], None] + + +def worker_main(func: F) -> F: Review Comment: Let's avoid refactoring and cleaning up code in this PR :) -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org