This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch arrow-worker
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 798cd3ca8d00fdc32f0091703c9edf1b543327a1
Author: pawelkocinski <[email protected]>
AuthorDate: Wed Jul 30 00:07:46 2025 +0200

    SEDONA-738 Fix unit tests.
---
 .../common/geometrySerde/GeometrySerializer.java   |  53 ++---
 sedonaworker/serializer.py                         | 187 ++++++++++++++++++
 sedonaworker/worker.py                             | 154 +++++++++------
 .../spark/api/python/SedonaPythonRunner.scala      |  32 ++-
 .../sql/execution/python/SedonaArrowUtils.scala    | 216 +++++++++++++++++++++
 .../execution/python/SedonaPythonArrowInput.scala  |   4 +-
 .../org/apache/spark/sql/udf/StrategySuite.scala   |  68 ++++---
 .../apache/spark/sql/udf/TestScalarPandasUDF.scala |  23 +--
 8 files changed, 613 insertions(+), 124 deletions(-)

diff --git 
a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
 
b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
index 508a62901d..325098c6ac 100644
--- 
a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
+++ 
b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
@@ -32,37 +32,46 @@ import org.locationtech.jts.geom.Point;
 import org.locationtech.jts.geom.Polygon;
 import org.locationtech.jts.geom.PrecisionModel;
 import org.locationtech.jts.io.WKBConstants;
+import org.locationtech.jts.io.WKBReader;
+import org.locationtech.jts.io.WKBWriter;
 
 public class GeometrySerializer {
   private static final Coordinate NULL_COORDINATE = new Coordinate(Double.NaN, 
Double.NaN);
   private static final PrecisionModel PRECISION_MODEL = new PrecisionModel();
 
   public static byte[] serialize(Geometry geometry) {
-    GeometryBuffer buffer;
-    if (geometry instanceof Point) {
-      buffer = serializePoint((Point) geometry);
-    } else if (geometry instanceof MultiPoint) {
-      buffer = serializeMultiPoint((MultiPoint) geometry);
-    } else if (geometry instanceof LineString) {
-      buffer = serializeLineString((LineString) geometry);
-    } else if (geometry instanceof MultiLineString) {
-      buffer = serializeMultiLineString((MultiLineString) geometry);
-    } else if (geometry instanceof Polygon) {
-      buffer = serializePolygon((Polygon) geometry);
-    } else if (geometry instanceof MultiPolygon) {
-      buffer = serializeMultiPolygon((MultiPolygon) geometry);
-    } else if (geometry instanceof GeometryCollection) {
-      buffer = serializeGeometryCollection((GeometryCollection) geometry);
-    } else {
-      throw new UnsupportedOperationException(
-          "Geometry type is not supported: " + 
geometry.getClass().getSimpleName());
-    }
-    return buffer.toByteArray();
+    return new WKBWriter().write(geometry);
+//    GeometryBuffer buffer;
+//    if (geometry instanceof Point) {
+//      buffer = serializePoint((Point) geometry);
+//    } else if (geometry instanceof MultiPoint) {
+//      buffer = serializeMultiPoint((MultiPoint) geometry);
+//    } else if (geometry instanceof LineString) {
+//      buffer = serializeLineString((LineString) geometry);
+//    } else if (geometry instanceof MultiLineString) {
+//      buffer = serializeMultiLineString((MultiLineString) geometry);
+//    } else if (geometry instanceof Polygon) {
+//      buffer = serializePolygon((Polygon) geometry);
+//    } else if (geometry instanceof MultiPolygon) {
+//      buffer = serializeMultiPolygon((MultiPolygon) geometry);
+//    } else if (geometry instanceof GeometryCollection) {
+//      buffer = serializeGeometryCollection((GeometryCollection) geometry);
+//    } else {
+//      throw new UnsupportedOperationException(
+//          "Geometry type is not supported: " + 
geometry.getClass().getSimpleName());
+//    }
+//    return buffer.toByteArray();
   }
 
   public static Geometry deserialize(byte[] bytes) {
-    GeometryBuffer buffer = GeometryBufferFactory.wrap(bytes);
-    return deserialize(buffer);
+    WKBReader reader = new WKBReader();
+    try {
+      return reader.read(bytes);
+    } catch (Exception e) {
+      throw new IllegalArgumentException("Failed to deserialize geometry from 
bytes", e);
+    }
+//    GeometryBuffer buffer = GeometryBufferFactory.wrap(bytes);
+//    return deserialize(buffer);
   }
 
   public static Geometry deserialize(GeometryBuffer buffer) {
diff --git a/sedonaworker/serializer.py b/sedonaworker/serializer.py
new file mode 100644
index 0000000000..319de34780
--- /dev/null
+++ b/sedonaworker/serializer.py
@@ -0,0 +1,187 @@
+import logging
+
+import pandas as pd
+from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer, 
ArrowStreamSerializer
+from pyspark.errors import PySparkTypeError, PySparkValueError
+import struct
+from pyspark.sql.pandas.types import (
+    from_arrow_type,
+    to_arrow_type,
+    _create_converter_from_pandas,
+    _create_converter_to_pandas,
+)
+
+def write_int(value, stream):
+    stream.write(struct.pack("!i", value))
+
+class SpecialLengths:
+    END_OF_DATA_SECTION = -1
+    PYTHON_EXCEPTION_THROWN = -2
+    TIMING_DATA = -3
+    END_OF_STREAM = -4
+    NULL = -5
+    START_ARROW_STREAM = -6
+
+
+class SedonaArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
+    """
+    Serializer used by Python worker to evaluate Pandas UDFs
+    """
+
+    def __init__(
+            self,
+            timezone,
+            safecheck,
+            assign_cols_by_name,
+            df_for_struct=False,
+            struct_in_pandas="dict",
+            ndarray_as_list=False,
+            arrow_cast=False,
+    ):
+        super(SedonaArrowStreamPandasUDFSerializer, self).__init__(timezone, 
safecheck)
+        self._assign_cols_by_name = assign_cols_by_name
+        self._df_for_struct = df_for_struct
+        self._struct_in_pandas = struct_in_pandas
+        self._ndarray_as_list = ndarray_as_list
+        self._arrow_cast = arrow_cast
+
+    def load_stream(self, stream):
+        """
+        Deserialize ArrowRecordBatches to an Arrow table and return as a list 
of pandas.Series.
+        """
+        import geoarrow.pyarrow as ga
+        batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
+        import pyarrow as pa
+        for batch in batches:
+            yield [ga.to_geopandas(c) for c in 
pa.Table.from_batches([batch]).itercolumns()]
+
+    def _create_batch(self, series):
+        """
+        Create an Arrow record batch from the given pandas.Series 
pandas.DataFrame
+        or list of Series or DataFrame, with optional type.
+
+        Parameters
+        ----------
+        series : pandas.Series or pandas.DataFrame or list
+            A single series or dataframe, list of series or dataframe,
+            or list of (series or dataframe, arrow_type)
+
+        Returns
+        -------
+        pyarrow.RecordBatch
+            Arrow RecordBatch
+        """
+        import pyarrow as pa
+
+        # Make input conform to [(series1, type1), (series2, type2), ...]
+        if not isinstance(series, (list, tuple)) or (
+                len(series) == 2 and isinstance(series[1], pa.DataType)
+        ):
+            series = [series]
+        series = ((s, None) if not isinstance(s, (list, tuple)) else s for s 
in series)
+
+        arrs = []
+        for s, t in series:
+            arrs.append(self._create_array(s, t, arrow_cast=self._arrow_cast))
+
+        return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in 
range(len(arrs))])
+
+    def _create_array(self, series, arrow_type, spark_type=None, 
arrow_cast=False):
+        """
+        Create an Arrow Array from the given pandas.Series and optional type.
+
+        Parameters
+        ----------
+        series : pandas.Series
+            A single series
+        arrow_type : pyarrow.DataType, optional
+            If None, pyarrow's inferred type will be used
+        spark_type : DataType, optional
+            If None, spark type converted from arrow_type will be used
+        arrow_cast: bool, optional
+            Whether to apply Arrow casting when the user-specified return type 
mismatches the
+            actual return values.
+
+        Returns
+        -------
+        pyarrow.Array
+        """
+        import pyarrow as pa
+        from pandas.api.types import is_categorical_dtype
+        if is_categorical_dtype(series.dtype):
+            series = series.astype(series.dtypes.categories.dtype)
+        if arrow_type is not None:
+            dt = spark_type or from_arrow_type(arrow_type, 
prefer_timestamp_ntz=True)
+            # TODO(SPARK-43579): cache the converter for reuse
+            conv = _create_converter_from_pandas(
+                dt, timezone=self._timezone, 
error_on_duplicated_field_names=False
+            )
+            series = conv(series)
+
+        if hasattr(series.array, "__arrow_array__"):
+            mask = None
+        else:
+            mask = series.isnull()
+        try:
+            try:
+                import geopandas as gpd
+                if isinstance(series, gpd.GeoSeries):
+                    import geoarrow.pyarrow as ga
+                    # If the series is a GeoSeries, convert it to an Arrow 
array using geoarrow
+                    return ga.array(series)
+
+                array = pa.Array.from_pandas(
+                    series, mask=mask, type=arrow_type, safe=self._safecheck
+                )
+
+                return array
+            except pa.lib.ArrowInvalid:
+                if arrow_cast:
+                    return pa.Array.from_pandas(series, mask=mask).cast(
+                        target_type=arrow_type, safe=self._safecheck
+                    )
+                else:
+                    raise
+        except TypeError as e:
+            error_msg = (
+                "Exception thrown when converting pandas.Series (%s) "
+                "with name '%s' to Arrow Array (%s)."
+            )
+            raise PySparkTypeError(error_msg % (series.dtype, series.name, 
arrow_type)) from e
+        except ValueError as e:
+            error_msg = (
+                "Exception thrown when converting pandas.Series (%s) "
+                "with name '%s' to Arrow Array (%s)."
+            )
+            if self._safecheck:
+                error_msg = error_msg + (
+                    " It can be caused by overflows or other "
+                    "unsafe conversions warned by Arrow. Arrow safe type check 
"
+                    "can be disabled by using SQL config "
+                    "`spark.sql.execution.pandas.convertToArrowArraySafely`."
+                )
+            raise PySparkValueError(error_msg % (series.dtype, series.name, 
arrow_type)) from e
+
+    def dump_stream(self, iterator, stream):
+        """
+        Override because Pandas UDFs require a START_ARROW_STREAM before the 
Arrow stream is sent.
+        This should be sent after creating the first record batch so in case 
of an error, it can
+        be sent back to the JVM before the Arrow stream starts.
+        """
+
+        def init_stream_yield_batches():
+            should_write_start_length = True
+            for series in iterator:
+                batch = self._create_batch(series)
+                if should_write_start_length:
+                    write_int(SpecialLengths.START_ARROW_STREAM, stream)
+                    should_write_start_length = False
+                yield batch
+
+        return ArrowStreamSerializer.dump_stream(self, 
init_stream_yield_batches(), stream)
+
+    def __repr__(self):
+        return "ArrowStreamPandasUDFSerializer"
+
+from pyspark.serializers import Serializer, write_int
+
diff --git a/sedonaworker/worker.py b/sedonaworker/worker.py
index 365561f0a6..c08d47588a 100644
--- a/sedonaworker/worker.py
+++ b/sedonaworker/worker.py
@@ -18,14 +18,18 @@
 """
 Worker that receives input from Piped RDD.
 """
+import logging
 import os
 import sys
 import time
 from inspect import currentframe, getframeinfo, getfullargspec
 import importlib
 import json
+from io import BufferedRWPair
 from typing import Any, Iterable, Iterator
 
+from sedonaworker.serializer import SedonaArrowStreamPandasUDFSerializer
+
 # 'resource' is a Unix specific module.
 has_resource_module = True
 try:
@@ -56,11 +60,7 @@ from pyspark.serializers import (
     BatchedSerializer,
 )
 from pyspark.sql.pandas.serializers import (
-    ArrowStreamPandasUDFSerializer,
-    ArrowStreamPandasUDTFSerializer,
-    CogroupUDFSerializer,
-    ArrowStreamUDFSerializer,
-    ApplyInPandasWithStateSerializer,
+    ArrowStreamPandasUDFSerializer, ArrowStreamSerializer
 )
 from pyspark.sql.pandas.types import to_arrow_type
 from pyspark.sql.types import BinaryType, StringType, StructType, 
_parse_datatype_json_string
@@ -72,6 +72,51 @@ pickleSer = CPickleSerializer()
 utf8_deserializer = UTF8Deserializer()
 
 
+class SedonaArrowStreamUDFSerializer(ArrowStreamSerializer):
+    """
+    Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow 
record batch
+    for applying each function with the raw record arrow batch. See also 
`DataFrame.mapInArrow`.
+    """
+
+    def load_stream(self, stream):
+        """
+        Flatten the struct into Arrow's record batches.
+        """
+        import pyarrow as pa
+
+        batches = super(SedonaArrowStreamUDFSerializer, 
self).load_stream(stream)
+        for batch in batches:
+            struct = batch.column(0)
+            yield [pa.RecordBatch.from_arrays(struct.flatten(), 
schema=pa.schema(struct.type))]
+
+    def dump_stream(self, iterator, stream):
+        """
+        Override because Pandas UDFs require a START_ARROW_STREAM before the 
Arrow stream is sent.
+        This should be sent after creating the first record batch so in case 
of an error, it can
+        be sent back to the JVM before the Arrow stream starts.
+        """
+        import pyarrow as pa
+
+        def wrap_and_init_stream():
+            should_write_start_length = True
+            for batch, _ in iterator:
+                assert isinstance(batch, pa.RecordBatch)
+
+                # Wrap the root struct
+                struct = pa.StructArray.from_arrays(
+                    batch.columns, fields=pa.struct(list(batch.schema))
+                )
+                batch = pa.RecordBatch.from_arrays([struct], ["_0"])
+
+                # Write the first record batch with initialization.
+                if should_write_start_length:
+                    write_int(SpecialLengths.START_ARROW_STREAM, stream)
+                    should_write_start_length = False
+                yield batch
+
+        return super(SedonaArrowStreamUDFSerializer, 
self).dump_stream(wrap_and_init_stream(), stream)
+
+
 def report_times(outfile, boot, init, finish):
     write_int(SpecialLengths.TIMING_DATA, outfile)
     write_long(int(1000 * boot), outfile)
@@ -180,56 +225,45 @@ def assign_cols_by_name(runner_conf):
 def read_udfs(pickleSer, infile, eval_type):
     runner_conf = {}
 
-    if eval_type in (
-            PythonEvalType.SQL_SCALAR_PANDAS_UDF,
-    ):
-
-        # Load conf used for pandas_udf evaluation
-        num_conf = read_int(infile)
-        for i in range(num_conf):
-            k = utf8_deserializer.loads(infile)
-            v = utf8_deserializer.loads(infile)
-            runner_conf[k] = v
-
-        state_object_schema = None
-        if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
-            state_object_schema = 
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
-
-        # NOTE: if timezone is set here, that implies respectSessionTimeZone 
is True
-        timezone = runner_conf.get("spark.sql.session.timeZone", None)
-        safecheck = (
-                
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", 
"false").lower()
-                == "true"
-        )
+    # Load conf used for pandas_udf evaluation
+    num_conf = read_int(infile)
+    for i in range(num_conf):
+        k = utf8_deserializer.loads(infile)
+        v = utf8_deserializer.loads(infile)
+        runner_conf[k] = v
+
+    # state_object_schema = None
+    # if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+    #     state_object_schema = 
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
+
+    # NOTE: if timezone is set here, that implies respectSessionTimeZone is 
True
+    timezone = runner_conf.get("spark.sql.session.timeZone", None)
+    safecheck = (
+            
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", 
"false").lower()
+            == "true"
+    )
 
-        if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
-            ser = CogroupUDFSerializer(timezone, safecheck, 
assign_cols_by_name(runner_conf))
-        else:
-            # Scalar Pandas UDF handles struct type arguments as pandas 
DataFrames instead of
-            # pandas Series. See SPARK-27240.
-            df_for_struct = (
-                    eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
-                    or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
-                    or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
-            )
-            # Arrow-optimized Python UDF takes a struct type argument as a Row
-            struct_in_pandas = (
-                "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF 
else "dict"
-            )
-            ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
-            # Arrow-optimized Python UDF uses explicit Arrow cast for type 
coercion
-            arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
-            ser = ArrowStreamPandasUDFSerializer(
-                timezone,
-                safecheck,
-                assign_cols_by_name(runner_conf),
-                df_for_struct,
-                struct_in_pandas,
-                ndarray_as_list,
-                arrow_cast,
-            )
-    else:
-        ser = BatchedSerializer(CPickleSerializer(), 100)
+    df_for_struct = (
+            eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
+            or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
+            or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
+    )
+    # Arrow-optimized Python UDF takes a struct type argument as a Row
+    struct_in_pandas = (
+        "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict"
+    )
+    ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
+    # Arrow-optimized Python UDF uses explicit Arrow cast for type coercion
+    arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
+    ser = SedonaArrowStreamPandasUDFSerializer(
+        timezone,
+        safecheck,
+        assign_cols_by_name(runner_conf),
+        df_for_struct,
+        struct_in_pandas,
+        ndarray_as_list,
+        arrow_cast,
+    )
 
     num_udfs = read_int(infile)
 
@@ -298,6 +332,7 @@ def read_udfs(pickleSer, infile, eval_type):
 
         # profiling is not supported for UDF
         return func, None, ser, ser
+
     udfs = []
     for i in range(num_udfs):
         udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, 
udf_index=i))
@@ -332,6 +367,7 @@ def main(infile, outfile):
             sys.exit(-1)
 
         version = utf8_deserializer.loads(infile)
+
         if version != "%d.%d" % sys.version_info[:2]:
             raise PySparkRuntimeError(
                 error_class="PYTHON_VERSION_MISMATCH",
@@ -459,11 +495,7 @@ def main(infile, outfile):
 
         _accumulatorRegistry.clear()
         eval_type = read_int(infile)
-        if eval_type == PythonEvalType.NON_UDF:
-            func, profiler, deserializer, serializer = read_command(pickleSer, 
infile)
-        else:
-            func, profiler, deserializer, serializer = read_udfs(pickleSer, 
infile, eval_type)
-
+        func, profiler, deserializer, serializer = read_udfs(pickleSer, 
infile, eval_type)
         init_time = time.time()
 
         def process():
@@ -539,3 +571,7 @@ if __name__ == "__main__":
     write_int(os.getpid(), sock_file)
     sock_file.flush()
     main(sock_file, sock_file)
+
+
+class GeoArrowLoader:
+    pass
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
index 026518272c..fb01e62b5e 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
@@ -19,6 +19,8 @@ package org.apache.spark.api.python
 
 import org.apache.spark._
 import org.apache.spark.SedonaSparkEnv
+import org.apache.spark.api.python.PythonRDD.writeUTF
+import org.apache.spark.input.PortableDataStream
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.Python._
 import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
@@ -749,7 +751,7 @@ private[spark] class PythonRunner(
       }
 
       protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
-        PythonRDD.writeIteratorToStream(inputIterator, dataOut)
+        GeoArrowWriter.writeIteratorToStream(inputIterator, dataOut)
         dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
       }
     }
@@ -809,3 +811,31 @@ private[spark] object BarrierTaskContextMessageProtocol {
   val BARRIER_RESULT_SUCCESS = "success"
   val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python 
side."
 }
+
+object GeoArrowWriter extends Logging {
+  def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): 
Unit = {
+
+    def write(obj: Any): Unit = obj match {
+      case null =>
+        dataOut.writeInt(SpecialLengths.NULL)
+      case arr: Array[Byte] =>
+        logError("some random array")
+        dataOut.writeInt(arr.length)
+        dataOut.write(arr)
+      case str: String =>
+        logError("some random string")
+        writeUTF(str, dataOut)
+      case stream: PortableDataStream =>
+        logError("some random stream")
+        write(stream.toArray())
+      case (key, value) =>
+        logError("some random key value")
+        write(key)
+        write(value)
+      case other =>
+        throw new SparkException("Unexpected element type " + other.getClass)
+    }
+
+    iter.foreach(write)
+  }
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
new file mode 100644
index 0000000000..58166d173d
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
@@ -0,0 +1,216 @@
+package org.apache.spark.sql.execution.python
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.util.concurrent.atomic.AtomicInteger
+import scala.collection.JavaConverters._
+import org.apache.arrow.memory.RootAllocator
+import org.apache.arrow.vector.complex.MapVector
+import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, 
IntervalUnit, TimeUnit}
+import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
+import org.apache.spark.sql.errors.ExecutionErrors
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.types._
+
+private[sql] object SedonaArrowUtils {
+
+  val rootAllocator = new RootAllocator(Long.MaxValue)
+
+  // todo: support more types.
+
+  /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for 
TimestampTypes */
+  def toArrowType(
+                   dt: DataType, timeZoneId: String, largeVarTypes: Boolean = 
false): ArrowType = dt match {
+    case BooleanType => ArrowType.Bool.INSTANCE
+    case ByteType => new ArrowType.Int(8, true)
+    case ShortType => new ArrowType.Int(8 * 2, true)
+    case IntegerType => new ArrowType.Int(8 * 4, true)
+    case LongType => new ArrowType.Int(8 * 8, true)
+    case FloatType => new 
ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
+    case DoubleType => new 
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
+    case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE
+    case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE
+    case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE
+    case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE
+    case DecimalType.Fixed(precision, scale) => new 
ArrowType.Decimal(precision, scale)
+    case DateType => new ArrowType.Date(DateUnit.DAY)
+    case TimestampType if timeZoneId == null =>
+      throw new IllegalStateException("Missing timezoneId where it is 
mandatory.")
+    case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, 
timeZoneId)
+    case TimestampNTZType =>
+      new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
+    case NullType => ArrowType.Null.INSTANCE
+    case _: YearMonthIntervalType => new 
ArrowType.Interval(IntervalUnit.YEAR_MONTH)
+    case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
+    case _ =>
+      throw ExecutionErrors.unsupportedDataTypeError(dt)
+  }
+
+  def fromArrowType(dt: ArrowType): DataType = dt match {
+    case ArrowType.Bool.INSTANCE => BooleanType
+    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => 
ByteType
+    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => 
ShortType
+    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => 
IntegerType
+    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => 
LongType
+    case float: ArrowType.FloatingPoint
+      if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType
+    case float: ArrowType.FloatingPoint
+      if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType
+    case ArrowType.Utf8.INSTANCE => StringType
+    case ArrowType.Binary.INSTANCE => BinaryType
+    case ArrowType.LargeUtf8.INSTANCE => StringType
+    case ArrowType.LargeBinary.INSTANCE => BinaryType
+    case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
+    case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
+    case ts: ArrowType.Timestamp
+      if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => 
TimestampNTZType
+    case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => 
TimestampType
+    case ArrowType.Null.INSTANCE => NullType
+    case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => 
YearMonthIntervalType()
+    case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => 
DayTimeIntervalType()
+    case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt)
+  }
+
+  /** Maps field from Spark to Arrow. NOTE: timeZoneId required for 
TimestampType */
+  def toArrowField(
+                    name: String,
+                    dt: DataType,
+                    nullable: Boolean,
+                    timeZoneId: String,
+                    largeVarTypes: Boolean = false): Field = {
+    dt match {
+      case GeometryUDT =>
+        val jsonData = """{"crs": {"$schema": 
"https://proj.org/schemas/v0.7/projjson.schema.json";, "type": "GeographicCRS", 
"name": "WGS 84", "datum_ensemble": {"name": "World Geodetic System 1984 
ensemble", "members": [{"name": "World Geodetic System 1984 (Transit)", "id": 
{"authority": "EPSG", "code": 1166}}, {"name": "World Geodetic System 1984 
(G730)", "id": {"authority": "EPSG", "code": 1152}}, {"name": "World Geodetic 
System 1984 (G873)", "id": {"authority": "EPSG", "code": 1153}} [...]
+        val metadata = Map(
+          "ARROW:extension:name" -> "geoarrow.wkb",
+          "ARROW:extension:metadata" -> jsonData,
+        ).asJava
+
+        val fieldType = new FieldType(nullable, ArrowType.Binary.INSTANCE, 
null, metadata)
+        new Field(name, fieldType, Seq.empty[Field].asJava)
+
+      case ArrayType(elementType, containsNull) =>
+        val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null)
+        new Field(name, fieldType,
+          Seq(toArrowField("element", elementType, containsNull, timeZoneId,
+            largeVarTypes)).asJava)
+      case StructType(fields) =>
+        val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, 
null)
+        new Field(name, fieldType,
+          fields.map { field =>
+            toArrowField(field.name, field.dataType, field.nullable, 
timeZoneId, largeVarTypes)
+          }.toSeq.asJava)
+      case MapType(keyType, valueType, valueContainsNull) =>
+        val mapType = new FieldType(nullable, new ArrowType.Map(false), null)
+        // Note: Map Type struct can not be null, Struct Type key field can 
not be null
+        new Field(name, mapType,
+          Seq(toArrowField(MapVector.DATA_VECTOR_NAME,
+            new StructType()
+              .add(MapVector.KEY_NAME, keyType, nullable = false)
+              .add(MapVector.VALUE_NAME, valueType, nullable = 
valueContainsNull),
+            nullable = false,
+            timeZoneId,
+            largeVarTypes)).asJava)
+      case udt: UserDefinedType[_] =>
+        toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes)
+      case dataType =>
+        val fieldType = new FieldType(nullable, toArrowType(dataType, 
timeZoneId,
+          largeVarTypes), null)
+        new Field(name, fieldType, Seq.empty[Field].asJava)
+    }
+  }
+
+  def fromArrowField(field: Field): DataType = {
+    field.getType match {
+      case _: ArrowType.Map =>
+        val elementField = field.getChildren.get(0)
+        val keyType = fromArrowField(elementField.getChildren.get(0))
+        val valueType = fromArrowField(elementField.getChildren.get(1))
+        MapType(keyType, valueType, elementField.getChildren.get(1).isNullable)
+      case ArrowType.List.INSTANCE =>
+        val elementField = field.getChildren().get(0)
+        val elementType = fromArrowField(elementField)
+        ArrayType(elementType, containsNull = elementField.isNullable)
+      case ArrowType.Struct.INSTANCE =>
+        val fields = field.getChildren().asScala.map { child =>
+          val dt = fromArrowField(child)
+          StructField(child.getName, dt, child.isNullable)
+        }
+        StructType(fields.toArray)
+      case arrowType => fromArrowType(arrowType)
+    }
+  }
+
+  /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for 
TimestampType in StructType */
+  def toArrowSchema(
+                     schema: StructType,
+                     timeZoneId: String,
+                     errorOnDuplicatedFieldNames: Boolean,
+                     largeVarTypes: Boolean = false): Schema = {
+    new Schema(schema.map { field =>
+      toArrowField(
+        field.name,
+        deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames),
+        field.nullable,
+        timeZoneId,
+        largeVarTypes)
+    }.asJava)
+  }
+
+  def fromArrowSchema(schema: Schema): StructType = {
+    StructType(schema.getFields.asScala.map { field =>
+      val dt = fromArrowField(field)
+      StructField(field.getName, dt, field.isNullable)
+    }.toArray)
+  }
+
+  private def deduplicateFieldNames(
+                                     dt: DataType, 
errorOnDuplicatedFieldNames: Boolean): DataType = dt match {
+    case geometryType: GeometryUDT => geometryType
+    case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, 
errorOnDuplicatedFieldNames)
+    case st @ StructType(fields) =>
+      val newNames = if (st.names.toSet.size == st.names.length) {
+        st.names
+      } else {
+        if (errorOnDuplicatedFieldNames) {
+          throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names)
+        }
+        val genNawName = st.names.groupBy(identity).map {
+          case (name, names) if names.length > 1 =>
+            val i = new AtomicInteger()
+            name -> { () => s"${name}_${i.getAndIncrement()}" }
+          case (name, _) => name -> { () => name }
+        }
+        st.names.map(genNawName(_)())
+      }
+      val newFields =
+        fields.zip(newNames).map { case (StructField(_, dataType, nullable, 
metadata), name) =>
+          StructField(
+            name, deduplicateFieldNames(dataType, 
errorOnDuplicatedFieldNames), nullable, metadata)
+        }
+      StructType(newFields)
+    case ArrayType(elementType, containsNull) =>
+      ArrayType(deduplicateFieldNames(elementType, 
errorOnDuplicatedFieldNames), containsNull)
+    case MapType(keyType, valueType, valueContainsNull) =>
+      MapType(
+        deduplicateFieldNames(keyType, errorOnDuplicatedFieldNames),
+        deduplicateFieldNames(valueType, errorOnDuplicatedFieldNames),
+        valueContainsNull)
+    case _ => dt
+  }
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
index d2c390282c..6791015ae9 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
@@ -84,9 +84,9 @@ private[python] trait SedonaPythonArrowInput[IN] { self: 
SedonaBasePythonRunner[
       }
 
       protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
-        val arrowSchema = ArrowUtils.toArrowSchema(
+        val arrowSchema = SedonaArrowUtils.toArrowSchema(
           schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
-        val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+        val allocator = SedonaArrowUtils.rootAllocator.newChildAllocator(
           s"stdout writer for $pythonExec", 0, Long.MaxValue)
         val root = VectorSchemaRoot.create(arrowSchema, allocator)
 
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
index 8d41848de9..525dafaefb 100644
--- 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
@@ -18,42 +18,62 @@
  */
 package org.apache.spark.sql.udf
 
-import org.apache.sedona.sql.TestBaseScala
+import org.apache.sedona.spark.SedonaContext
 import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.functions.col
 import org.apache.spark.sql.udf.ScalarUDF.geoPandasScalaFunction
 import org.locationtech.jts.io.WKTReader
+import org.scalatest.funsuite.AnyFunSuite
 import org.scalatest.matchers.should.Matchers
 
-class StrategySuite extends TestBaseScala with Matchers {
+class StrategySuite extends AnyFunSuite with Matchers {
   val wktReader = new WKTReader()
 
   val spark: SparkSession = {
-    sparkSession.sparkContext.setLogLevel("ALL")
-    sparkSession
+    val builder = SedonaContext
+      .builder()
+      .master("local[*]")
+      .appName("sedonasqlScalaTest")
+
+    val spark = SedonaContext.create(builder.getOrCreate())
+
+    spark.sparkContext.setLogLevel("ALL")
+    spark
   }
 
   import spark.implicits._
 
-  it("sedona geospatial UDF") {
-    val df = Seq(
-      (1, "value", wktReader.read("POINT(21 52)")),
-      (2, "value1", wktReader.read("POINT(20 50)")),
-      (3, "value2", wktReader.read("POINT(20 49)")),
-      (4, "value3", wktReader.read("POINT(20 48)")),
-      (5, "value4", wktReader.read("POINT(20 47)")))
-      .toDF("id", "value", "geom")
-      .withColumn("geom_buffer", geoPandasScalaFunction(col("geom")))
-
-    df.count shouldEqual 5
-
-    df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))")
-      .as[String]
-      .collect() should contain theSameElementsAs Seq(
-      "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))",
-      "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))",
-      "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))",
-      "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))",
-      "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))")
+  test("sedona geospatial UDF") {
+    spark.sql("select 1").show()
+    val currentTime = System.currentTimeMillis()
+
+    val df = spark.read.format("parquet")
+      
.load("/Users/pawelkocinski/Desktop/projects/sedona-book/apache-sedona-book/book/chapter10/data/buildings/partitioned")
+      .select(geoPandasScalaFunction(col("geometry")).alias("area"))
+      .selectExpr("sum(area) as total_area")
+
+    df.show()
+    val processingTime = System.currentTimeMillis() - currentTime
+    println(s"Processing time: $processingTime ms")
+
+//    val df = Seq(
+//      (1, "value", wktReader.read("POINT(21 52)")),
+//      (2, "value1", wktReader.read("POINT(20 50)")),
+//      (3, "value2", wktReader.read("POINT(20 49)")),
+//      (4, "value3", wktReader.read("POINT(20 48)")),
+//      (5, "value4", wktReader.read("POINT(20 47)")))
+//      .toDF("id", "value", "geom")
+//      .withColumn("geom_buffer", geoPandasScalaFunction(col("geom")))
+
+//    df.count shouldEqual 5
+
+//    df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))")
+//      .as[String]
+//      .collect() should contain theSameElementsAs Seq(
+//      "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))",
+//      "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))",
+//      "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))",
+//      "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))",
+//      "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))")
   }
 }
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
index 80a4c64106..fdc96ac024 100644
--- 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
@@ -25,6 +25,7 @@ import org.apache.spark.broadcast.Broadcast
 import org.apache.spark.internal.config.Python.{PYTHON_DAEMON_MODULE, 
PYTHON_USE_DAEMON, PYTHON_WORKER_MODULE}
 import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.types.FloatType
 import org.apache.spark.util.Utils
 
 import java.io.File
@@ -82,31 +83,21 @@ object ScalarUDF {
           "-c",
           f"""
             |from pyspark.sql.types import IntegerType
+            |from pyspark.sql.types import DoubleType
+            |from pyspark.sql.types import FloatType
             |from shapely.geometry import Point
             |from sedona.sql.types import GeometryType
             |from pyspark.serializers import CloudPickleSerializer
-            |from sedona.utils import geometry_serde
             |from shapely import box
             |import logging
             |logging.basicConfig(level=logging.INFO)
             |logger = logging.getLogger(__name__)
             |logger.info("Loading Sedona Python UDF")
             |import os
-            |logger.info(os.getcwd())
-            |import sys
-            |import sys
-            |print("boring stuff")
-            |sys.path.append('$additionalModule')
-            |logger.info(sys.path)
             |f = open('$path', 'wb');
-            |def w(x):
-            |    def apply_function(w):
-            |        geom, offset = geometry_serde.deserialize(w)
-            |        bounds = geom.buffer(1).bounds
-            |        x = box(*bounds)
-            |        return geometry_serde.serialize(x)
-            |    return x.apply(apply_function)
-            |f.write(CloudPickleSerializer().dumps((w, GeometryType())))
+            |def apply_geopandas(x):
+            |    return x.area
+            |f.write(CloudPickleSerializer().dumps((apply_geopandas, 
FloatType())))
             |""".stripMargin),
         None,
         "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
@@ -131,7 +122,7 @@ object ScalarUDF {
       pythonVer = pythonVer,
       broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
       accumulator = null),
-    dataType = GeometryUDT,
+    dataType = FloatType,
     pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF,
     udfDeterministic = true)
 }


Reply via email to