Allison Wang created SPARK-53426: ------------------------------------ Summary: Support named argument for Arrow Python UDTFs Key: SPARK-53426 URL: https://issues.apache.org/jira/browse/SPARK-53426 Project: Spark Issue Type: Sub-task Components: PySpark Affects Versions: 4.1.0 Reporter: Allison Wang
Named argument does not work for Python UDTFs table arguments: {code:java} def test_arrow_udtf_with_named_arguments(self): @arrow_udtf(returnType="result_id bigint, multiplier_used int") class NamedArgsUDTF: def eval( self, table_data: "pa.RecordBatch", multiplier: "pa.Array" ) -> Iterator["pa.Table"]: assert isinstance( table_data, pa.RecordBatch ), f"Expected pa.RecordBatch for table_data, got {type(table_data)}" assert isinstance( multiplier, pa.Array ), f"Expected pa.Array for multiplier, got {type(multiplier)}" multiplier_val = multiplier[0].as_py() # Convert record batch to table table = pa.table(table_data) id_column = table.column("id") # Multiply each id by the multiplier multiplied_ids = pa.compute.multiply(id_column, pa.scalar(multiplier_val)) result_table = pa.table({ "result_id": multiplied_ids, "multiplier_used": pa.array([multiplier_val] * table.num_rows, type=pa.int32()) }) yield result_table # Test with DataFrame API using named arguments input_df = self.spark.range(3) # [0, 1, 2] > result_df = NamedArgsUDTF(table_data=input_df.asTable(), > multiplier=lit(5))python/pyspark/sql/tests/arrow/test_arrow_udtf.py:812: _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ python/pyspark/sql/udtf.py:450: in __call__ j_named_arg = sc._jvm.PythonSQLUtils.namedArgumentExpression(key, j_arg) ../../.virtualenvs/spark/lib/python3.11/site-packages/py4j/java_gateway.py:1322: in __call__ return_value = get_return_value( python/pyspark/errors/exceptions/captured.py:288: in deco return f(*a, **kw) _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _answer = 'xspy4j.Py4JException: Method namedArgumentExpression([class java.lang.String, class org.apache.spark.sql.TableArg]) d....ClientServerConnection.run(ClientServerConnection.java:108)\\n\tat java.base/java.lang.Thread.run(Thread.java:840)\\n' gateway_client = <py4j.clientserver.JavaClient object at 0x109e20090> target_id = 'z:org.apache.spark.sql.api.python.PythonSQLUtils', name = 'namedArgumentExpression' def get_return_value(answer, gateway_client, target_id=None, name=None): """Converts an answer received from the Java gateway into a Python object. For example, string representation of integers are converted to Python integer, string representation of objects are converted to JavaObject instances, etc. :param answer: the string returned by the Java gateway :param gateway_client: the gateway client used to communicate with the Java Gateway. Only necessary if the answer is a reference (e.g., object, list, map) :param target_id: the name of the object from which the answer comes from (e.g., *object1* in `object1.hello()`). Optional. :param name: the name of the member from which the answer comes from (e.g., *hello* in `object1.hello()`). Optional. """ if is_error(answer)[0]: if len(answer) > 1: type = answer[1] value = OUTPUT_CONVERTER[type](answer[2:], gateway_client) if answer[1] == REFERENCE_TYPE: raise Py4JJavaError( "An error occurred while calling {0}{1}{2}.\n". format(target_id, ".", name), value) else: > raise Py4JError( "An error occurred while calling {0}{1}{2}. Trace:\n{3}\n". format(target_id, ".", name, value)) E py4j.protocol.Py4JError: An error occurred while calling z:org.apache.spark.sql.api.python.PythonSQLUtils.namedArgumentExpression. Trace: E py4j.Py4JException: Method namedArgumentExpression([class java.lang.String, class org.apache.spark.sql.TableArg]) does not exist E at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:321) E at py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:342) E at py4j.Gateway.invoke(Gateway.java:276) E at py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) E at py4j.commands.CallCommand.execute(CallCommand.java:79) E at py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:184) E at py4j.ClientServerConnection.run(ClientServerConnection.java:108) E at java.base/java.lang.Thread.run(Thread.java:840) {code} -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org