This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch arrow-worker in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 798cd3ca8d00fdc32f0091703c9edf1b543327a1 Author: pawelkocinski <[email protected]> AuthorDate: Wed Jul 30 00:07:46 2025 +0200 SEDONA-738 Fix unit tests. --- .../common/geometrySerde/GeometrySerializer.java | 53 ++--- sedonaworker/serializer.py | 187 ++++++++++++++++++ sedonaworker/worker.py | 154 +++++++++------ .../spark/api/python/SedonaPythonRunner.scala | 32 ++- .../sql/execution/python/SedonaArrowUtils.scala | 216 +++++++++++++++++++++ .../execution/python/SedonaPythonArrowInput.scala | 4 +- .../org/apache/spark/sql/udf/StrategySuite.scala | 68 ++++--- .../apache/spark/sql/udf/TestScalarPandasUDF.scala | 23 +-- 8 files changed, 613 insertions(+), 124 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java index 508a62901d..325098c6ac 100644 --- a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java +++ b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java @@ -32,37 +32,46 @@ import org.locationtech.jts.geom.Point; import org.locationtech.jts.geom.Polygon; import org.locationtech.jts.geom.PrecisionModel; import org.locationtech.jts.io.WKBConstants; +import org.locationtech.jts.io.WKBReader; +import org.locationtech.jts.io.WKBWriter; public class GeometrySerializer { private static final Coordinate NULL_COORDINATE = new Coordinate(Double.NaN, Double.NaN); private static final PrecisionModel PRECISION_MODEL = new PrecisionModel(); public static byte[] serialize(Geometry geometry) { - GeometryBuffer buffer; - if (geometry instanceof Point) { - buffer = serializePoint((Point) geometry); - } else if (geometry instanceof MultiPoint) { - buffer = serializeMultiPoint((MultiPoint) geometry); - } else if (geometry instanceof LineString) { - buffer = serializeLineString((LineString) geometry); - } else if (geometry instanceof MultiLineString) { - buffer = serializeMultiLineString((MultiLineString) geometry); - } else if (geometry instanceof Polygon) { - buffer = serializePolygon((Polygon) geometry); - } else if (geometry instanceof MultiPolygon) { - buffer = serializeMultiPolygon((MultiPolygon) geometry); - } else if (geometry instanceof GeometryCollection) { - buffer = serializeGeometryCollection((GeometryCollection) geometry); - } else { - throw new UnsupportedOperationException( - "Geometry type is not supported: " + geometry.getClass().getSimpleName()); - } - return buffer.toByteArray(); + return new WKBWriter().write(geometry); +// GeometryBuffer buffer; +// if (geometry instanceof Point) { +// buffer = serializePoint((Point) geometry); +// } else if (geometry instanceof MultiPoint) { +// buffer = serializeMultiPoint((MultiPoint) geometry); +// } else if (geometry instanceof LineString) { +// buffer = serializeLineString((LineString) geometry); +// } else if (geometry instanceof MultiLineString) { +// buffer = serializeMultiLineString((MultiLineString) geometry); +// } else if (geometry instanceof Polygon) { +// buffer = serializePolygon((Polygon) geometry); +// } else if (geometry instanceof MultiPolygon) { +// buffer = serializeMultiPolygon((MultiPolygon) geometry); +// } else if (geometry instanceof GeometryCollection) { +// buffer = serializeGeometryCollection((GeometryCollection) geometry); +// } else { +// throw new UnsupportedOperationException( +// "Geometry type is not supported: " + geometry.getClass().getSimpleName()); +// } +// return buffer.toByteArray(); } public static Geometry deserialize(byte[] bytes) { - GeometryBuffer buffer = GeometryBufferFactory.wrap(bytes); - return deserialize(buffer); + WKBReader reader = new WKBReader(); + try { + return reader.read(bytes); + } catch (Exception e) { + throw new IllegalArgumentException("Failed to deserialize geometry from bytes", e); + } +// GeometryBuffer buffer = GeometryBufferFactory.wrap(bytes); +// return deserialize(buffer); } public static Geometry deserialize(GeometryBuffer buffer) { diff --git a/sedonaworker/serializer.py b/sedonaworker/serializer.py new file mode 100644 index 0000000000..319de34780 --- /dev/null +++ b/sedonaworker/serializer.py @@ -0,0 +1,187 @@ +import logging + +import pandas as pd +from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer, ArrowStreamSerializer +from pyspark.errors import PySparkTypeError, PySparkValueError +import struct +from pyspark.sql.pandas.types import ( + from_arrow_type, + to_arrow_type, + _create_converter_from_pandas, + _create_converter_to_pandas, +) + +def write_int(value, stream): + stream.write(struct.pack("!i", value)) + +class SpecialLengths: + END_OF_DATA_SECTION = -1 + PYTHON_EXCEPTION_THROWN = -2 + TIMING_DATA = -3 + END_OF_STREAM = -4 + NULL = -5 + START_ARROW_STREAM = -6 + + +class SedonaArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): + """ + Serializer used by Python worker to evaluate Pandas UDFs + """ + + def __init__( + self, + timezone, + safecheck, + assign_cols_by_name, + df_for_struct=False, + struct_in_pandas="dict", + ndarray_as_list=False, + arrow_cast=False, + ): + super(SedonaArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck) + self._assign_cols_by_name = assign_cols_by_name + self._df_for_struct = df_for_struct + self._struct_in_pandas = struct_in_pandas + self._ndarray_as_list = ndarray_as_list + self._arrow_cast = arrow_cast + + def load_stream(self, stream): + """ + Deserialize ArrowRecordBatches to an Arrow table and return as a list of pandas.Series. + """ + import geoarrow.pyarrow as ga + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + import pyarrow as pa + for batch in batches: + yield [ga.to_geopandas(c) for c in pa.Table.from_batches([batch]).itercolumns()] + + def _create_batch(self, series): + """ + Create an Arrow record batch from the given pandas.Series pandas.DataFrame + or list of Series or DataFrame, with optional type. + + Parameters + ---------- + series : pandas.Series or pandas.DataFrame or list + A single series or dataframe, list of series or dataframe, + or list of (series or dataframe, arrow_type) + + Returns + ------- + pyarrow.RecordBatch + Arrow RecordBatch + """ + import pyarrow as pa + + # Make input conform to [(series1, type1), (series2, type2), ...] + if not isinstance(series, (list, tuple)) or ( + len(series) == 2 and isinstance(series[1], pa.DataType) + ): + series = [series] + series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) + + arrs = [] + for s, t in series: + arrs.append(self._create_array(s, t, arrow_cast=self._arrow_cast)) + + return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) + + def _create_array(self, series, arrow_type, spark_type=None, arrow_cast=False): + """ + Create an Arrow Array from the given pandas.Series and optional type. + + Parameters + ---------- + series : pandas.Series + A single series + arrow_type : pyarrow.DataType, optional + If None, pyarrow's inferred type will be used + spark_type : DataType, optional + If None, spark type converted from arrow_type will be used + arrow_cast: bool, optional + Whether to apply Arrow casting when the user-specified return type mismatches the + actual return values. + + Returns + ------- + pyarrow.Array + """ + import pyarrow as pa + from pandas.api.types import is_categorical_dtype + if is_categorical_dtype(series.dtype): + series = series.astype(series.dtypes.categories.dtype) + if arrow_type is not None: + dt = spark_type or from_arrow_type(arrow_type, prefer_timestamp_ntz=True) + # TODO(SPARK-43579): cache the converter for reuse + conv = _create_converter_from_pandas( + dt, timezone=self._timezone, error_on_duplicated_field_names=False + ) + series = conv(series) + + if hasattr(series.array, "__arrow_array__"): + mask = None + else: + mask = series.isnull() + try: + try: + import geopandas as gpd + if isinstance(series, gpd.GeoSeries): + import geoarrow.pyarrow as ga + # If the series is a GeoSeries, convert it to an Arrow array using geoarrow + return ga.array(series) + + array = pa.Array.from_pandas( + series, mask=mask, type=arrow_type, safe=self._safecheck + ) + + return array + except pa.lib.ArrowInvalid: + if arrow_cast: + return pa.Array.from_pandas(series, mask=mask).cast( + target_type=arrow_type, safe=self._safecheck + ) + else: + raise + except TypeError as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e + except ValueError as e: + error_msg = ( + "Exception thrown when converting pandas.Series (%s) " + "with name '%s' to Arrow Array (%s)." + ) + if self._safecheck: + error_msg = error_msg + ( + " It can be caused by overflows or other " + "unsafe conversions warned by Arrow. Arrow safe type check " + "can be disabled by using SQL config " + "`spark.sql.execution.pandas.convertToArrowArraySafely`." + ) + raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + + def init_stream_yield_batches(): + should_write_start_length = True + for series in iterator: + batch = self._create_batch(series) + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) + + def __repr__(self): + return "ArrowStreamPandasUDFSerializer" + +from pyspark.serializers import Serializer, write_int + diff --git a/sedonaworker/worker.py b/sedonaworker/worker.py index 365561f0a6..c08d47588a 100644 --- a/sedonaworker/worker.py +++ b/sedonaworker/worker.py @@ -18,14 +18,18 @@ """ Worker that receives input from Piped RDD. """ +import logging import os import sys import time from inspect import currentframe, getframeinfo, getfullargspec import importlib import json +from io import BufferedRWPair from typing import Any, Iterable, Iterator +from sedonaworker.serializer import SedonaArrowStreamPandasUDFSerializer + # 'resource' is a Unix specific module. has_resource_module = True try: @@ -56,11 +60,7 @@ from pyspark.serializers import ( BatchedSerializer, ) from pyspark.sql.pandas.serializers import ( - ArrowStreamPandasUDFSerializer, - ArrowStreamPandasUDTFSerializer, - CogroupUDFSerializer, - ArrowStreamUDFSerializer, - ApplyInPandasWithStateSerializer, + ArrowStreamPandasUDFSerializer, ArrowStreamSerializer ) from pyspark.sql.pandas.types import to_arrow_type from pyspark.sql.types import BinaryType, StringType, StructType, _parse_datatype_json_string @@ -72,6 +72,51 @@ pickleSer = CPickleSerializer() utf8_deserializer = UTF8Deserializer() +class SedonaArrowStreamUDFSerializer(ArrowStreamSerializer): + """ + Same as :class:`ArrowStreamSerializer` but it flattens the struct to Arrow record batch + for applying each function with the raw record arrow batch. See also `DataFrame.mapInArrow`. + """ + + def load_stream(self, stream): + """ + Flatten the struct into Arrow's record batches. + """ + import pyarrow as pa + + batches = super(SedonaArrowStreamUDFSerializer, self).load_stream(stream) + for batch in batches: + struct = batch.column(0) + yield [pa.RecordBatch.from_arrays(struct.flatten(), schema=pa.schema(struct.type))] + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + import pyarrow as pa + + def wrap_and_init_stream(): + should_write_start_length = True + for batch, _ in iterator: + assert isinstance(batch, pa.RecordBatch) + + # Wrap the root struct + struct = pa.StructArray.from_arrays( + batch.columns, fields=pa.struct(list(batch.schema)) + ) + batch = pa.RecordBatch.from_arrays([struct], ["_0"]) + + # Write the first record batch with initialization. + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + yield batch + + return super(SedonaArrowStreamUDFSerializer, self).dump_stream(wrap_and_init_stream(), stream) + + def report_times(outfile, boot, init, finish): write_int(SpecialLengths.TIMING_DATA, outfile) write_long(int(1000 * boot), outfile) @@ -180,56 +225,45 @@ def assign_cols_by_name(runner_conf): def read_udfs(pickleSer, infile, eval_type): runner_conf = {} - if eval_type in ( - PythonEvalType.SQL_SCALAR_PANDAS_UDF, - ): - - # Load conf used for pandas_udf evaluation - num_conf = read_int(infile) - for i in range(num_conf): - k = utf8_deserializer.loads(infile) - v = utf8_deserializer.loads(infile) - runner_conf[k] = v - - state_object_schema = None - if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: - state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) - - # NOTE: if timezone is set here, that implies respectSessionTimeZone is True - timezone = runner_conf.get("spark.sql.session.timeZone", None) - safecheck = ( - runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() - == "true" - ) + # Load conf used for pandas_udf evaluation + num_conf = read_int(infile) + for i in range(num_conf): + k = utf8_deserializer.loads(infile) + v = utf8_deserializer.loads(infile) + runner_conf[k] = v + + # state_object_schema = None + # if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE: + # state_object_schema = StructType.fromJson(json.loads(utf8_deserializer.loads(infile))) + + # NOTE: if timezone is set here, that implies respectSessionTimeZone is True + timezone = runner_conf.get("spark.sql.session.timeZone", None) + safecheck = ( + runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() + == "true" + ) - if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF: - ser = CogroupUDFSerializer(timezone, safecheck, assign_cols_by_name(runner_conf)) - else: - # Scalar Pandas UDF handles struct type arguments as pandas DataFrames instead of - # pandas Series. See SPARK-27240. - df_for_struct = ( - eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF - or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF - or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF - ) - # Arrow-optimized Python UDF takes a struct type argument as a Row - struct_in_pandas = ( - "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict" - ) - ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF - # Arrow-optimized Python UDF uses explicit Arrow cast for type coercion - arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF - ser = ArrowStreamPandasUDFSerializer( - timezone, - safecheck, - assign_cols_by_name(runner_conf), - df_for_struct, - struct_in_pandas, - ndarray_as_list, - arrow_cast, - ) - else: - ser = BatchedSerializer(CPickleSerializer(), 100) + df_for_struct = ( + eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF + or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF + or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF + ) + # Arrow-optimized Python UDF takes a struct type argument as a Row + struct_in_pandas = ( + "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF else "dict" + ) + ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + # Arrow-optimized Python UDF uses explicit Arrow cast for type coercion + arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF + ser = SedonaArrowStreamPandasUDFSerializer( + timezone, + safecheck, + assign_cols_by_name(runner_conf), + df_for_struct, + struct_in_pandas, + ndarray_as_list, + arrow_cast, + ) num_udfs = read_int(infile) @@ -298,6 +332,7 @@ def read_udfs(pickleSer, infile, eval_type): # profiling is not supported for UDF return func, None, ser, ser + udfs = [] for i in range(num_udfs): udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i)) @@ -332,6 +367,7 @@ def main(infile, outfile): sys.exit(-1) version = utf8_deserializer.loads(infile) + if version != "%d.%d" % sys.version_info[:2]: raise PySparkRuntimeError( error_class="PYTHON_VERSION_MISMATCH", @@ -459,11 +495,7 @@ def main(infile, outfile): _accumulatorRegistry.clear() eval_type = read_int(infile) - if eval_type == PythonEvalType.NON_UDF: - func, profiler, deserializer, serializer = read_command(pickleSer, infile) - else: - func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) - + func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) init_time = time.time() def process(): @@ -539,3 +571,7 @@ if __name__ == "__main__": write_int(os.getpid(), sock_file) sock_file.flush() main(sock_file, sock_file) + + +class GeoArrowLoader: + pass diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala index 026518272c..fb01e62b5e 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala @@ -19,6 +19,8 @@ package org.apache.spark.api.python import org.apache.spark._ import org.apache.spark.SedonaSparkEnv +import org.apache.spark.api.python.PythonRDD.writeUTF +import org.apache.spark.input.PortableDataStream import org.apache.spark.internal.Logging import org.apache.spark.internal.config.Python._ import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} @@ -749,7 +751,7 @@ private[spark] class PythonRunner( } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - PythonRDD.writeIteratorToStream(inputIterator, dataOut) + GeoArrowWriter.writeIteratorToStream(inputIterator, dataOut) dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) } } @@ -809,3 +811,31 @@ private[spark] object BarrierTaskContextMessageProtocol { val BARRIER_RESULT_SUCCESS = "success" val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python side." } + +object GeoArrowWriter extends Logging { + def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): Unit = { + + def write(obj: Any): Unit = obj match { + case null => + dataOut.writeInt(SpecialLengths.NULL) + case arr: Array[Byte] => + logError("some random array") + dataOut.writeInt(arr.length) + dataOut.write(arr) + case str: String => + logError("some random string") + writeUTF(str, dataOut) + case stream: PortableDataStream => + logError("some random stream") + write(stream.toArray()) + case (key, value) => + logError("some random key value") + write(key) + write(value) + case other => + throw new SparkException("Unexpected element type " + other.getClass) + } + + iter.foreach(write) + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala new file mode 100644 index 0000000000..58166d173d --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala @@ -0,0 +1,216 @@ +package org.apache.spark.sql.execution.python + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import java.util.concurrent.atomic.AtomicInteger +import scala.collection.JavaConverters._ +import org.apache.arrow.memory.RootAllocator +import org.apache.arrow.vector.complex.MapVector +import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} +import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} +import org.apache.spark.sql.errors.ExecutionErrors +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types._ + +private[sql] object SedonaArrowUtils { + + val rootAllocator = new RootAllocator(Long.MaxValue) + + // todo: support more types. + + /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ + def toArrowType( + dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE + case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE + case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE + case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE + case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType if timeZoneId == null => + throw new IllegalStateException("Missing timezoneId where it is mandatory.") + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + case TimestampNTZType => + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case NullType => ArrowType.Null.INSTANCE + case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) + case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) + case _ => + throw ExecutionErrors.unsupportedDataTypeError(dt) + } + + def fromArrowType(dt: ArrowType): DataType = dt match { + case ArrowType.Bool.INSTANCE => BooleanType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType + case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType + case float: ArrowType.FloatingPoint + if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType + case ArrowType.Utf8.INSTANCE => StringType + case ArrowType.Binary.INSTANCE => BinaryType + case ArrowType.LargeUtf8.INSTANCE => StringType + case ArrowType.LargeBinary.INSTANCE => BinaryType + case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) + case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType + case ts: ArrowType.Timestamp + if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => TimestampNTZType + case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType + case ArrowType.Null.INSTANCE => NullType + case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType() + case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() + case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt) + } + + /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ + def toArrowField( + name: String, + dt: DataType, + nullable: Boolean, + timeZoneId: String, + largeVarTypes: Boolean = false): Field = { + dt match { + case GeometryUDT => + val jsonData = """{"crs": {"$schema": "https://proj.org/schemas/v0.7/projjson.schema.json", "type": "GeographicCRS", "name": "WGS 84", "datum_ensemble": {"name": "World Geodetic System 1984 ensemble", "members": [{"name": "World Geodetic System 1984 (Transit)", "id": {"authority": "EPSG", "code": 1166}}, {"name": "World Geodetic System 1984 (G730)", "id": {"authority": "EPSG", "code": 1152}}, {"name": "World Geodetic System 1984 (G873)", "id": {"authority": "EPSG", "code": 1153}} [...] + val metadata = Map( + "ARROW:extension:name" -> "geoarrow.wkb", + "ARROW:extension:metadata" -> jsonData, + ).asJava + + val fieldType = new FieldType(nullable, ArrowType.Binary.INSTANCE, null, metadata) + new Field(name, fieldType, Seq.empty[Field].asJava) + + case ArrayType(elementType, containsNull) => + val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) + new Field(name, fieldType, + Seq(toArrowField("element", elementType, containsNull, timeZoneId, + largeVarTypes)).asJava) + case StructType(fields) => + val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) + new Field(name, fieldType, + fields.map { field => + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes) + }.toSeq.asJava) + case MapType(keyType, valueType, valueContainsNull) => + val mapType = new FieldType(nullable, new ArrowType.Map(false), null) + // Note: Map Type struct can not be null, Struct Type key field can not be null + new Field(name, mapType, + Seq(toArrowField(MapVector.DATA_VECTOR_NAME, + new StructType() + .add(MapVector.KEY_NAME, keyType, nullable = false) + .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), + nullable = false, + timeZoneId, + largeVarTypes)).asJava) + case udt: UserDefinedType[_] => + toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) + case dataType => + val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId, + largeVarTypes), null) + new Field(name, fieldType, Seq.empty[Field].asJava) + } + } + + def fromArrowField(field: Field): DataType = { + field.getType match { + case _: ArrowType.Map => + val elementField = field.getChildren.get(0) + val keyType = fromArrowField(elementField.getChildren.get(0)) + val valueType = fromArrowField(elementField.getChildren.get(1)) + MapType(keyType, valueType, elementField.getChildren.get(1).isNullable) + case ArrowType.List.INSTANCE => + val elementField = field.getChildren().get(0) + val elementType = fromArrowField(elementField) + ArrayType(elementType, containsNull = elementField.isNullable) + case ArrowType.Struct.INSTANCE => + val fields = field.getChildren().asScala.map { child => + val dt = fromArrowField(child) + StructField(child.getName, dt, child.isNullable) + } + StructType(fields.toArray) + case arrowType => fromArrowType(arrowType) + } + } + + /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ + def toArrowSchema( + schema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean = false): Schema = { + new Schema(schema.map { field => + toArrowField( + field.name, + deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames), + field.nullable, + timeZoneId, + largeVarTypes) + }.asJava) + } + + def fromArrowSchema(schema: Schema): StructType = { + StructType(schema.getFields.asScala.map { field => + val dt = fromArrowField(field) + StructField(field.getName, dt, field.isNullable) + }.toArray) + } + + private def deduplicateFieldNames( + dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match { + case geometryType: GeometryUDT => geometryType + case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) + case st @ StructType(fields) => + val newNames = if (st.names.toSet.size == st.names.length) { + st.names + } else { + if (errorOnDuplicatedFieldNames) { + throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names) + } + val genNawName = st.names.groupBy(identity).map { + case (name, names) if names.length > 1 => + val i = new AtomicInteger() + name -> { () => s"${name}_${i.getAndIncrement()}" } + case (name, _) => name -> { () => name } + } + st.names.map(genNawName(_)()) + } + val newFields = + fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) => + StructField( + name, deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), nullable, metadata) + } + StructType(newFields) + case ArrayType(elementType, containsNull) => + ArrayType(deduplicateFieldNames(elementType, errorOnDuplicatedFieldNames), containsNull) + case MapType(keyType, valueType, valueContainsNull) => + MapType( + deduplicateFieldNames(keyType, errorOnDuplicatedFieldNames), + deduplicateFieldNames(valueType, errorOnDuplicatedFieldNames), + valueContainsNull) + case _ => dt + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala index d2c390282c..6791015ae9 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -84,9 +84,9 @@ private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[ } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema( + val arrowSchema = SedonaArrowUtils.toArrowSchema( schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) - val allocator = ArrowUtils.rootAllocator.newChildAllocator( + val allocator = SedonaArrowUtils.rootAllocator.newChildAllocator( s"stdout writer for $pythonExec", 0, Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala index 8d41848de9..525dafaefb 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -18,42 +18,62 @@ */ package org.apache.spark.sql.udf -import org.apache.sedona.sql.TestBaseScala +import org.apache.sedona.spark.SedonaContext import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.col import org.apache.spark.sql.udf.ScalarUDF.geoPandasScalaFunction import org.locationtech.jts.io.WKTReader +import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers -class StrategySuite extends TestBaseScala with Matchers { +class StrategySuite extends AnyFunSuite with Matchers { val wktReader = new WKTReader() val spark: SparkSession = { - sparkSession.sparkContext.setLogLevel("ALL") - sparkSession + val builder = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + + val spark = SedonaContext.create(builder.getOrCreate()) + + spark.sparkContext.setLogLevel("ALL") + spark } import spark.implicits._ - it("sedona geospatial UDF") { - val df = Seq( - (1, "value", wktReader.read("POINT(21 52)")), - (2, "value1", wktReader.read("POINT(20 50)")), - (3, "value2", wktReader.read("POINT(20 49)")), - (4, "value3", wktReader.read("POINT(20 48)")), - (5, "value4", wktReader.read("POINT(20 47)"))) - .toDF("id", "value", "geom") - .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) - - df.count shouldEqual 5 - - df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") - .as[String] - .collect() should contain theSameElementsAs Seq( - "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", - "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", - "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", - "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", - "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") + test("sedona geospatial UDF") { + spark.sql("select 1").show() + val currentTime = System.currentTimeMillis() + + val df = spark.read.format("parquet") + .load("/Users/pawelkocinski/Desktop/projects/sedona-book/apache-sedona-book/book/chapter10/data/buildings/partitioned") + .select(geoPandasScalaFunction(col("geometry")).alias("area")) + .selectExpr("sum(area) as total_area") + + df.show() + val processingTime = System.currentTimeMillis() - currentTime + println(s"Processing time: $processingTime ms") + +// val df = Seq( +// (1, "value", wktReader.read("POINT(21 52)")), +// (2, "value1", wktReader.read("POINT(20 50)")), +// (3, "value2", wktReader.read("POINT(20 49)")), +// (4, "value3", wktReader.read("POINT(20 48)")), +// (5, "value4", wktReader.read("POINT(20 47)"))) +// .toDF("id", "value", "geom") +// .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) + +// df.count shouldEqual 5 + +// df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") +// .as[String] +// .collect() should contain theSameElementsAs Seq( +// "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", +// "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", +// "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", +// "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", +// "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index 80a4c64106..fdc96ac024 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -25,6 +25,7 @@ import org.apache.spark.broadcast.Broadcast import org.apache.spark.internal.config.Python.{PYTHON_DAEMON_MODULE, PYTHON_USE_DAEMON, PYTHON_WORKER_MODULE} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.FloatType import org.apache.spark.util.Utils import java.io.File @@ -82,31 +83,21 @@ object ScalarUDF { "-c", f""" |from pyspark.sql.types import IntegerType + |from pyspark.sql.types import DoubleType + |from pyspark.sql.types import FloatType |from shapely.geometry import Point |from sedona.sql.types import GeometryType |from pyspark.serializers import CloudPickleSerializer - |from sedona.utils import geometry_serde |from shapely import box |import logging |logging.basicConfig(level=logging.INFO) |logger = logging.getLogger(__name__) |logger.info("Loading Sedona Python UDF") |import os - |logger.info(os.getcwd()) - |import sys - |import sys - |print("boring stuff") - |sys.path.append('$additionalModule') - |logger.info(sys.path) |f = open('$path', 'wb'); - |def w(x): - | def apply_function(w): - | geom, offset = geometry_serde.deserialize(w) - | bounds = geom.buffer(1).bounds - | x = box(*bounds) - | return geometry_serde.serialize(x) - | return x.apply(apply_function) - |f.write(CloudPickleSerializer().dumps((w, GeometryType()))) + |def apply_geopandas(x): + | return x.area + |f.write(CloudPickleSerializer().dumps((apply_geopandas, FloatType()))) |""".stripMargin), None, "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! @@ -131,7 +122,7 @@ object ScalarUDF { pythonVer = pythonVer, broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, accumulator = null), - dataType = GeometryUDT, + dataType = FloatType, pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, udfDeterministic = true) }
