ueshin commented on code in PR #52140: URL: https://github.com/apache/spark/pull/52140#discussion_r2317581146
########## python/pyspark/sql/pandas/serializers.py: ########## @@ -227,6 +227,61 @@ def load_stream(self, stream): result_batches.append(batch.column(i)) yield result_batches + def _create_array(self, arr, arrow_type): + import pyarrow as pa + + assert isinstance(arr, pa.Array) + assert isinstance(arrow_type, pa.DataType) + if arr.type == arrow_type: + return arr + else: + try: + # when safe is True, the cast will fail if there's a overflow or other unsafe conversion + return arr.cast(target_type=arrow_type, safe=True) + except (pa.ArrowInvalid, pa.ArrowTypeError): + raise PySparkTypeError( + "Arrow UDTFs require the return type to match the expected Arrow type. " + f"Expected: {arrow_type}, but got: {arr.type}." + ) + + def dump_stream(self, iterator, stream): + """ + Override to handle type coercion for ArrowUDTF outputs. + ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples. + """ + import pyarrow as pa + + def apply_type_coercion(): + for batch, arrow_return_type in iterator: + assert isinstance( + arrow_return_type, pa.StructType + ), f"Expected pa.StructType, got {type(arrow_return_type)}" + + # Handle empty struct case specially + if batch.num_columns == 0: + coerced_batch = batch # skip type coercion + else: + expected_field_names = [field.name for field in arrow_return_type] + actual_field_names = batch.schema.names + + if expected_field_names != actual_field_names: + raise PySparkTypeError( + "Target schema's field names are not matching the record batch's field names. " + f"Expected: {expected_field_names}, but got: {actual_field_names}." + ) + + coerced_arrays = [] + for i, field in enumerate(arrow_return_type): + original_array = batch.column(i) + coerced_array = self._create_array(original_array, field.type) + coerced_arrays.append(coerced_array) + coerced_batch = pa.RecordBatch.from_arrays( + coerced_arrays, names=arrow_return_type.names Review Comment: nit: `expected_field_names` or `actual_field_names`? ########## python/pyspark/sql/pandas/serializers.py: ########## @@ -227,6 +227,61 @@ def load_stream(self, stream): result_batches.append(batch.column(i)) yield result_batches + def _create_array(self, arr, arrow_type): + import pyarrow as pa + + assert isinstance(arr, pa.Array) + assert isinstance(arrow_type, pa.DataType) + if arr.type == arrow_type: + return arr + else: + try: + # when safe is True, the cast will fail if there's a overflow or other unsafe conversion + return arr.cast(target_type=arrow_type, safe=True) + except (pa.ArrowInvalid, pa.ArrowTypeError): + raise PySparkTypeError( + "Arrow UDTFs require the return type to match the expected Arrow type. " + f"Expected: {arrow_type}, but got: {arr.type}." + ) + + def dump_stream(self, iterator, stream): + """ + Override to handle type coercion for ArrowUDTF outputs. + ArrowUDTF returns iterator of (pa.RecordBatch, arrow_return_type) tuples. + """ + import pyarrow as pa + + def apply_type_coercion(): + for batch, arrow_return_type in iterator: + assert isinstance( + arrow_return_type, pa.StructType + ), f"Expected pa.StructType, got {type(arrow_return_type)}" + + # Handle empty struct case specially + if batch.num_columns == 0: + coerced_batch = batch # skip type coercion + else: + expected_field_names = [field.name for field in arrow_return_type] Review Comment: nit: we can use `arrow_return_type.names` instead? -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org