This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch arrow-worker
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 467f627413d8599d5cfb16b29c5df3ae2261ef75
Author: pawelkocinski <[email protected]>
AuthorDate: Sat Aug 2 01:14:33 2025 +0200

    SEDONA-738 Fix unit tests.
---
 sedonaworker/serializer.py                         | 28 +++++--
 sedonaworker/worker.py                             | 94 +++-------------------
 .../org/apache/spark/sql/udf/StrategySuite.scala   | 15 ++--
 .../apache/spark/sql/udf/TestScalarPandasUDF.scala | 91 ++++++++++++++++++---
 4 files changed, 115 insertions(+), 113 deletions(-)

diff --git a/sedonaworker/serializer.py b/sedonaworker/serializer.py
index 319de34780..0f85344f86 100644
--- a/sedonaworker/serializer.py
+++ b/sedonaworker/serializer.py
@@ -1,14 +1,9 @@
-import logging
-
-import pandas as pd
 from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer, 
ArrowStreamSerializer
 from pyspark.errors import PySparkTypeError, PySparkValueError
 import struct
 from pyspark.sql.pandas.types import (
     from_arrow_type,
-    to_arrow_type,
     _create_converter_from_pandas,
-    _create_converter_to_pandas,
 )
 
 def write_int(value, stream):
@@ -53,7 +48,18 @@ class 
SedonaArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
         batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
         import pyarrow as pa
         for batch in batches:
-            yield [ga.to_geopandas(c) for c in 
pa.Table.from_batches([batch]).itercolumns()]
+            table = pa.Table.from_batches([batch])
+            data = []
+
+            for c in table.itercolumns():
+                meta = table.schema.field(c._name).metadata
+                if meta and meta[b"ARROW:extension:name"] == b'geoarrow.wkb':
+                    data.append(ga.to_geopandas(c))
+                    continue
+
+                data.append(self.arrow_to_pandas(c))
+
+            yield data
 
     def _create_batch(self, series):
         """
@@ -72,7 +78,8 @@ class 
SedonaArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
             Arrow RecordBatch
         """
         import pyarrow as pa
-
+        import geopandas as gpd
+        from shapely.geometry.base import BaseGeometry
         # Make input conform to [(series1, type1), (series2, type2), ...]
         if not isinstance(series, (list, tuple)) or (
                 len(series) == 2 and isinstance(series[1], pa.DataType)
@@ -82,8 +89,13 @@ class 
SedonaArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
 
         arrs = []
         for s, t in series:
-            arrs.append(self._create_array(s, t, arrow_cast=self._arrow_cast))
+            # TODO here we should look into the return type
+            first_element = s.iloc[0]
+            if isinstance(first_element, BaseGeometry):
+                arrs.append(self._create_array(gpd.GeoSeries(s), t, 
arrow_cast=self._arrow_cast))
+                continue
 
+            arrs.append(self._create_array(s, t, arrow_cast=self._arrow_cast))
         return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in 
range(len(arrs))])
 
     def _create_array(self, series, arrow_type, spark_type=None, 
arrow_cast=False):
diff --git a/sedonaworker/worker.py b/sedonaworker/worker.py
index c08d47588a..2904fab764 100644
--- a/sedonaworker/worker.py
+++ b/sedonaworker/worker.py
@@ -142,15 +142,6 @@ def chain(f, g):
     """chain two functions together"""
     return lambda *a: g(f(*a))
 
-
-# def wrap_udf(f, return_type):
-#     if return_type.needConversion():
-#         toInternal = return_type.toInternal
-#         return lambda *a: toInternal(f(*a))
-#     else:
-#         return lambda *a: f(*a)
-
-
 def wrap_scalar_pandas_udf(f, return_type):
     arrow_return_type = to_arrow_type(return_type)
 
@@ -255,84 +246,9 @@ def read_udfs(pickleSer, infile, eval_type):
     ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
     # Arrow-optimized Python UDF uses explicit Arrow cast for type coercion
     arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
-    ser = SedonaArrowStreamPandasUDFSerializer(
-        timezone,
-        safecheck,
-        assign_cols_by_name(runner_conf),
-        df_for_struct,
-        struct_in_pandas,
-        ndarray_as_list,
-        arrow_cast,
-    )
 
     num_udfs = read_int(infile)
 
-    is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
-    is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
-    is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF
-
-    if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter:
-        if is_scalar_iter:
-            assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
-        if is_map_pandas_iter:
-            assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here."
-        if is_map_arrow_iter:
-            assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
-
-        arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
-
-        def func(_, iterator):
-            num_input_rows = 0
-
-            def map_batch(batch):
-                nonlocal num_input_rows
-
-                udf_args = [batch[offset] for offset in arg_offsets]
-                num_input_rows += len(udf_args[0])
-                if len(udf_args) == 1:
-                    return udf_args[0]
-                else:
-                    return tuple(udf_args)
-
-            iterator = map(map_batch, iterator)
-            result_iter = udf(iterator)
-
-            num_output_rows = 0
-            for result_batch, result_type in result_iter:
-                num_output_rows += len(result_batch)
-                # This assert is for Scalar Iterator UDF to fail fast.
-                # The length of the entire input can only be explicitly known
-                # by consuming the input iterator in user side. Therefore,
-                # it's very unlikely the output length is higher than
-                # input length.
-                assert (
-                        is_map_pandas_iter or is_map_arrow_iter or 
num_output_rows <= num_input_rows
-                ), "Pandas SCALAR_ITER UDF outputted more rows than input 
rows."
-                yield (result_batch, result_type)
-
-            if is_scalar_iter:
-                try:
-                    next(iterator)
-                except StopIteration:
-                    pass
-                else:
-                    raise PySparkRuntimeError(
-                        
error_class="STOP_ITERATION_OCCURRED_FROM_SCALAR_ITER_PANDAS_UDF",
-                        message_parameters={},
-                    )
-
-                if num_output_rows != num_input_rows:
-                    raise PySparkRuntimeError(
-                        
error_class="RESULT_LENGTH_MISMATCH_FOR_SCALAR_ITER_PANDAS_UDF",
-                        message_parameters={
-                            "output_length": str(num_output_rows),
-                            "input_length": str(num_input_rows),
-                        },
-                    )
-
-        # profiling is not supported for UDF
-        return func, None, ser, ser
-
     udfs = []
     for i in range(num_udfs):
         udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, 
udf_index=i))
@@ -349,6 +265,16 @@ def read_udfs(pickleSer, infile, eval_type):
     def func(_, it):
         return map(mapper, it)
 
+    ser = SedonaArrowStreamPandasUDFSerializer(
+        timezone,
+        safecheck,
+        assign_cols_by_name(runner_conf),
+        df_for_struct,
+        struct_in_pandas,
+        ndarray_as_list,
+        arrow_cast,
+    )
+
     # profiling is not supported for UDF
     return func, None, ser, ser
 
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
index 525dafaefb..350e4a515b 100644
--- 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.udf
 
 import org.apache.sedona.spark.SedonaContext
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.functions.col
-import org.apache.spark.sql.udf.ScalarUDF.geoPandasScalaFunction
+import org.apache.spark.sql.functions.{col, expr}
+import org.apache.spark.sql.udf.ScalarUDF.{geometryToGeometryFunction, 
geometryToNonGeometryFunction, geopandasGeometryToGeometryFunction, 
nonGeometryToGeometryFunction}
 import org.locationtech.jts.io.WKTReader
 import org.scalatest.funsuite.AnyFunSuite
 import org.scalatest.matchers.should.Matchers
@@ -45,16 +45,15 @@ class StrategySuite extends AnyFunSuite with Matchers {
 
   test("sedona geospatial UDF") {
     spark.sql("select 1").show()
-    val currentTime = System.currentTimeMillis()
-
     val df = spark.read.format("parquet")
       
.load("/Users/pawelkocinski/Desktop/projects/sedona-book/apache-sedona-book/book/chapter10/data/buildings/partitioned")
-      .select(geoPandasScalaFunction(col("geometry")).alias("area"))
-      .selectExpr("sum(area) as total_area")
+      .select(
+        geometryToNonGeometryFunction(col("geometry")),
+        geometryToGeometryFunction(col("geometry")),
+        nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")),
+      )
 
     df.show()
-    val processingTime = System.currentTimeMillis() - currentTime
-    println(s"Processing time: $processingTime ms")
 
 //    val df = Seq(
 //      (1, "value", wktReader.read("POINT(21 52)")),
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
index fdc96ac024..3006a14e14 100644
--- 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
@@ -74,7 +74,7 @@ object ScalarUDF {
 
   val additionalModule = 
"spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf"
 
-  val pandasFunc: Array[Byte] = {
+  val geopandasGeometryToNonGeometry: Array[Byte] = {
     var binaryPandasFunc: Array[Byte] = null
     withTempPath { path =>
       Process(
@@ -82,18 +82,8 @@ object ScalarUDF {
           pythonExec,
           "-c",
           f"""
-            |from pyspark.sql.types import IntegerType
-            |from pyspark.sql.types import DoubleType
             |from pyspark.sql.types import FloatType
-            |from shapely.geometry import Point
-            |from sedona.sql.types import GeometryType
             |from pyspark.serializers import CloudPickleSerializer
-            |from shapely import box
-            |import logging
-            |logging.basicConfig(level=logging.INFO)
-            |logger = logging.getLogger(__name__)
-            |logger.info("Loading Sedona Python UDF")
-            |import os
             |f = open('$path', 'wb');
             |def apply_geopandas(x):
             |    return x.area
@@ -107,15 +97,62 @@ object ScalarUDF {
     binaryPandasFunc
   }
 
+  val geopandasGeometryToGeometryFunction: Array[Byte] = {
+    var binaryPandasFunc: Array[Byte] = null
+    withTempPath { path =>
+      Process(
+        Seq(
+          pythonExec,
+          "-c",
+          f"""
+             |from sedona.sql.types import GeometryType
+             |from pyspark.serializers import CloudPickleSerializer
+             |f = open('$path', 'wb');
+             |def apply_geopandas(x):
+             |    return x.buffer(1)
+             |f.write(CloudPickleSerializer().dumps((apply_geopandas, 
GeometryType())))
+             |""".stripMargin),
+        None,
+        "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
+      binaryPandasFunc = Files.readAllBytes(path.toPath)
+    }
+    assert(binaryPandasFunc != null)
+    binaryPandasFunc
+  }
+
+  val geopandasNonGeometryToGeometryFunction: Array[Byte] = {
+    var binaryPandasFunc: Array[Byte] = null
+    withTempPath { path =>
+      Process(
+        Seq(
+          pythonExec,
+          "-c",
+          f"""
+             |from sedona.sql.types import GeometryType
+             |from shapely.wkt import loads
+             |from pyspark.serializers import CloudPickleSerializer
+             |f = open('$path', 'wb');
+             |def apply_geopandas(x):
+             |    return x.apply(lambda wkt: loads(wkt).buffer(1))
+             |f.write(CloudPickleSerializer().dumps((apply_geopandas, 
GeometryType())))
+             |""".stripMargin),
+        None,
+        "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
+      binaryPandasFunc = Files.readAllBytes(path.toPath)
+    }
+    assert(binaryPandasFunc != null)
+    binaryPandasFunc
+  }
+
   private val workerEnv = new java.util.HashMap[String, String]()
     workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath")
     SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "sedonaworker.worker")
     SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false)
 
-  val geoPandasScalaFunction: UserDefinedPythonFunction = 
UserDefinedPythonFunction(
+  val geometryToNonGeometryFunction: UserDefinedPythonFunction = 
UserDefinedPythonFunction(
     name = "geospatial_udf",
     func = SimplePythonFunction(
-      command = pandasFunc,
+      command = geopandasGeometryToNonGeometry,
       envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
       pythonIncludes = List.empty[String].asJava,
       pythonExec = pythonExec,
@@ -125,4 +162,32 @@ object ScalarUDF {
     dataType = FloatType,
     pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF,
     udfDeterministic = true)
+
+  val geometryToGeometryFunction: UserDefinedPythonFunction = 
UserDefinedPythonFunction(
+    name = "geospatial_udf",
+    func = SimplePythonFunction(
+      command = geopandasGeometryToGeometryFunction,
+      envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
+      pythonIncludes = List.empty[String].asJava,
+      pythonExec = pythonExec,
+      pythonVer = pythonVer,
+      broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
+      accumulator = null),
+    dataType = GeometryUDT,
+    pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF,
+    udfDeterministic = true)
+
+  val nonGeometryToGeometryFunction: UserDefinedPythonFunction = 
UserDefinedPythonFunction(
+    name = "geospatial_udf",
+    func = SimplePythonFunction(
+      command = geopandasNonGeometryToGeometryFunction,
+      envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
+      pythonIncludes = List.empty[String].asJava,
+      pythonExec = pythonExec,
+      pythonVer = pythonVer,
+      broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
+      accumulator = null),
+    dataType = GeometryUDT,
+    pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF,
+    udfDeterministic = true)
 }

Reply via email to