HyukjinKwon commented on code in PR #50099: URL: https://github.com/apache/spark/pull/50099#discussion_r2053763675
########## python/pyspark/worker.py: ########## @@ -1417,6 +1434,153 @@ def mapper(_, it): return mapper, None, ser, ser + elif eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF and not legacy_pandas_conversion: + + def wrap_arrow_udtf(f, return_type): + import pyarrow as pa + + arrow_return_type = to_arrow_type( + return_type, prefers_large_types=use_large_var_types(runner_conf) + ) + return_type_size = len(return_type) + + def verify_result(result): + if not isinstance(result, pa.RecordBatch): + raise PySparkTypeError( + errorClass="INVALID_ARROW_UDTF_RETURN_TYPE", + messageParameters={ + "return_type": type(result).__name__, + "value": str(result), + "func": f.__name__, + }, + ) + + # Validate the output schema when the result dataframe has either output + # rows or columns. Note that we avoid using `df.empty` here because the + # result dataframe may contain an empty row. For example, when a UDTF is + # defined as follows: def eval(self): yield tuple(). + if len(result) > 0 or len(result.columns) > 0: + if len(result.columns) != return_type_size: + raise PySparkRuntimeError( + errorClass="UDTF_RETURN_SCHEMA_MISMATCH", + messageParameters={ + "expected": str(return_type_size), + "actual": str(len(result.columns)), + "func": f.__name__, + }, + ) + + # Verify the type and the schema of the result. + verify_arrow_result( + pa.Table.from_batches([result], schema=pa.schema(list(arrow_return_type))), + assign_cols_by_name=False, + expected_cols_and_types=[ + (col.name, to_arrow_type(col.dataType)) for col in return_type.fields + ], + ) + return result + + # Wrap the exception thrown from the UDTF in a PySparkRuntimeError. + def func(*a: Any) -> Any: + try: + return f(*a) + except SkipRestOfInputTableException: + raise + except Exception as e: + raise PySparkRuntimeError( + errorClass="UDTF_EXEC_ERROR", + messageParameters={"method_name": f.__name__, "error": str(e)}, + ) + + def check_return_value(res): + # Check whether the result of an arrow UDTF is iterable before + # using it to construct a pandas DataFrame. + if res is not None: + if not isinstance(res, Iterable): + raise PySparkRuntimeError( + errorClass="UDTF_RETURN_NOT_ITERABLE", + messageParameters={ + "type": type(res).__name__, + "func": f.__name__, + }, + ) + if check_output_row_against_schema is not None: + for row in res: + if row is not None: + check_output_row_against_schema(row) + yield row + else: + yield from res + + def convert_to_arrow(data: Iterable): + data = list(check_return_value(data)) + if len(data) == 0: + return [ + pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type))) + ] + try: + ret = LocalDataToArrowConversion.convert( + data, return_type, prefers_large_var_types + ).to_batches() + if len(return_type.fields) == 0: + return [pa.RecordBatch.from_struct_array(pa.array([{}] * len(data)))] + return ret + except Exception as e: + raise PySparkRuntimeError( + errorClass="UDTF_ARROW_TYPE_CONVERSION_ERROR", + messageParameters={ + "data": str(data), + "schema": return_type.simpleString(), + "arrow_schema": str(arrow_return_type), + }, + ) from e + + def evaluate(*args: pa.ChunkedArray): + if len(args) == 0: + for batch in convert_to_arrow(func()): + yield verify_result(batch), arrow_return_type + + else: + list_args = list(args) + names = [f"_{n}" for n in range(len(list_args))] + t = pa.Table.from_arrays(list_args, names=names) + schema = from_arrow_schema(t.schema, prefers_large_var_types) + rows = ArrowTableToRowsConversion.convert(t, schema=schema) + for row in rows: + row = tuple(row) # type: ignore[assignment] + for batch in convert_to_arrow(func(*row)): + yield verify_result(batch), arrow_return_type Review Comment: Just checked. https://github.com/apache/arrow/blob/d2ddee62329eb711572b4d71d6380673d7f7edd1/cpp/src/arrow/table.cc#L612-L638 The batch size will be long max by default, which I believe it's pretty safe. Arrow batch cannot contain # of rows larger than long in any way. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org