anishshri-db commented on code in PR #50600:
URL: https://github.com/apache/spark/pull/50600#discussion_r2047861034


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1369,8 +1374,212 @@ def flatten_columns(cur_batch, col_name):
         data_batches = generate_data_batches(_batches)
 
         for k, g in groupby(data_batches, key=lambda x: x[0]):
-            yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
+            yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, g)
+
+        yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
+
+        yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
+
+
+class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for
+    :meth:`pyspark.sql.GroupedData.transformWithState`.
+
+    Parameters
+    ----------
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(self, arrow_max_records_per_batch):
+        super(TransformWithStateInPySparkRowSerializer, self).__init__()
+        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.key_offsets = None
+
+    def load_stream(self, stream):
+        """
+        Read ArrowRecordBatches from stream, deserialize them to populate a 
list of data chunk, and
+        convert the data into a list of pandas.Series.
+
+        Please refer the doc of inner function `generate_data_batches` for 
more details how
+        this function works in overall.
+        """
+        from pyspark.sql.streaming.stateful_processor_util import (
+            TransformWithStateInPySparkFuncMode,
+        )
+        import itertools
+
+        def generate_data_batches(batches):
+            """
+            Deserialize ArrowRecordBatches and return a generator of Row.
+
+            The deserialization logic assumes that Arrow RecordBatches contain 
the data with the
+            ordering that data chunks for same grouping key will appear 
sequentially.
+
+            This function must avoid materializing multiple Arrow 
RecordBatches into memory at the
+            same time. And data chunks from the same grouping key should 
appear sequentially.
+            """
+            for batch in batches:
+                DataRow = Row(*(batch.column_names))
+
+                # This is supposed to be the same.
+                batch_key = tuple(batch[o][0].as_py() for o in 
self.key_offsets)
+                for row_idx in range(batch.num_rows):
+                    row = DataRow(
+                        *(batch.column(i)[row_idx].as_py() for i in 
range(batch.num_columns))
+                    )
+                    yield (batch_key, row)
+
+        _batches = super(ArrowStreamUDFSerializer, self).load_stream(stream)
+        data_batches = generate_data_batches(_batches)
+
+        for k, g in groupby(data_batches, key=lambda x: x[0]):
+            chained = itertools.chain(g)
+            chained_values = map(lambda x: x[1], chained)
+            yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, 
chained_values)
+
+        yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
+
+        yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
+
+    def dump_stream(self, iterator, stream):
+        """
+        Read through an iterator of (iterator of Row), serialize them to Arrow
+        RecordBatches, and write batches to stream.
+        """
+        import pyarrow as pa
+
+        def flatten_iterator():
+            # iterator: iter[list[(iter[Row], pdf_type)]]
+            for packed in iterator:
+                iter_row_with_type = packed[0]
+                iter_row = iter_row_with_type[0]
+                pdf_type = iter_row_with_type[1]
+
+                rows_as_dict = []
+                for row in iter_row:
+                    row_as_dict = row.asDict(True)
+                    rows_as_dict.append(row_as_dict)
+
+                pdf_schema = pa.schema(pdf_type.fields)
+                record_batch = pa.RecordBatch.from_pylist(rows_as_dict, 
schema=pdf_schema)
+
+                yield (record_batch, pdf_type)
+
+        return ArrowStreamUDFSerializer.dump_stream(self, flatten_iterator(), 
stream)
+
+
+class 
TransformWithStateInPySparkRowInitStateSerializer(TransformWithStateInPySparkRowSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for
+    
:meth:`pyspark.sql.GroupedData.transformWithStateInPySparkRowInitStateSerializer`.
+    Parameters
+    ----------
+    Same as input parameters in TransformWithStateInPySparkRowSerializer.
+    """
+
+    def __init__(self, arrow_max_records_per_batch):
+        super(TransformWithStateInPySparkRowInitStateSerializer, 
self).__init__(
+            arrow_max_records_per_batch
+        )
+        self.init_key_offsets = None
+
+    def load_stream(self, stream):
+        import itertools
+        import pyarrow as pa
+        from pyspark.sql.streaming.stateful_processor_util import (
+            TransformWithStateInPySparkFuncMode,
+        )
+
+        def generate_data_batches(batches):
+            """
+            Deserialize ArrowRecordBatches and return a generator of Row.
+            The deserialization logic assumes that Arrow RecordBatches contain 
the data with the
+            ordering that data chunks for same grouping key will appear 
sequentially.
+            See `TransformWithStateInPySparkPythonInitialStateRunner` for 
arrow batch schema sent
+             from JVM.
+            This function flatten the columns of input rows and initial state 
rows and feed them

Review Comment:
   nit: `flattens the columns of`



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1369,8 +1374,212 @@ def flatten_columns(cur_batch, col_name):
         data_batches = generate_data_batches(_batches)
 
         for k, g in groupby(data_batches, key=lambda x: x[0]):
-            yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
+            yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, g)
+
+        yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
+
+        yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
+
+
+class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for
+    :meth:`pyspark.sql.GroupedData.transformWithState`.
+
+    Parameters
+    ----------
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(self, arrow_max_records_per_batch):
+        super(TransformWithStateInPySparkRowSerializer, self).__init__()
+        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.key_offsets = None
+
+    def load_stream(self, stream):
+        """
+        Read ArrowRecordBatches from stream, deserialize them to populate a 
list of data chunk, and
+        convert the data into a list of pandas.Series.

Review Comment:
   There is no `Pandas` here any more right ?



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -380,25 +388,29 @@ def init(self, handle: StatefulProcessorHandle) -> None:
     def handleInputRows(
         self,
         key: Any,
-        rows: Iterator["PandasDataFrameLike"],
+        rows: Union[Iterator["PandasDataFrameLike"], Iterator[Row]],
         timerValues: TimerValues,
-    ) -> Iterator["PandasDataFrameLike"]:
+    ) -> Union[Iterator["PandasDataFrameLike"], Iterator[Row]]:
         """
         Function that will allow users to interact with input data rows along 
with the grouping key.
-        It should take parameters (key, Iterator[`pandas.DataFrame`]) and 
return another
-        Iterator[`pandas.DataFrame`]. For each group, all columns are passed 
together as
-        `pandas.DataFrame` to the function, and the returned 
`pandas.DataFrame` across all
-        invocations are combined as a :class:`DataFrame`. Note that the 
function should not make a
-        guess of the number of elements in the iterator. To process all data, 
the `handleInputRows`
-        function needs to iterate all elements and process them. On the other 
hand, the
-        `handleInputRows` function is not strictly required to iterate through 
all elements in the
-        iterator if it intends to read a part of data.
+
+        The type of input data and return are different by which method is 
called.

Review Comment:
   nit: same/similar as above



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -359,6 +359,14 @@ class StatefulProcessor(ABC):
     Class that represents the arbitrary stateful logic that needs to be 
provided by the user to
     perform stateful manipulations on keyed streams.
 
+    NOTE: The type of input data and return are different by which method is 
called, following:

Review Comment:
   nit: `type of input data and return type are different based on which method 
is called, such as`



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1369,8 +1374,212 @@ def flatten_columns(cur_batch, col_name):
         data_batches = generate_data_batches(_batches)
 
         for k, g in groupby(data_batches, key=lambda x: x[0]):
-            yield (TransformWithStateInPandasFuncMode.PROCESS_DATA, k, g)
+            yield (TransformWithStateInPySparkFuncMode.PROCESS_DATA, k, g)
+
+        yield (TransformWithStateInPySparkFuncMode.PROCESS_TIMER, None, None)
+
+        yield (TransformWithStateInPySparkFuncMode.COMPLETE, None, None)
+
+
+class TransformWithStateInPySparkRowSerializer(ArrowStreamUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for
+    :meth:`pyspark.sql.GroupedData.transformWithState`.
+
+    Parameters
+    ----------
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(self, arrow_max_records_per_batch):
+        super(TransformWithStateInPySparkRowSerializer, self).__init__()
+        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.key_offsets = None
+
+    def load_stream(self, stream):
+        """
+        Read ArrowRecordBatches from stream, deserialize them to populate a 
list of data chunk, and

Review Comment:
   nit: `list of data chunks`



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to