anishshri-db commented on code in PR #47133:
URL: https://github.com/apache/spark/pull/47133#discussion_r1680002466


##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -358,6 +362,141 @@ def applyInPandasWithState(
         )
         return DataFrame(jdf, self.session)
 
+
+    def transformWithStateInPandas(self, 
+            stateful_processor: StatefulProcessor,
+            outputStructType: Union[StructType, str],
+            outputMode: str,
+            timeMode: str) -> DataFrame:
+        """
+        Invokes methods defined in the stateful processor used in arbitrary 
state API v2.
+        We allow the user to act on per-group set of input rows along with 
keyed state and the
+        user can choose to output/return 0 or more rows.
+
+        For a streaming dataframe, we will repeatedly invoke the interface 
methods for new rows
+        in each trigger and the user's state/state variables will be stored 
persistently across
+        invocations.
+
+        The `stateful_processor` should be a Python class that implements the 
interface defined in
+        pyspark.sql.streaming.stateful_processor. The stateful processor 
consists 3 functions:
+        `init`, `handleInputRows`, and `close`.

Review Comment:
   we also optionally have `handleInitialState`



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -33,6 +36,7 @@
         PandasCogroupedMapFunction,
         ArrowGroupedMapFunction,
         ArrowCogroupedMapFunction,
+        DataFrameLike as PandasDataFrameLike

Review Comment:
   why do we need to add this ?



##########
python/pyspark/sql/streaming/state_api_client.py:
##########
@@ -0,0 +1,142 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from enum import Enum
+import os
+import socket
+from typing import Any, Union, cast
+
+import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+from pyspark.serializers import write_int, read_int, UTF8Deserializer
+from pyspark.sql.types import StructType, _parse_datatype_string
+
+
+class StatefulProcessorHandleState(Enum):
+    CREATED = 1
+    INITIALIZED = 2
+    DATA_PROCESSED = 3
+    CLOSED = 4
+
+

Review Comment:
   is this a pattern for PySpark in general ?



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -358,6 +362,55 @@ def applyInPandasWithState(
         )
         return DataFrame(jdf, self.session)
 
+
+    def transformWithStateInPandas(self, 
+            stateful_processor: StatefulProcessor,
+            outputStructType: Union[StructType, str],
+            outputMode: str,
+            timeMode: str) -> DataFrame:
+        
+        from pyspark.sql import GroupedData
+        from pyspark.sql.functions import pandas_udf
+        assert isinstance(self, GroupedData)
+
+        def transformWithStateUDF(state_api_client: StateApiClient, key: Any,
+                                  inputRows: Iterator["PandasDataFrameLike"]) 
-> Iterator["PandasDataFrameLike"]:
+            handle = StatefulProcessorHandle(state_api_client)
+
+            print(f"checking handle state: {state_api_client.handle_state}")

Review Comment:
   @bogao007 - should we remove these now ?



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1116,3 +1121,88 @@ def init_stream_yield_batches(batches):
         batches_to_write = init_stream_yield_batches(serialize_batches())
 
         return ArrowStreamSerializer.dump_stream(self, batches_to_write, 
stream)
+
+
+class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for 
transformWithStateInPandasSerializer.
+
+    Parameters
+    ----------
+    timezone : str
+        A timezone to respect when handling timestamp values
+    safecheck : bool
+        If True, conversion from Arrow to Pandas checks for overflow/truncation
+    assign_cols_by_name : bool
+        If True, then Pandas DataFrames will get columns by name
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(
+            self,
+            timezone,
+            safecheck,
+            assign_cols_by_name,
+            arrow_max_records_per_batch):
+        super(
+            TransformWithStateInPandasSerializer,
+            self
+        ).__init__(timezone, safecheck, assign_cols_by_name)
+
+        # self.state_server_port = state_server_port
+
+        # # open client connection to state server socket
+        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.key_offsets = None
+
+    # Nothing special here, we need to create the handle and read
+    # data in groups.
+    def load_stream(self, stream):
+        """
+        Read ArrowRecordBatches from stream, deserialize them to populate a 
list of pair
+        (data chunk, state), and convert the data into a list of pandas.Series.
+
+        Please refer the doc of inner function `gen_data_and_state` for more 
details how
+        this function works in overall.
+
+        In addition, this function further groups the return of 
`gen_data_and_state` by the state
+        instance (same semantic as grouping by grouping key) and produces an 
iterator of data
+        chunks for each group, so that the caller can lazily materialize the 
data chunk.
+        """
+        import pyarrow as pa
+        from itertools import tee
+
+        def generate_data_batches(batches):
+            for batch in batches:
+                data_pandas = [self.arrow_to_pandas(c) for c in 
pa.Table.from_batches([batch]).itercolumns()]
+                key_series = [data_pandas[o] for o in self.key_offsets]
+                batch_key = tuple(s[0] for s in key_series)
+                yield (batch_key, data_pandas)
+
+        print("Generating data batches...")

Review Comment:
   nit: do we intend to keep these ?



##########
python/pyspark/sql/streaming/state_api_client.py:
##########
@@ -0,0 +1,142 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from enum import Enum
+import os
+import socket
+from typing import Any, Union, cast
+
+import pyspark.sql.streaming.StateMessage_pb2 as stateMessage
+from pyspark.serializers import write_int, read_int, UTF8Deserializer
+from pyspark.sql.types import StructType, _parse_datatype_string
+
+

Review Comment:
   nit: extra newline ?



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1116,3 +1121,88 @@ def init_stream_yield_batches(batches):
         batches_to_write = init_stream_yield_batches(serialize_batches())
 
         return ArrowStreamSerializer.dump_stream(self, batches_to_write, 
stream)
+
+
+class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for 
transformWithStateInPandasSerializer.
+
+    Parameters
+    ----------
+    timezone : str
+        A timezone to respect when handling timestamp values
+    safecheck : bool
+        If True, conversion from Arrow to Pandas checks for overflow/truncation
+    assign_cols_by_name : bool
+        If True, then Pandas DataFrames will get columns by name
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(

Review Comment:
   Is this something that each operator needs to define ? could we consolidate 
the code with other operators maybe, if possible ?



##########
python/pyspark/sql/streaming/StateMessage.proto:
##########
@@ -0,0 +1,86 @@
+syntax = "proto3";
+
+package pyspark.sql.streaming;
+
+message StateRequest {

Review Comment:
   is it possible to add some high level comments here or in some other Python 
file ?



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1116,3 +1121,88 @@ def init_stream_yield_batches(batches):
         batches_to_write = init_stream_yield_batches(serialize_batches())
 
         return ArrowStreamSerializer.dump_stream(self, batches_to_write, 
stream)
+
+
+class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for 
transformWithStateInPandasSerializer.
+
+    Parameters
+    ----------
+    timezone : str
+        A timezone to respect when handling timestamp values
+    safecheck : bool
+        If True, conversion from Arrow to Pandas checks for overflow/truncation
+    assign_cols_by_name : bool
+        If True, then Pandas DataFrames will get columns by name
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(
+            self,
+            timezone,
+            safecheck,
+            assign_cols_by_name,
+            arrow_max_records_per_batch):
+        super(
+            TransformWithStateInPandasSerializer,
+            self
+        ).__init__(timezone, safecheck, assign_cols_by_name)
+
+        # self.state_server_port = state_server_port
+
+        # # open client connection to state server socket
+        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.key_offsets = None
+
+    # Nothing special here, we need to create the handle and read
+    # data in groups.
+    def load_stream(self, stream):
+        """
+        Read ArrowRecordBatches from stream, deserialize them to populate a 
list of pair
+        (data chunk, state), and convert the data into a list of pandas.Series.
+
+        Please refer the doc of inner function `gen_data_and_state` for more 
details how
+        this function works in overall.
+
+        In addition, this function further groups the return of 
`gen_data_and_state` by the state
+        instance (same semantic as grouping by grouping key) and produces an 
iterator of data
+        chunks for each group, so that the caller can lazily materialize the 
data chunk.
+        """
+        import pyarrow as pa
+        from itertools import tee
+
+        def generate_data_batches(batches):
+            for batch in batches:
+                data_pandas = [self.arrow_to_pandas(c) for c in 
pa.Table.from_batches([batch]).itercolumns()]
+                key_series = [data_pandas[o] for o in self.key_offsets]
+                batch_key = tuple(s[0] for s in key_series)
+                yield (batch_key, data_pandas)
+
+        print("Generating data batches...")
+        _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
+        data_batches = generate_data_batches(_batches)
+
+        print("Returning data batches...")
+        for k, g in groupby(data_batches, key=lambda x: x[0]):
+            yield (k, g)
+
+
+    def dump_stream(self, iterator, stream):
+        """
+        Read through an iterator of (iterator of pandas DataFrame, state), 
serialize them to Arrow
+        RecordBatches, and write batches to stream.
+        """
+        result = [(b, t) for x in iterator for y, t in x for b in y]    
+        super().dump_stream(result, stream)
+    
+class ImplicitGroupingKeyTracker:

Review Comment:
   yea +1 - can prob remove this one ?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to