gaogaotiantian commented on code in PR #54296:
URL: https://github.com/apache/spark/pull/54296#discussion_r2830488892


##########
python/pyspark/worker.py:
##########
@@ -259,6 +260,37 @@ def check(result: Any) -> Iterator:
     return check
 
 
+def verify_scalar_result(result: Any, num_rows: int) -> Any:
+    """
+    Verify a scalar UDF result is array-like and has the expected number of 
rows.
+
+    Parameters
+    ----------
+    result : Any
+        The UDF result to verify.
+    num_rows : int
+        Expected number of rows (must match input batch size).
+    """
+    if not hasattr(result, "__len__"):

Review Comment:
   I understand this is how it was done before, but we are abstracting it out 
as a more generic functinn (probably will be used by others). This piece is not 
consistent. The error message saying we are expecting a `pyarrow.Array` but we 
are checking `__len__`. It could be confusing to users.
   
   Also, having `__len__` is not equivalent to be an array-like object. It's 
possible to implement a data structure with native code that does not have 
`__len__`. So `len(result)` could work even if `hasattr(result, "__len__") is 
False`.
   
   Even if we do not want to fix the inconsistency between array-like and 
`pyarrow.Array`, we should do a
   
   ```python
   try:
       result_length = len(result)
   except TypeError:
       raise XXX
   ```
   
   We are using the length anyway.



##########
python/pyspark/worker.py:
##########
@@ -1420,7 +1412,7 @@ def read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index):
     if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
         return wrap_scalar_pandas_udf(func, args_offsets, kwargs_offsets, 
return_type, runner_conf)
     elif eval_type == PythonEvalType.SQL_SCALAR_ARROW_UDF:
-        return wrap_scalar_arrow_udf(func, args_offsets, kwargs_offsets, 
return_type, runner_conf)
+        return func, args_offsets, kwargs_offsets, return_type

Review Comment:
   It's not ideal here to return a random thing. I remembered that I asked 
about `SQL_MAP_ARROW_ITER_UDF` and it returns just a `func`. Now this returns a 
tuple. We should not return different data based on eval type. We can throw 
away the unused data but we should keep the return format. Just imagine you 
have to type hint this function.



##########
python/pyspark/sql/conversion.py:
##########
@@ -107,6 +107,64 @@ def wrap_struct(batch: "pa.RecordBatch") -> 
"pa.RecordBatch":
             struct = pa.StructArray.from_arrays(batch.columns, 
fields=pa.struct(list(batch.schema)))
         return pa.RecordBatch.from_arrays([struct], ["_0"])
 
+    @classmethod
+    def enforce_schema(
+        cls,
+        batch: "pa.RecordBatch",
+        arrow_schema: "pa.Schema",
+        safecheck: bool = True,
+    ) -> "pa.RecordBatch":
+        """
+        Enforce target schema on a RecordBatch by reordering columns and 
coercing types.
+
+        Parameters
+        ----------
+        batch : pa.RecordBatch
+            Input RecordBatch to transform.
+        arrow_schema : pa.Schema
+            Target Arrow schema. Callers should pre-compute this once via
+            to_arrow_schema() to avoid repeated conversion.
+        safecheck : bool, default True
+            If True, use safe casting (fails on overflow/truncation).
+
+        Returns
+        -------
+        pa.RecordBatch
+            RecordBatch with columns reordered and types coerced to match 
target schema.
+        """
+        import pyarrow as pa
+
+        if batch.num_columns == 0 or len(arrow_schema) == 0:
+            return batch
+
+        # Fast path: schema already matches (ignoring metadata), no work needed
+        if batch.schema.equals(arrow_schema, check_metadata=False):
+            return batch
+
+        # Check if columns are in the same order (by name) as the target 
schema.
+        # If so, use index-based access (faster than name lookup).
+        batch_names = [batch.schema.field(i).name for i in 
range(batch.num_columns)]
+        target_names = [field.name for field in arrow_schema]
+        use_index = batch_names == target_names
+
+        coerced_arrays = []
+        for i, field in enumerate(arrow_schema):
+            arr = batch.column(i) if use_index else batch.column(field.name)
+            if arr.type != field.type:
+                try:
+                    arr = arr.cast(target_type=field.type, safe=safecheck)
+                except (pa.ArrowInvalid, pa.ArrowTypeError):
+                    raise PySparkRuntimeError(
+                        errorClass="RESULT_COLUMNS_MISMATCH_FOR_ARROW_UDTF",

Review Comment:
   Why `UDTF` here?



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]

Reply via email to