bogao007 commented on code in PR #49560: URL: https://github.com/apache/spark/pull/49560#discussion_r1938092668
########## python/pyspark/sql/tests/pandas/test_pandas_transform_with_state.py: ########## @@ -317,21 +313,16 @@ def check_results(batch_df, batch_id): Row(id="0", countAsString="2"), Row(id="1", countAsString="2"), } - elif batch_id == 1: + else: assert set(batch_df.sort("id").collect()) == { Row(id="0", countAsString="3"), Row(id="1", countAsString="2"), } - else: - for q in self.spark.streams.active: - q.stop() self._test_transform_with_state_in_pandas_basic( SimpleTTLStatefulProcessor(), check_results, False, "processingTime" ) - # TODO SPARK-50908 holistic fix for TTL suite - @unittest.skip("test is flaky and it is only a timing issue, skipping until we can resolve") Review Comment: Are we trying to fix the TTL flaky test here? If so, maybe better to create a separate PR for this and track that TODO Jira? ########## dev/sparktestsupport/modules.py: ########## @@ -1096,6 +1096,7 @@ def __hash__(self): "pyspark.sql.tests.connect.pandas.test_parity_pandas_udf_scalar", "pyspark.sql.tests.connect.pandas.test_parity_pandas_udf_grouped_agg", "pyspark.sql.tests.connect.pandas.test_parity_pandas_udf_window", + "pyspark.sql.tests.connect.pandas.test_parity_pandas_transform_with_state", Review Comment: Have we verified that this test actually run in CI? ########## sql/connect/server/src/main/scala/org/apache/spark/sql/connect/planner/SparkConnectPlanner.scala: ########## @@ -1021,6 +1023,55 @@ class SparkConnectPlanner( .logicalPlan } + private def transformTransformWithStateInPandas( + rel: proto.TransformWithStateInPandas): LogicalPlan = { + val pythonUdf = transformPythonUDF(rel.getTransformWithStateUdf) + val cols = + rel.getGroupingExpressionsList.asScala.toSeq.map(expr => Column(transformExpression(expr))) + + val outputSchema = parseSchema(rel.getOutputSchema) + + if (rel.hasInitialInput) { + val initialGroupingCols = rel.getInitialGroupingExpressionsList.asScala.toSeq.map(expr => + Column(transformExpression(expr))) + + val input = Dataset + .ofRows(session, transformRelation(rel.getInput)) + .groupBy(cols: _*) Review Comment: Nit: Can we share this part in both conditions to avoid some duplicated code? ########## python/pyspark/sql/streaming/stateful_processor_util.py: ########## @@ -26,3 +42,185 @@ class TransformWithStateInPandasFuncMode(Enum): PROCESS_TIMER = 2 COMPLETE = 3 PRE_INIT = 4 + + +class TransformWithStateInPandasUdfUtils: Review Comment: I saw we are sharing this UDF definition util in both client and server, would that bring any compatibility issue since they are directly copied from the server? e.g. the import `pyspark.sql.pandas._typing` is a server side library while in client side we use `pyspark.sql.connect._typing`. cc @HyukjinKwon who may know more about this. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org