ueshin commented on code in PR #52140:
URL: https://github.com/apache/spark/pull/52140#discussion_r2317581146


##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -227,6 +227,61 @@ def load_stream(self, stream):
                     result_batches.append(batch.column(i))
             yield result_batches
 
+    def _create_array(self, arr, arrow_type):
+        import pyarrow as pa
+
+        assert isinstance(arr, pa.Array)
+        assert isinstance(arrow_type, pa.DataType)
+        if arr.type == arrow_type:
+            return arr
+        else:
+            try:
+                # when safe is True, the cast will fail if there's a overflow 
or other unsafe conversion
+                return arr.cast(target_type=arrow_type, safe=True)
+            except (pa.ArrowInvalid, pa.ArrowTypeError):
+                raise PySparkTypeError(
+                    "Arrow UDTFs require the return type to match the expected 
Arrow type. "
+                    f"Expected: {arrow_type}, but got: {arr.type}."
+                )
+
+    def dump_stream(self, iterator, stream):
+        """
+        Override to handle type coercion for ArrowUDTF outputs.
+        ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) 
tuples.
+        """
+        import pyarrow as pa
+
+        def apply_type_coercion():
+            for batch, arrow_return_type in iterator:
+                assert isinstance(
+                    arrow_return_type, pa.StructType
+                ), f"Expected pa.StructType, got {type(arrow_return_type)}"
+
+                # Handle empty struct case specially
+                if batch.num_columns == 0:
+                    coerced_batch = batch  # skip type coercion
+                else:
+                    expected_field_names = [field.name for field in 
arrow_return_type]
+                    actual_field_names = batch.schema.names
+
+                    if expected_field_names != actual_field_names:
+                        raise PySparkTypeError(
+                            "Target schema's field names are not matching the 
record batch's field names. "
+                            f"Expected: {expected_field_names}, but got: 
{actual_field_names}."
+                        )
+
+                    coerced_arrays = []
+                    for i, field in enumerate(arrow_return_type):
+                        original_array = batch.column(i)
+                        coerced_array = self._create_array(original_array, 
field.type)
+                        coerced_arrays.append(coerced_array)
+                    coerced_batch = pa.RecordBatch.from_arrays(
+                        coerced_arrays, names=arrow_return_type.names

Review Comment:
   nit: `expected_field_names` or `actual_field_names`?



##########
python/pyspark/sql/pandas/serializers.py:
##########
@@ -227,6 +227,61 @@ def load_stream(self, stream):
                     result_batches.append(batch.column(i))
             yield result_batches
 
+    def _create_array(self, arr, arrow_type):
+        import pyarrow as pa
+
+        assert isinstance(arr, pa.Array)
+        assert isinstance(arrow_type, pa.DataType)
+        if arr.type == arrow_type:
+            return arr
+        else:
+            try:
+                # when safe is True, the cast will fail if there's a overflow 
or other unsafe conversion
+                return arr.cast(target_type=arrow_type, safe=True)
+            except (pa.ArrowInvalid, pa.ArrowTypeError):
+                raise PySparkTypeError(
+                    "Arrow UDTFs require the return type to match the expected 
Arrow type. "
+                    f"Expected: {arrow_type}, but got: {arr.type}."
+                )
+
+    def dump_stream(self, iterator, stream):
+        """
+        Override to handle type coercion for ArrowUDTF outputs.
+        ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) 
tuples.
+        """
+        import pyarrow as pa
+
+        def apply_type_coercion():
+            for batch, arrow_return_type in iterator:
+                assert isinstance(
+                    arrow_return_type, pa.StructType
+                ), f"Expected pa.StructType, got {type(arrow_return_type)}"
+
+                # Handle empty struct case specially
+                if batch.num_columns == 0:
+                    coerced_batch = batch  # skip type coercion
+                else:
+                    expected_field_names = [field.name for field in 
arrow_return_type]

Review Comment:
   nit: we can use `arrow_return_type.names` instead?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to