wengh commented on code in PR #50099: URL: https://github.com/apache/spark/pull/50099#discussion_r1977992529
########## python/pyspark/worker.py: ########## @@ -1372,19 +1374,47 @@ def check_return_value(res): else: yield from res - def evaluate(*args: pd.Series): + def convert_to_arrow(data: Iterable): + data = list(check_return_value(data)) + if len(data) == 0: + return [ + pa.RecordBatch.from_pylist(data, schema=pa.schema(list(arrow_return_type))) + ] + try: + ret = LocalDataToArrowConversion.convert( + data, return_type, prefers_large_var_types + ).to_batches() + if len(return_type.fields) == 0: + return [pa.RecordBatch.from_struct_array(pa.array([{}] * len(data)))] + return ret + except Exception as e: + raise PySparkRuntimeError( + errorClass="UDTF_ARROW_TYPE_CAST_ERROR", + messageParameters={ + "data": str(data), + "schema": return_type.simpleString(), + "arrow_schema": str(arrow_return_type), + }, + ) from e + + def evaluate(*args: pa.ChunkedArray): if len(args) == 0: - res = func() - yield verify_result(pd.DataFrame(check_return_value(res))), arrow_return_type + for batch in convert_to_arrow(func()): + yield verify_result(batch), arrow_return_type + else: - # Create tuples from the input pandas Series, each tuple - # represents a row across all Series. - row_tuples = zip(*args) - for row in row_tuples: - res = func(*row) - yield verify_result( - pd.DataFrame(check_return_value(res)) - ), arrow_return_type + list_args = list(args) + names = [f"_{n}" for n in range(len(list_args))] + t = pa.Table.from_arrays( + [pa.StructArray.from_arrays(list_args, names=names)], names=["_0"] + ) + schema = from_arrow_schema(t.schema, prefers_large_var_types) + rows = ArrowTableToRowsConversion.convert(t, schema=schema) + for row in rows: + row = row["_0"] + row = tuple(row) # type: ignore[assignment] + for batch in convert_to_arrow(func(*row)): + yield verify_result(batch), arrow_return_type Review Comment: Can we directly make a pa.Table from the columns without wrapping in a pa.StructArray? ```suggestion names = [f"_{n}" for n in range(len(list_args))] t = pa.Table.from_arrays(list_args, names=names) schema = from_arrow_schema(t.schema, prefers_large_var_types) rows = ArrowTableToRowsConversion.convert(t, schema=schema) for row in rows: row = tuple(row) # type: ignore[assignment] for batch in convert_to_arrow(func(*row)): yield verify_result(batch), arrow_return_type ``` -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For queries about this service, please contact Infrastructure at: us...@infra.apache.org --------------------------------------------------------------------- To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org For additional commands, e-mail: reviews-h...@spark.apache.org