sahnib commented on code in PR #47133:
URL: https://github.com/apache/spark/pull/47133#discussion_r1679584589


##########
python/pyspark/worker.py:
##########
@@ -1609,6 +1645,35 @@ def mapper(a):
             vals = [a[o] for o in parsed_offsets[0][1]]
             return f(keys, vals)
 
+    elif eval_type == PythonEvalType.SQL_TRANSFORM_WITH_STATE:
+        from itertools import tee
+
+        # We assume there is only one UDF here because grouped map doesn't
+        # support combining multiple UDFs.
+        assert num_udfs == 1
+
+        # See FlatMapGroupsInPandasExec for how arg_offsets are used to

Review Comment:
   [nit] See TransformWithStateInPandasExec. 



##########
python/pyspark/sql/streaming/stateful_processor.py:
##########
@@ -0,0 +1,101 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+from abc import ABC, abstractmethod
+from typing import Any, TYPE_CHECKING, Iterator, Union
+
+from pyspark.sql.streaming.state_api_client import StateApiClient
+from pyspark.sql.streaming.value_state_client import ValueStateClient
+
+import pandas as pd
+from pyspark.sql.types import (
+    StructType, StructField, IntegerType, LongType, ShortType,
+    FloatType, DoubleType, DecimalType, StringType, BooleanType,
+    DateType, TimestampType
+)
+
+if TYPE_CHECKING:
+    from pyspark.sql.pandas._typing import DataFrameLike as PandasDataFrameLike
+
+
+class ValueState:

Review Comment:
   We would need docStrings here as users would refer to these classes to 
understand how different State variables work. 



##########
sql/api/src/main/scala/org/apache/spark/sql/catalyst/streaming/InternalTimeModes.scala:
##########
@@ -0,0 +1,49 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.catalyst.streaming
+
+import java.util.Locale
+
+import org.apache.spark.SparkIllegalArgumentException
+import org.apache.spark.sql.streaming.TimeMode
+
+/**
+ * Internal helper class to generate objects representing various `TimeMode`s,
+ */
+private[sql] object InternalTimeModes {

Review Comment:
   this is mainly for parsing from a String to TimeMode? If so, should we add 
this in the `TimeMode.scala` file? 



##########
python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py:
##########
@@ -0,0 +1,152 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+import random
+import shutil
+import string
+import sys
+import tempfile
+import pandas as pd
+from pyspark.sql.streaming import StatefulProcessor, StatefulProcessorHandle
+from typing import Iterator
+
+import unittest
+from typing import cast
+
+from pyspark import SparkConf
+from pyspark.sql.streaming.state import GroupStateTimeout, GroupState
+from pyspark.sql.types import (
+    LongType,
+    StringType,
+    StructType,
+    StructField,
+    Row,
+)
+from pyspark.testing.sqlutils import (
+    ReusedSQLTestCase,
+    have_pandas,
+    have_pyarrow,
+    pandas_requirement_message,
+    pyarrow_requirement_message,
+)
+from pyspark.testing.utils import eventually
+
+
+@unittest.skipIf(
+    not have_pandas or not have_pyarrow,
+    cast(str, pandas_requirement_message or pyarrow_requirement_message),
+)
+class TransformWithStateInPandasTestsMixin:
+    @classmethod
+    def conf(cls):
+        cfg = SparkConf()
+        cfg.set("spark.sql.shuffle.partitions", "5")
+        cfg.set("spark.sql.streaming.stateStore.providerClass",
+                
"org.apache.spark.sql.execution.streaming.state.RocksDBStateStoreProvider")
+        return cfg
+
+    def _test_apply_in_pandas_with_state_basic(self, func, check_results):
+        input_path = tempfile.mkdtemp()
+
+        def prepare_test_resource():
+            with open(input_path + "/text-test.txt", "w") as fw:
+                fw.write("hello\n")
+                fw.write("this\n")
+
+        prepare_test_resource()
+
+        df = self.spark.readStream.format("text").load(input_path)
+
+        for q in self.spark.streams.active:
+            q.stop()
+        self.assertTrue(df.isStreaming)
+
+        output_type = StructType(
+            [StructField("key", StringType()), StructField("countAsString", 
StringType())]
+        )
+        state_type = StructType([StructField("c", LongType())])
+
+        q = (
+            df.groupBy(df["value"])
+            .transformWithStateInPandas(stateful_processor = 
SimpleStatefulProcessor(),
+                                        outputStructType=output_type,
+                                        outputMode="Update",
+                                        timeMode="None")
+            .writeStream.queryName("this_query")
+            .foreachBatch(check_results)
+            .outputMode("update")
+            .start()
+        )
+
+        self.assertEqual(q.name, "this_query")
+        self.assertTrue(q.isActive)
+        q.processAllAvailable()
+        self.assertTrue(q.exception() is None)
+
+    def test_apply_in_pandas_with_state_basic(self):
+        def func(key, pdf_iter, state):
+            assert isinstance(state, GroupState)
+
+            total_len = 0
+            for pdf in pdf_iter:
+                total_len += len(pdf)
+
+            state.update((total_len,))
+            assert state.get[0] == 1
+            yield pd.DataFrame({"key": [key[0]], "countAsString": 
[str(total_len)]})
+
+        def check_results(batch_df, _):
+            assert set(batch_df.sort("key").collect()) == {
+                Row(key="hello", countAsString="1"),
+                Row(key="this", countAsString="1"),
+            }
+
+        self._test_apply_in_pandas_with_state_basic(func, check_results)
+
+
+class SimpleStatefulProcessor(StatefulProcessor):
+  def init(self, handle: StatefulProcessorHandle) -> None:
+    state_schema = StructType([
+      StructField("value", StringType(), True)
+    ])
+    self.value_state = handle.getValueState("testValueState", state_schema)
+  def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+    self.value_state.update("test_value")
+    exists = self.value_state.exists()
+    value = self.value_state.get()
+    self.value_state.clear()

Review Comment:
   What are we testing here? Can we add test to ensure different values are set 
properly in `ValueState`, and values read after setting are correct. 



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1116,3 +1121,88 @@ def init_stream_yield_batches(batches):
         batches_to_write = init_stream_yield_batches(serialize_batches())
 
         return ArrowStreamSerializer.dump_stream(self, batches_to_write, 
stream)
+
+
+class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for 
transformWithStateInPandasSerializer.
+
+    Parameters
+    ----------
+    timezone : str
+        A timezone to respect when handling timestamp values
+    safecheck : bool
+        If True, conversion from Arrow to Pandas checks for overflow/truncation
+    assign_cols_by_name : bool
+        If True, then Pandas DataFrames will get columns by name
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(
+            self,
+            timezone,
+            safecheck,
+            assign_cols_by_name,
+            arrow_max_records_per_batch):
+        super(
+            TransformWithStateInPandasSerializer,
+            self
+        ).__init__(timezone, safecheck, assign_cols_by_name)
+
+        # self.state_server_port = state_server_port
+
+        # # open client connection to state server socket

Review Comment:
   [nit] redundant comments? 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/PythonArrowInput.scala:
##########
@@ -71,17 +71,17 @@ private[python] trait PythonArrowInput[IN] { self: 
BasePythonRunner[IN, _] =>
   protected val root = VectorSchemaRoot.create(arrowSchema, allocator)
   protected var writer: ArrowStreamWriter = _
 
-protected def close(): Unit = {
-  Utils.tryWithSafeFinally {
-    // end writes footer to the output stream and doesn't clean any resources.
-    // It could throw exception if the output stream is closed, so it should be
-    // in the try block.
-    writer.end()
-  } {
-    root.close()
-    allocator.close()
+  protected def close(): Unit = {

Review Comment:
   thanks for fixing the indentation here. 



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -358,6 +362,141 @@ def applyInPandasWithState(
         )
         return DataFrame(jdf, self.session)
 
+
+    def transformWithStateInPandas(self, 
+            stateful_processor: StatefulProcessor,
+            outputStructType: Union[StructType, str],
+            outputMode: str,
+            timeMode: str) -> DataFrame:
+        """
+        Invokes methods defined in the stateful processor used in arbitrary 
state API v2.
+        We allow the user to act on per-group set of input rows along with 
keyed state and the
+        user can choose to output/return 0 or more rows.
+
+        For a streaming dataframe, we will repeatedly invoke the interface 
methods for new rows
+        in each trigger and the user's state/state variables will be stored 
persistently across
+        invocations.
+
+        The `stateful_processor` should be a Python class that implements the 
interface defined in
+        pyspark.sql.streaming.stateful_processor. The stateful processor 
consists 3 functions:
+        `init`, `handleInputRows`, and `close`.
+
+        The `init` function will be invoked as the first method that allows 
for users to initialize
+        all their state variables and perform other init actions before 
handling data.
+
+        The `handleInputRows` function will allow users to interact with input 
data rows. It should
+        take parameters (key, Iterator[`pandas.DataFrame`]) and return another
+        Iterator[`pandas.DataFrame`]. For each group, all columns are passed 
together as
+        `pandas.DataFrame` to the `handleInputRows` function, and the returned 
`pandas.DataFrame`
+        across all invocations are combined as a :class:`DataFrame`. Note that 
the `handleInputRows`
+        function should not make a guess of the number of elements in the 
iterator. To process all
+        data, the `handleInputRows` function needs to iterate all elements and 
process them. On the
+        other hand, the `handleInputRows` function is not strictly required 
toiterate through all
+        elements in the iterator if it intends to read a part of data.

Review Comment:
   Thanks for adding the documentation. 
   
   [nit] `toiterate` -> `to iterate`. 
   Should we keep the documentation for StatefulProcessor class when its 
defined? Here we can just mention that user needs to extend the abstract class 
StatefulProcessor.  
   



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1116,3 +1121,88 @@ def init_stream_yield_batches(batches):
         batches_to_write = init_stream_yield_batches(serialize_batches())
 
         return ArrowStreamSerializer.dump_stream(self, batches_to_write, 
stream)
+
+
+class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for 
transformWithStateInPandasSerializer.
+
+    Parameters
+    ----------
+    timezone : str
+        A timezone to respect when handling timestamp values
+    safecheck : bool
+        If True, conversion from Arrow to Pandas checks for overflow/truncation
+    assign_cols_by_name : bool
+        If True, then Pandas DataFrames will get columns by name
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(
+            self,
+            timezone,
+            safecheck,
+            assign_cols_by_name,
+            arrow_max_records_per_batch):
+        super(
+            TransformWithStateInPandasSerializer,
+            self
+        ).__init__(timezone, safecheck, assign_cols_by_name)
+
+        # self.state_server_port = state_server_port
+
+        # # open client connection to state server socket
+        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.key_offsets = None
+
+    # Nothing special here, we need to create the handle and read
+    # data in groups.
+    def load_stream(self, stream):
+        """
+        Read ArrowRecordBatches from stream, deserialize them to populate a 
list of pair
+        (data chunk, state), and convert the data into a list of pandas.Series.
+
+        Please refer the doc of inner function `gen_data_and_state` for more 
details how
+        this function works in overall.
+
+        In addition, this function further groups the return of 
`gen_data_and_state` by the state
+        instance (same semantic as grouping by grouping key) and produces an 
iterator of data
+        chunks for each group, so that the caller can lazily materialize the 
data chunk.

Review Comment:
   It seems like this documentation is referring to the ApplyInPandasWithState 
serializer which transfers both state and data. 



##########
sql/core/src/main/java/org/apache/spark/sql/execution/streaming/StateMessage.proto:
##########
@@ -0,0 +1,86 @@
+syntax = "proto3";
+
+package org.apache.spark.sql.execution.streaming.state;

Review Comment:
   I think the Poc has this file twice too, if I recall correctly I needed a 
package prefix for Java (and not Python). Do you know if we can just use a 
single file, or do we need to duplicate the definitions for Java/Python? 



##########
python/pyspark/sql/pandas/group_ops.py:
##########
@@ -358,6 +362,141 @@ def applyInPandasWithState(
         )
         return DataFrame(jdf, self.session)
 
+
+    def transformWithStateInPandas(self, 
+            stateful_processor: StatefulProcessor,
+            outputStructType: Union[StructType, str],
+            outputMode: str,
+            timeMode: str) -> DataFrame:
+        """
+        Invokes methods defined in the stateful processor used in arbitrary 
state API v2.
+        We allow the user to act on per-group set of input rows along with 
keyed state and the
+        user can choose to output/return 0 or more rows.
+
+        For a streaming dataframe, we will repeatedly invoke the interface 
methods for new rows
+        in each trigger and the user's state/state variables will be stored 
persistently across
+        invocations.
+
+        The `stateful_processor` should be a Python class that implements the 
interface defined in
+        pyspark.sql.streaming.stateful_processor. The stateful processor 
consists 3 functions:
+        `init`, `handleInputRows`, and `close`.
+
+        The `init` function will be invoked as the first method that allows 
for users to initialize
+        all their state variables and perform other init actions before 
handling data.
+
+        The `handleInputRows` function will allow users to interact with input 
data rows. It should
+        take parameters (key, Iterator[`pandas.DataFrame`]) and return another
+        Iterator[`pandas.DataFrame`]. For each group, all columns are passed 
together as
+        `pandas.DataFrame` to the `handleInputRows` function, and the returned 
`pandas.DataFrame`
+        across all invocations are combined as a :class:`DataFrame`. Note that 
the `handleInputRows`
+        function should not make a guess of the number of elements in the 
iterator. To process all
+        data, the `handleInputRows` function needs to iterate all elements and 
process them. On the
+        other hand, the `handleInputRows` function is not strictly required 
toiterate through all
+        elements in the iterator if it intends to read a part of data.
+
+        The `close` function will be called as the last method that allows for 
users to perform any
+        cleanup or teardown operations.
+
+        The `outputStructType` should be a :class:`StructType` describing the 
schema of all
+        elements in the returned value, `pandas.DataFrame`. The column labels 
of all elements in
+        returned `pandas.DataFrame` must either match the field names in the 
defined schema if
+        specified as strings, or match the field data types by position if not 
strings,
+        e.g. integer indices.
+
+        The size of each `pandas.DataFrame` in both the input and output can 
be arbitrary. The
+        number of `pandas.DataFrame` in both the input and output can also be 
arbitrary.
+
+        .. versionadded:: 4.0.0
+
+        Parameters
+        ----------
+        stateful_processor : StatefulProcessor
+            Instance of statefulProcessor whose functions will be invoked by 
the operator.
+        outputStructType : :class:`pyspark.sql.types.DataType` or str
+            the type of the output records. The value can be either a
+            :class:`pyspark.sql.types.DataType` object or a DDL-formatted type 
string.
+        outputMode : str
+            the output mode of the stateful processor.
+        timeMode : str
+            The time mode semantics of the stateful processor for timers and 
TTL.
+
+        Examples
+        --------
+        >>> import pandas as pd
+        >>> from pyspark.sql.streaming import StatefulProcessor, 
StatefulProcessorHandle
+        >>> from pyspark.sql.types import StructType, StructField, LongType, 
StringType
+        >>> from typing import Iterator
+        >>> output_schema = StructType([
+        ...     StructField("value", LongType(), True)
+        ... ])
+        >>> state_schema = StructType([
+        ...     StructField("value", StringType(), True)
+        ... ])
+        >>> class SimpleStatefulProcessor(StatefulProcessor):
+        ...   def init(self, handle: StatefulProcessorHandle) -> None:
+        ...     self.value_state = handle.getValueState("testValueState", 
state_schema)
+        ...   def handleInputRows(self, key, rows) -> Iterator[pd.DataFrame]:
+        ...     self.value_state.update("test_value")
+        ...     exists = self.value_state.exists()
+        ...     value = self.value_state.get()
+        ...     self.value_state.clear()
+        ...     return rows
+        ...   def close(self) -> None:
+        ...     pass

Review Comment:
   [nit] It might be more useful to provide a running count example, where we 
store values above a specified threshold in the state (to keep track of 
violations). [something like processing temperature sensor values in a stream] 



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -1116,3 +1121,88 @@ def init_stream_yield_batches(batches):
         batches_to_write = init_stream_yield_batches(serialize_batches())
 
         return ArrowStreamSerializer.dump_stream(self, batches_to_write, 
stream)
+
+
+class TransformWithStateInPandasSerializer(ArrowStreamPandasUDFSerializer):
+    """
+    Serializer used by Python worker to evaluate UDF for 
transformWithStateInPandasSerializer.
+
+    Parameters
+    ----------
+    timezone : str
+        A timezone to respect when handling timestamp values
+    safecheck : bool
+        If True, conversion from Arrow to Pandas checks for overflow/truncation
+    assign_cols_by_name : bool
+        If True, then Pandas DataFrames will get columns by name
+    arrow_max_records_per_batch : int
+        Limit of the number of records that can be written to a single 
ArrowRecordBatch in memory.
+    """
+
+    def __init__(
+            self,
+            timezone,
+            safecheck,
+            assign_cols_by_name,
+            arrow_max_records_per_batch):
+        super(
+            TransformWithStateInPandasSerializer,
+            self
+        ).__init__(timezone, safecheck, assign_cols_by_name)
+
+        # self.state_server_port = state_server_port
+
+        # # open client connection to state server socket
+        self.arrow_max_records_per_batch = arrow_max_records_per_batch
+        self.key_offsets = None
+
+    # Nothing special here, we need to create the handle and read
+    # data in groups.
+    def load_stream(self, stream):
+        """
+        Read ArrowRecordBatches from stream, deserialize them to populate a 
list of pair
+        (data chunk, state), and convert the data into a list of pandas.Series.
+
+        Please refer the doc of inner function `gen_data_and_state` for more 
details how
+        this function works in overall.
+
+        In addition, this function further groups the return of 
`gen_data_and_state` by the state
+        instance (same semantic as grouping by grouping key) and produces an 
iterator of data
+        chunks for each group, so that the caller can lazily materialize the 
data chunk.
+        """
+        import pyarrow as pa
+        from itertools import tee
+
+        def generate_data_batches(batches):
+            for batch in batches:
+                data_pandas = [self.arrow_to_pandas(c) for c in 
pa.Table.from_batches([batch]).itercolumns()]
+                key_series = [data_pandas[o] for o in self.key_offsets]
+                batch_key = tuple(s[0] for s in key_series)
+                yield (batch_key, data_pandas)
+
+        print("Generating data batches...")
+        _batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
+        data_batches = generate_data_batches(_batches)
+
+        print("Returning data batches...")
+        for k, g in groupby(data_batches, key=lambda x: x[0]):
+            yield (k, g)
+
+
+    def dump_stream(self, iterator, stream):
+        """
+        Read through an iterator of (iterator of pandas DataFrame, state), 
serialize them to Arrow
+        RecordBatches, and write batches to stream.
+        """
+        result = [(b, t) for x in iterator for y, t in x for b in y]    
+        super().dump_stream(result, stream)
+    
+class ImplicitGroupingKeyTracker:

Review Comment:
   I dont think we are using this anymore. Probably safe to remove. 



##########
sql/catalyst/src/main/scala/org/apache/spark/sql/catalyst/encoders/ExpressionEncoder.scala:
##########
@@ -201,7 +202,7 @@ object ExpressionEncoder {
    * object.  Thus, the caller should copy the result before making another 
call if required.
    */
   class Serializer[T](private val expressions: Seq[Expression])
-    extends (T => InternalRow) with Serializable {
+    extends (T => InternalRow) with Serializable with Logging {

Review Comment:
   [nit] intentional? 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.DataOutputStream
+import java.nio.file.{Files, Path}
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.ExecutionContext
+
+import jnr.unixsocket.UnixServerSocketChannel
+import jnr.unixsocket.UnixSocketAddress
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.metric.SQLMetric
+import 
org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner.{InType,
 OutType}
+import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Python runner implementation for TransformWithStateInPandas.
+ */
+class TransformWithStateInPandasPythonRunner(
+    funcs: Seq[(ChainedPythonFunctions, Long)],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    _schema: StructType,
+    processorHandle: StatefulProcessorHandleImpl,
+    _timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    override val pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String],
+    groupingKeySchema: StructType)
+  extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
+    with PythonArrowInput[InType]
+    with BasicPythonArrowOutput
+    with Logging {
+
+  private val sqlConf = SQLConf.get
+  private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+
+  private val serverId = 
TransformWithStateInPandasStateServer.allocateServerId()
+
+  private val socketPath = s"./uds_$serverId.sock"

Review Comment:
   We should append the operator name on socket path to ensure it does not 
conflict with other operators in future. 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasPythonRunner.scala:
##########
@@ -0,0 +1,175 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.DataOutputStream
+import java.nio.file.{Files, Path}
+import java.util.concurrent.TimeUnit
+
+import scala.concurrent.ExecutionContext
+
+import jnr.unixsocket.UnixServerSocketChannel
+import jnr.unixsocket.UnixSocketAddress
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+
+import org.apache.spark.TaskContext
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions}
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.metric.SQLMetric
+import 
org.apache.spark.sql.execution.python.TransformWithStateInPandasPythonRunner.{InType,
 OutType}
+import org.apache.spark.sql.execution.streaming.StatefulProcessorHandleImpl
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.vectorized.ColumnarBatch
+import org.apache.spark.util.ThreadUtils
+
+/**
+ * Python runner implementation for TransformWithStateInPandas.
+ */
+class TransformWithStateInPandasPythonRunner(
+    funcs: Seq[(ChainedPythonFunctions, Long)],
+    evalType: Int,
+    argOffsets: Array[Array[Int]],
+    _schema: StructType,
+    processorHandle: StatefulProcessorHandleImpl,
+    _timeZoneId: String,
+    initialWorkerConf: Map[String, String],
+    override val pythonMetrics: Map[String, SQLMetric],
+    jobArtifactUUID: Option[String],
+    groupingKeySchema: StructType)
+  extends BasePythonRunner[InType, OutType](funcs.map(_._1), evalType, 
argOffsets, jobArtifactUUID)
+    with PythonArrowInput[InType]
+    with BasicPythonArrowOutput
+    with Logging {
+
+  private val sqlConf = SQLConf.get
+  private val arrowMaxRecordsPerBatch = sqlConf.arrowMaxRecordsPerBatch
+
+  private val serverId = 
TransformWithStateInPandasStateServer.allocateServerId()
+
+  private val socketPath = s"./uds_$serverId.sock"
+
+  override protected val workerConf: Map[String, String] = initialWorkerConf +
+    (SQLConf.ARROW_EXECUTION_MAX_RECORDS_PER_BATCH.key -> 
arrowMaxRecordsPerBatch.toString)
+
+  private val arrowWriter: arrow.ArrowWriter = ArrowWriter.create(root)
+
+  // Use lazy val to initialize the fields before these are accessed in 
[[PythonArrowInput]]'s
+  // constructor.
+  override protected lazy val schema: StructType = _schema
+  override protected lazy val timeZoneId: String = _timeZoneId
+  override protected val errorOnDuplicatedFieldNames: Boolean = true
+  override protected val largeVarTypes: Boolean = sqlConf.arrowUseLargeVarTypes
+
+  override protected def handleMetadataBeforeExec(stream: DataOutputStream): 
Unit = {
+    super.handleMetadataBeforeExec(stream)
+    // Also write the port number for state server
+    stream.writeInt(serverId)
+  }
+
+  override def compute(
+      inputIterator: Iterator[InType],
+      partitionIndex: Int,
+      context: TaskContext): Iterator[OutType] = {
+    var serverChannel: UnixServerSocketChannel = null
+    var failed = false
+    try {
+      val socketFile = Path.of(socketPath)
+      Files.deleteIfExists(socketFile)
+      val serverAddress = new UnixSocketAddress(socketPath)
+      serverChannel = UnixServerSocketChannel.open()
+      serverChannel.socket().bind(serverAddress)
+    } catch {
+      case e: Exception =>
+        failed = true
+        throw e
+    } finally {
+      if (failed) {
+        closeServerSocketChannelSilently(serverChannel)
+      }
+    }
+
+    val executor = 
ThreadUtils.newDaemonSingleThreadExecutor("stateConnectionListenerThread")
+    val executionContext = ExecutionContext.fromExecutor(executor)
+
+    executionContext.execute(
+      new TransformWithStateInPandasStateServer(serverChannel, processorHandle,
+        groupingKeySchema))
+
+    context.addTaskCompletionListener[Unit] { _ =>
+      logWarning(s"completion listener called")
+      executor.awaitTermination(10, TimeUnit.SECONDS)
+      executor.shutdownNow()
+      val socketFile = Path.of(socketPath)
+      Files.deleteIfExists(socketFile)
+    }
+
+    super.compute(inputIterator, partitionIndex, context)
+  }
+
+  private def closeServerSocketChannelSilently(serverChannel: 
UnixServerSocketChannel): Unit = {
+    try {
+      logWarning(s"closing the state server socket")
+      serverChannel.close()
+    } catch {
+      case e: Exception =>
+        logError(s"failed to close state server socket", e)
+    }
+  }
+
+  override protected def writeUDF(dataOut: DataOutputStream): Unit = {
+    PythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets, None)
+  }
+
+  override protected def writeNextInputToArrowStream(
+      root: VectorSchemaRoot,
+      writer: ArrowStreamWriter,
+      dataOut: DataOutputStream,
+      inputIterator: Iterator[InType]): Boolean = {
+
+    if (inputIterator.hasNext) {
+      val startData = dataOut.size()
+      val next = inputIterator.next()
+      val nextBatch = next._2
+
+      while (nextBatch.hasNext) {
+        arrowWriter.write(nextBatch.next())
+      }

Review Comment:
   This is sending the entire data for a key to Python client, which would 
result in data for entire grouping key to be buffered in memory. We should 
chunk it (similar to ApplyInPandasWithStatePythonRunner). 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{DataInputStream, DataOutputStream, EOFException}
+import java.nio.channels.Channels
+
+import scala.collection.mutable
+
+import com.google.protobuf.ByteString
+import jnr.unixsocket.UnixServerSocketChannel
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Encoder, Encoders, Row}
+import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, StatefulProcessorHandleState}
+import 
org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, 
ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, 
StateVariableRequest, ValueStateCall}
+import org.apache.spark.sql.streaming.ValueState
+import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, 
FloatType, IntegerType, LongType, StructType}
+
+/**
+ * This class is used to handle the state requests from the Python side.
+ */
+class TransformWithStateInPandasStateServer(
+    private val serverChannel: UnixServerSocketChannel,
+    private val statefulProcessorHandle: StatefulProcessorHandleImpl,
+    private val groupingKeySchema: StructType)
+  extends Runnable
+  with Logging{
+
+  private var inputStream: DataInputStream = _
+  private var outputStream: DataOutputStream = _
+
+  private val valueStates = mutable.HashMap[String, ValueState[Any]]()
+
+  def run(): Unit = {
+    logWarning(s"Waiting for connection from Python worker")
+    val channel = serverChannel.accept()
+    logWarning(s"listening on channel - ${channel.getLocalAddress}")
+
+    inputStream = new DataInputStream(
+      Channels.newInputStream(channel))
+    outputStream = new DataOutputStream(
+      Channels.newOutputStream(channel)
+    )
+
+    while (channel.isConnected &&
+      statefulProcessorHandle.getHandleState != 
StatefulProcessorHandleState.CLOSED) {
+
+      try {
+        logWarning(s"reading the version")
+        val version = inputStream.readInt()
+
+        if (version != -1) {
+          logWarning(s"version = ${version}")
+          assert(version == 0)
+          val messageLen = inputStream.readInt()
+          logWarning(s"parsing a message of ${messageLen} bytes")
+
+          val messageBytes = new Array[Byte](messageLen)
+          inputStream.read(messageBytes)
+          logWarning(s"read bytes = ${messageBytes.mkString("Array(", ", ", 
")")}")
+
+          val message = 
StateRequest.parseFrom(ByteString.copyFrom(messageBytes))
+
+          logWarning(s"read message = $message")
+          handleRequest(message)
+          logWarning(s"flush output stream")
+
+          outputStream.flush()
+        }
+      } catch {
+        case _: EOFException =>
+          logWarning(s"No more data to read from the socket")
+          
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
+          return
+        case e: Exception =>
+          logWarning(s"Error reading message: ${e.getMessage}")
+          sendResponse(1, e.getMessage)
+          outputStream.flush()
+          
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
+          return
+      }
+    }
+    logWarning(s"done from the state server thread")
+  }
+
+  private def handleRequest(message: StateRequest): Unit = {

Review Comment:
   I think we should split this method into sub-methods. A starting point would 
be to split using methodCase. 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{DataInputStream, DataOutputStream, EOFException}
+import java.nio.channels.Channels
+
+import scala.collection.mutable
+
+import com.google.protobuf.ByteString
+import jnr.unixsocket.UnixServerSocketChannel
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Encoder, Encoders, Row}
+import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, StatefulProcessorHandleState}
+import 
org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, 
ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, 
StateVariableRequest, ValueStateCall}
+import org.apache.spark.sql.streaming.ValueState
+import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, 
FloatType, IntegerType, LongType, StructType}
+
+/**
+ * This class is used to handle the state requests from the Python side.
+ */
+class TransformWithStateInPandasStateServer(
+    private val serverChannel: UnixServerSocketChannel,
+    private val statefulProcessorHandle: StatefulProcessorHandleImpl,
+    private val groupingKeySchema: StructType)
+  extends Runnable
+  with Logging{
+
+  private var inputStream: DataInputStream = _
+  private var outputStream: DataOutputStream = _
+
+  private val valueStates = mutable.HashMap[String, ValueState[Any]]()
+
+  def run(): Unit = {
+    logWarning(s"Waiting for connection from Python worker")
+    val channel = serverChannel.accept()
+    logWarning(s"listening on channel - ${channel.getLocalAddress}")
+
+    inputStream = new DataInputStream(
+      Channels.newInputStream(channel))
+    outputStream = new DataOutputStream(
+      Channels.newOutputStream(channel)
+    )
+
+    while (channel.isConnected &&
+      statefulProcessorHandle.getHandleState != 
StatefulProcessorHandleState.CLOSED) {
+
+      try {
+        logWarning(s"reading the version")
+        val version = inputStream.readInt()
+
+        if (version != -1) {
+          logWarning(s"version = ${version}")
+          assert(version == 0)
+          val messageLen = inputStream.readInt()
+          logWarning(s"parsing a message of ${messageLen} bytes")
+
+          val messageBytes = new Array[Byte](messageLen)
+          inputStream.read(messageBytes)
+          logWarning(s"read bytes = ${messageBytes.mkString("Array(", ", ", 
")")}")
+
+          val message = 
StateRequest.parseFrom(ByteString.copyFrom(messageBytes))
+
+          logWarning(s"read message = $message")
+          handleRequest(message)
+          logWarning(s"flush output stream")
+
+          outputStream.flush()
+        }
+      } catch {
+        case _: EOFException =>
+          logWarning(s"No more data to read from the socket")
+          
statefulProcessorHandle.setHandleState(StatefulProcessorHandleState.CLOSED)
+          return
+        case e: Exception =>
+          logWarning(s"Error reading message: ${e.getMessage}")

Review Comment:
   [nit] should be logged as Error. 



##########
sql/core/src/main/scala/org/apache/spark/sql/execution/python/TransformWithStateInPandasStateServer.scala:
##########
@@ -0,0 +1,287 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark.sql.execution.python
+
+import java.io.{DataInputStream, DataOutputStream, EOFException}
+import java.nio.channels.Channels
+
+import scala.collection.mutable
+
+import com.google.protobuf.ByteString
+import jnr.unixsocket.UnixServerSocketChannel
+
+import org.apache.spark.internal.Logging
+import org.apache.spark.sql.{Encoder, Encoders, Row}
+import org.apache.spark.sql.execution.streaming.{ImplicitGroupingKeyTracker, 
StatefulProcessorHandleImpl, StatefulProcessorHandleState}
+import 
org.apache.spark.sql.execution.streaming.state.StateMessage.{HandleState, 
ImplicitGroupingKeyRequest, StatefulProcessorCall, StateRequest, StateResponse, 
StateVariableRequest, ValueStateCall}
+import org.apache.spark.sql.streaming.ValueState
+import org.apache.spark.sql.types.{BooleanType, DataType, DoubleType, 
FloatType, IntegerType, LongType, StructType}
+
+/**
+ * This class is used to handle the state requests from the Python side.
+ */
+class TransformWithStateInPandasStateServer(
+    private val serverChannel: UnixServerSocketChannel,
+    private val statefulProcessorHandle: StatefulProcessorHandleImpl,
+    private val groupingKeySchema: StructType)
+  extends Runnable
+  with Logging{
+
+  private var inputStream: DataInputStream = _
+  private var outputStream: DataOutputStream = _
+
+  private val valueStates = mutable.HashMap[String, ValueState[Any]]()
+
+  def run(): Unit = {
+    logWarning(s"Waiting for connection from Python worker")
+    val channel = serverChannel.accept()
+    logWarning(s"listening on channel - ${channel.getLocalAddress}")
+
+    inputStream = new DataInputStream(
+      Channels.newInputStream(channel))
+    outputStream = new DataOutputStream(
+      Channels.newOutputStream(channel)
+    )
+
+    while (channel.isConnected &&
+      statefulProcessorHandle.getHandleState != 
StatefulProcessorHandleState.CLOSED) {
+
+      try {
+        logWarning(s"reading the version")
+        val version = inputStream.readInt()
+
+        if (version != -1) {
+          logWarning(s"version = ${version}")
+          assert(version == 0)
+          val messageLen = inputStream.readInt()
+          logWarning(s"parsing a message of ${messageLen} bytes")
+
+          val messageBytes = new Array[Byte](messageLen)
+          inputStream.read(messageBytes)
+          logWarning(s"read bytes = ${messageBytes.mkString("Array(", ", ", 
")")}")
+
+          val message = 
StateRequest.parseFrom(ByteString.copyFrom(messageBytes))
+
+          logWarning(s"read message = $message")

Review Comment:
   We can create a sub method for reading and parsing the Protobuf message. 



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org


Reply via email to