HyukjinKwon commented on code in PR #50099: URL: https://github.com/apache/spark/pull/50099#discussion_r2025899979
########## python/pyspark/worker.py: ########## @@ -1417,6 +1434,153 @@ def mapper(_, it): return mapper, None, ser, ser + elif eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF and not legacy_pandas_conversion: + + def wrap_arrow_udtf(f, return_type): + import pyarrow as pa + + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) + return_type_size = len(return_type) + + def verify_result(result): + if not isinstance(result, pa.RecordBatch): + raise PySparkTypeError( + errorClass="INVALID_ARROW_UDTF_RETURN_TYPE", + messageParameters={ + "return_type": type(result).__name__, + "value": str(result), + "func": f.__name__, + }, + ) + + # Validate the output schema when the result dataframe has either output + # rows or columns. Note that we avoid using `df.empty` here because the + # result dataframe may contain an empty row. For example, when a UDTF is + # defined as follows: def eval(self): yield tuple(). + if len(result) > 0 or len(result.columns) > 0: + if len(result.columns) != return_type_size: + raise PySparkRuntimeError( + errorClass="UDTF_RETURN_SCHEMA_MISMATCH", + messageParameters={ + "expected": str(return_type_size), + "actual": str(len(result.columns)), + "func": f.__name__, + }, + ) + + # Verify the type and the schema of the result. + verify_arrow_result( + pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))), + assign_cols_by_name=False, + expected_cols_and_types=[ + (col.name, to_arrow_type(col.dataType)) for col in return_type.fields + ], + ) + return result + + # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. + def func(*a: Any) -> Any: + try: + return f(*a) + except SkipRestOfInputTableException: + raise + except Exception as e: + raise PySparkRuntimeError( + errorClass="UDTF_EXEC_ERROR", + messageParameters={"method_name": f.__name__, "error": str(e)}, + ) + + def check_return_value(res): + # Check whether the result of an arrow UDTF is iterable before + # using it to construct a pandas DataFrame. + if res is not None: + if not isinstance(res, Iterable): + raise PySparkRuntimeError( + errorClass="UDTF_RETURN_NOT_ITERABLE", + messageParameters={ + "type": type(res).__name__, + "func": f.__name__, + }, + ) + if check_output_row_against_schema is not None: + for row in res: + if row is not None: + check_output_row_against_schema(row) + yield row + else: + yield from res + + def convert_to_arrow(data: Iterable): + data = list(check_return_value(data)) + if len(data) == 0: + return [ + pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type))) + ] + try: + ret = LocalDataToArrowConversion.convert( + data, return_type, prefers_large_var_types + ).to_batches() Review Comment: Yeah, if it grows over the default size (https://arrow.apache.org/docs/python/generated/pyarrow.dataset.Dataset.html#pyarrow.dataset.Dataset.to_batches) it can be multiple batches. It should work though - I wrote the codes that it should work via `ArrowStreamUDFSerializer.dump_stream`. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org