[ https://issues.apache.org/jira/browse/SPARK-53426?page=com.atlassian.jira.plugin.system.issuetabpanels:all-tabpanel ]
Allison Wang updated SPARK-53426: --------------------------------- Summary: Support named table argument with asTable() API (was: Support named argument for Arrow Python UDTFs) > Support named table argument with asTable() API > ----------------------------------------------- > > Key: SPARK-53426 > URL: https://issues.apache.org/jira/browse/SPARK-53426 > Project: Spark > Issue Type: Sub-task > Components: PySpark > Affects Versions: 4.1.0 > Reporter: Allison Wang > Priority: Major > > Named argument does not work for Python UDTFs table arguments: > > {code:java} > def test_arrow_udtf_with_named_arguments(self): > @arrow_udtf(returnType="result_id bigint, multiplier_used int") > class NamedArgsUDTF: > def eval( > self, > table_data: "pa.RecordBatch", > multiplier: "pa.Array" > ) -> Iterator["pa.Table"]: > assert isinstance( > table_data, pa.RecordBatch > ), f"Expected pa.RecordBatch for table_data, got > {type(table_data)}" > assert isinstance( > multiplier, pa.Array > ), f"Expected pa.Array for multiplier, got > {type(multiplier)}" multiplier_val = multiplier[0].as_py() > # Convert record batch to table > table = pa.table(table_data) > id_column = table.column("id") # Multiply each > id by the multiplier > multiplied_ids = pa.compute.multiply(id_column, > pa.scalar(multiplier_val)) result_table = pa.table({ > "result_id": multiplied_ids, > "multiplier_used": pa.array([multiplier_val] * > table.num_rows, type=pa.int32()) > }) > yield result_table # Test with DataFrame API using > named arguments > input_df = self.spark.range(3) # [0, 1, 2] > > result_df = NamedArgsUDTF(table_data=input_df.asTable(), > > multiplier=lit(5))python/pyspark/sql/tests/arrow/test_arrow_udtf.py:812: > _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ > _ _ _ _ _ _ _ _ _ _ _ _ _ > python/pyspark/sql/udtf.py:450: in __call__ > j_named_arg = sc._jvm.PythonSQLUtils.namedArgumentExpression(key, j_arg) > ../../.virtualenvs/spark/lib/python3.11/site-packages/py4j/java_gateway.py:1322: > in __call__ > return_value = get_return_value( > python/pyspark/errors/exceptions/captured.py:288: in deco > return f(*a, **kw) > _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ _ > _ _ _ _ _ _ _ _ _ _ _ _ _answer = 'xspy4j.Py4JException: Method > namedArgumentExpression([class java.lang.String, class > org.apache.spark.sql.TableArg]) > d....ClientServerConnection.run(ClientServerConnection.java:108)\\n\tat > java.base/java.lang.Thread.run(Thread.java:840)\\n' > gateway_client = <py4j.clientserver.JavaClient object at 0x109e20090> > target_id = 'z:org.apache.spark.sql.api.python.PythonSQLUtils', name = > 'namedArgumentExpression' def get_return_value(answer, gateway_client, > target_id=None, name=None): > """Converts an answer received from the Java gateway into a Python > object. For example, string representation of integers are converted > to Python > integer, string representation of objects are converted to JavaObject > instances, etc. :param answer: the string returned by the Java > gateway > :param gateway_client: the gateway client used to communicate with > the Java > Gateway. Only necessary if the answer is a reference (e.g., > object, > list, map) > :param target_id: the name of the object from which the answer comes > from > (e.g., *object1* in `object1.hello()`). Optional. > :param name: the name of the member from which the answer comes from > (e.g., *hello* in `object1.hello()`). Optional. > """ > if is_error(answer)[0]: > if len(answer) > 1: > type = answer[1] > value = OUTPUT_CONVERTER[type](answer[2:], gateway_client) > if answer[1] == REFERENCE_TYPE: > raise Py4JJavaError( > "An error occurred while calling {0}{1}{2}.\n". > format(target_id, ".", name), value) > else: > > raise Py4JError( > "An error occurred while calling {0}{1}{2}. > Trace:\n{3}\n". > format(target_id, ".", name, value)) > E py4j.protocol.Py4JError: An error occurred while calling > z:org.apache.spark.sql.api.python.PythonSQLUtils.namedArgumentExpression. > Trace: > E py4j.Py4JException: Method namedArgumentExpression([class > java.lang.String, class org.apache.spark.sql.TableArg]) does not exist > E at > py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:321) > E at > py4j.reflection.ReflectionEngine.getMethod(ReflectionEngine.java:342) > E at py4j.Gateway.invoke(Gateway.java:276) > E at > py4j.commands.AbstractCommand.invokeMethod(AbstractCommand.java:132) > E at > py4j.commands.CallCommand.execute(CallCommand.java:79) > E at > py4j.ClientServerConnection.waitForCommands(ClientServerConnection.java:184) > E at > py4j.ClientServerConnection.run(ClientServerConnection.java:108) > E at java.base/java.lang.Thread.run(Thread.java:840) > {code} > > -- This message was sent by Atlassian Jira (v8.20.10#820010) --------------------------------------------------------------------- To unsubscribe, e-mail: issues-unsubscr...@spark.apache.org For additional commands, e-mail: issues-h...@spark.apache.org