HyukjinKwon commented on code in PR #50099:
URL: https://github.com/apache/spark/pull/50099#discussion_r2053763675


##########
python/pyspark/worker.py:
##########
@@ -1417,6 +1434,153 @@ def mapper(_, it):
 
         return mapper, None, ser, ser
 
+    elif eval_type == PythonEvalType.SQL_ARROW_TABLE_UDF and not 
legacy_pandas_conversion:
+
+        def wrap_arrow_udtf(f, return_type):
+            import pyarrow as pa
+
+            arrow_return_type = to_arrow_type(
+                return_type, 
prefers_large_types=use_large_var_types(runner_conf)
+            )
+            return_type_size = len(return_type)
+
+            def verify_result(result):
+                if not isinstance(result, pa.RecordBatch):
+                    raise PySparkTypeError(
+                        errorClass="INVALID_ARROW_UDTF_RETURN_TYPE",
+                        messageParameters={
+                            "return_type": type(result).__name__,
+                            "value": str(result),
+                            "func": f.__name__,
+                        },
+                    )
+
+                # Validate the output schema when the result dataframe has 
either output
+                # rows or columns. Note that we avoid using `df.empty` here 
because the
+                # result dataframe may contain an empty row. For example, when 
a UDTF is
+                # defined as follows: def eval(self): yield tuple().
+                if len(result) > 0 or len(result.columns) > 0:
+                    if len(result.columns) != return_type_size:
+                        raise PySparkRuntimeError(
+                            errorClass="UDTF_RETURN_SCHEMA_MISMATCH",
+                            messageParameters={
+                                "expected": str(return_type_size),
+                                "actual": str(len(result.columns)),
+                                "func": f.__name__,
+                            },
+                        )
+
+                # Verify the type and the schema of the result.
+                verify_arrow_result(
+                    pa.Table.from_batches([result], 
schema=pa.schema(list(arrow_return_type))),
+                    assign_cols_by_name=False,
+                    expected_cols_and_types=[
+                        (col.name, to_arrow_type(col.dataType)) for col in 
return_type.fields
+                    ],
+                )
+                return result
+
+            # Wrap the exception thrown from the UDTF in a PySparkRuntimeError.
+            def func(*a: Any) -> Any:
+                try:
+                    return f(*a)
+                except SkipRestOfInputTableException:
+                    raise
+                except Exception as e:
+                    raise PySparkRuntimeError(
+                        errorClass="UDTF_EXEC_ERROR",
+                        messageParameters={"method_name": f.__name__, "error": 
str(e)},
+                    )
+
+            def check_return_value(res):
+                # Check whether the result of an arrow UDTF is iterable before
+                # using it to construct a pandas DataFrame.
+                if res is not None:
+                    if not isinstance(res, Iterable):
+                        raise PySparkRuntimeError(
+                            errorClass="UDTF_RETURN_NOT_ITERABLE",
+                            messageParameters={
+                                "type": type(res).__name__,
+                                "func": f.__name__,
+                            },
+                        )
+                    if check_output_row_against_schema is not None:
+                        for row in res:
+                            if row is not None:
+                                check_output_row_against_schema(row)
+                            yield row
+                    else:
+                        yield from res
+
+            def convert_to_arrow(data: Iterable):
+                data = list(check_return_value(data))
+                if len(data) == 0:
+                    return [
+                        pa.RecordBatch.from_pylist(data, 
schema=pa.schema(list(arrow_return_type)))
+                    ]
+                try:
+                    ret = LocalDataToArrowConversion.convert(
+                        data, return_type, prefers_large_var_types
+                    ).to_batches()
+                    if len(return_type.fields) == 0:
+                        return [pa.RecordBatch.from_struct_array(pa.array([{}] 
* len(data)))]
+                    return ret
+                except Exception as e:
+                    raise PySparkRuntimeError(
+                        errorClass="UDTF_ARROW_TYPE_CONVERSION_ERROR",
+                        messageParameters={
+                            "data": str(data),
+                            "schema": return_type.simpleString(),
+                            "arrow_schema": str(arrow_return_type),
+                        },
+                    ) from e
+
+            def evaluate(*args: pa.ChunkedArray):
+                if len(args) == 0:
+                    for batch in convert_to_arrow(func()):
+                        yield verify_result(batch), arrow_return_type
+
+                else:
+                    list_args = list(args)
+                    names = [f"_{n}" for n in range(len(list_args))]
+                    t = pa.Table.from_arrays(list_args, names=names)
+                    schema = from_arrow_schema(t.schema, 
prefers_large_var_types)
+                    rows = ArrowTableToRowsConversion.convert(t, schema=schema)
+                    for row in rows:
+                        row = tuple(row)  # type: ignore[assignment]
+                        for batch in convert_to_arrow(func(*row)):
+                            yield verify_result(batch), arrow_return_type

Review Comment:
   Just checked. 
https://github.com/apache/arrow/blob/d2ddee62329eb711572b4d71d6380673d7f7edd1/cpp/src/arrow/table.cc#L612-L638
   
   The batch size will be long max by default, which I believe it's pretty 
safe. Arrow batch cannot contain # of rows larger than long in any way.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org

For queries about this service, please contact Infrastructure at:
us...@infra.apache.org


---------------------------------------------------------------------
To unsubscribe, e-mail: reviews-unsubscr...@spark.apache.org
For additional commands, e-mail: reviews-h...@spark.apache.org

Reply via email to