This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch add-sedona-worker-daemon-mode in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 1ca3a5941450e56951b2a361679063a0b3a6753d Author: pawelkocinski <[email protected]> AuthorDate: Fri Dec 19 23:13:31 2025 +0100 SEDONA-738 Add sedonadb worker --- pom.xml | 2 +- python/pyproject.toml | 52 ++-- python/sedona/spark/sql/functions.py | 64 ++++- python/sedona/spark/worker/__init__.py | 0 python/sedona/spark/worker/serde.py | 82 ++++++ python/sedona/spark/worker/udf_info.py | 34 +++ python/sedona/spark/worker/worker.py | 295 +++++++++++++++++++++ python/tests/test_base.py | 2 +- .../tests/utils/test_sedona_db_vectorized_udf.py | 94 +++++++ .../org/apache/sedona/spark/SedonaContext.scala | 3 +- .../org/apache/sedona/sql/UDF/PythonEvalType.scala | 7 + .../execution/python/SedonaArrowPythonRunner.scala | 58 ++++ .../sql/execution/python/SedonaArrowStrategy.scala | 159 +++++++++++ .../execution/python/SedonaBasePythonRunner.scala | 121 +++++++++ .../execution/python/SedonaDBWorkerFactory.scala | 118 +++++++++ .../execution/python/SedonaPythonArrowInput.scala | 135 ++++++++++ .../execution/python/SedonaPythonArrowOutput.scala | 171 ++++++++++++ .../spark/sql/execution/python/WorkerContext.scala | 52 ++++ .../spark/sql/udf/ExtractSedonaUDFRule.scala | 13 +- .../apache/spark/sql/udf/SedonaArrowStrategy.scala | 89 ------- .../org/apache/sedona/sql/TestBaseScala.scala | 1 + .../org/apache/spark/sql/udf/StrategySuite.scala | 32 ++- .../apache/spark/sql/udf/TestScalarPandasUDF.scala | 162 +++++++++-- 23 files changed, 1586 insertions(+), 160 deletions(-) diff --git a/pom.xml b/pom.xml index d6e4e81319..613e310983 100644 --- a/pom.xml +++ b/pom.xml @@ -631,7 +631,7 @@ <plugin> <groupId>org.apache.maven.plugins</groupId> <artifactId>maven-javadoc-plugin</artifactId> - <version>2.10.4</version> + <version>3.12.0</version> <executions> <execution> <id>attach-javadocs</id> diff --git a/python/pyproject.toml b/python/pyproject.toml index b988966c4f..76169261c3 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -16,7 +16,7 @@ # under the License. [build-system] -requires = ["setuptools>=69", "wheel"] +requires = ["setuptools>=80.9.0", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -26,13 +26,19 @@ description = "Apache Sedona is a cluster computing system for processing large- readme = "README.md" license = { text = "Apache-2.0" } authors = [ { name = "Apache Sedona", email = "[email protected]" } ] -requires-python = ">=3.8" +requires-python = ">=3.12" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] dependencies = [ "attrs", + "geoarrow-c>=0.3.1", + "geoarrow-pyarrow>=0.2.0", + "geopandas>=1.1.2", + "pyarrow>=16.1.0", + "pyspark==3.5.4", + "sedonadb", "shapely>=1.7.0", ] @@ -43,38 +49,16 @@ kepler-map = ["geopandas", "keplergl==0.3.2"] flink = ["apache-flink>=1.19.0"] db = ["sedonadb[geopandas]; python_version >= '3.9'"] all = [ - "pyspark>=3.4.0,<4.1.0", - "geopandas", - "pydeck==0.8.0", - "keplergl==0.3.2", - "rasterio>=1.2.10", +# "pyspark>=3.4.0,<4.1.0", +# "geopandas", +# "pydeck==0.8.0", +# "keplergl==0.3.2", +# "rasterio>=1.2.10", ] [dependency-groups] dev = [ - "pytest", - "pytest-cov", - "notebook==6.4.12", - "jupyter", - "mkdocs", - "scikit-learn", - "esda", - "libpysal", - "matplotlib", # implicit dependency of esda - # prevent incompatibility with pysal 4.7.0, which is what is resolved to when shapely >2 is specified - "scipy<=1.10.0", - "pandas>=2.0.0", - "numpy<2", - "geopandas", - # https://stackoverflow.com/questions/78949093/how-to-resolve-attributeerror-module-fiona-has-no-attribute-path - # cannot set geopandas>=0.14.4 since it doesn't support python 3.8, so we pin fiona to <1.10.0 - "fiona<1.10.0", - "pyarrow", - "pyspark>=3.4.0,<4.1.0", - "keplergl==0.3.2", - "pydeck==0.8.0", - "pystac==1.5.0", - "rasterio>=1.2.10", + "pytest>=9.0.2", ] [project.urls] @@ -99,3 +83,11 @@ sources = [ "src/geom_buf.c", "src/geos_c_dyn.c", ] + +[tool.uv] +dev-dependencies = [ + "pytest>=9.0.2", +] + +[tool.uv.sources] +sedonadb = { path = "../../../sedona-db/target/wheels/sedonadb-0.3.0-cp312-cp312-macosx_11_0_arm64.whl" } diff --git a/python/sedona/spark/sql/functions.py b/python/sedona/spark/sql/functions.py index 2420301d52..7c480e1700 100644 --- a/python/sedona/spark/sql/functions.py +++ b/python/sedona/spark/sql/functions.py @@ -21,11 +21,14 @@ from enum import Enum import pandas as pd -from sedona.spark.sql.types import GeometryType from sedona.spark.utils import geometry_serde -from pyspark.sql.udf import UserDefinedFunction -from pyspark.sql.types import DataType from shapely.geometry.base import BaseGeometry +from pyspark.sql.udf import UserDefinedFunction +import pyarrow as pa +import geoarrow.pyarrow as ga +from sedonadb import udf as sedona_udf_module +from sedona.spark.sql.types import GeometryType +from pyspark.sql.types import DataType, FloatType, DoubleType, IntegerType, StringType SEDONA_SCALAR_EVAL_TYPE = 5200 @@ -142,3 +145,58 @@ def serialize_to_geometry_if_geom(data, return_type: DataType): return geometry_serde.serialize(data) return data + + +def infer_pa_type(spark_type: DataType): + if isinstance(spark_type, GeometryType): + return ga.wkb() + elif isinstance(spark_type, FloatType): + return pa.float32() + elif isinstance(spark_type, DoubleType): + return pa.float64() + elif isinstance(spark_type, IntegerType): + return pa.int32() + elif isinstance(spark_type, StringType): + return pa.string() + else: + raise NotImplementedError(f"Type {spark_type} is not supported yet.") + +def infer_input_type(spark_type: DataType): + if isinstance(spark_type, GeometryType): + return sedona_udf_module.GEOMETRY + elif isinstance(spark_type, FloatType) or isinstance(spark_type, DoubleType) or isinstance(spark_type, IntegerType): + return sedona_udf_module.NUMERIC + elif isinstance(spark_type, StringType): + return sedona_udf_module.STRING + else: + raise NotImplementedError(f"Type {spark_type} is not supported yet.") + +def infer_input_types(spark_types: list[DataType]): + pa_types = [] + for spark_type in spark_types: + pa_type = infer_input_type(spark_type) + pa_types.append(pa_type) + + return pa_types + + +def sedona_db_vectorized_udf( + return_type: DataType, + input_types: list[DataType] +): + def apply_fn(fn): + out_type = infer_pa_type(return_type) + input_types_sedona_db = infer_input_types(input_types) + + @sedona_udf_module.arrow_udf(out_type, input_types=input_types_sedona_db) + def shapely_udf(*args, **kwargs): + return fn(*args, **kwargs) + + udf = UserDefinedFunction( + lambda: shapely_udf, return_type, "SedonaPandasArrowUDF", evalType=6200 + ) + + return udf + + + return apply_fn diff --git a/python/sedona/spark/worker/__init__.py b/python/sedona/spark/worker/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/python/sedona/spark/worker/serde.py b/python/sedona/spark/worker/serde.py new file mode 100644 index 0000000000..31038b7fcd --- /dev/null +++ b/python/sedona/spark/worker/serde.py @@ -0,0 +1,82 @@ +import socket + +from pyspark.serializers import write_int, SpecialLengths +from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer + +from sedona.spark.worker.udf_info import UDFInfo + + +def read_available(buf, chunk=4096): + # buf.raw._sock.settimeout(0.01) # non-blocking-ish + data = bytearray() + index = 0 + while True: + index+=1 + try: + chunk_bytes = buf.read(chunk) + except socket.timeout: + break + + if not chunk_bytes and index > 10: + break + + data.extend(chunk_bytes) + + return bytes(data) + +class SedonaDBSerializer(ArrowStreamPandasSerializer): + def __init__(self, timezone, safecheck, db, udf_info: UDFInfo): + super(SedonaDBSerializer, self).__init__(timezone, safecheck) + self.db = db + self.udf_info = udf_info + + def load_stream(self, stream): + import pyarrow as pa + + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + index = 0 + for batch in batches: + table = pa.Table.from_batches(batches=[batch]) + import pyarrow as pa + df = self.db.create_data_frame(table) + table_name = f"my_table_{index}" + + df.to_view(table_name) + + sql_expression = self.udf_info.sedona_db_transformation_expr(table_name) + + index += 1 + + yield self.db.sql(sql_expression) + + def arrow_dump_stream(self, iterator, stream): + import pyarrow as pa + + writer = None + try: + for batch in iterator: + if writer is None: + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + # stream.flush() + finally: + if writer is not None: + writer.close() + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + + def init_stream_yield_batches(): + should_write_start_length = True + for batch in iterator: + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + + yield batch + + return self.arrow_dump_stream(init_stream_yield_batches(), stream) diff --git a/python/sedona/spark/worker/udf_info.py b/python/sedona/spark/worker/udf_info.py new file mode 100644 index 0000000000..d354bcea7e --- /dev/null +++ b/python/sedona/spark/worker/udf_info.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass + +from sedona.spark import GeometryType + + +@dataclass +class UDFInfo: + arg_offsets: list + geom_offsets: dict + function: object + return_type: object + name: str + + def get_function_call_sql(self, table_name: str) -> str: + arg_offset_str = ", ".join([f"_{el}" for el in self.arg_offsets]) + function_expr = f"{self.name}({arg_offset_str})" + if isinstance(self.return_type, GeometryType): + return f"SELECT ST_GeomToSedonaSpark({function_expr}) AS _0 FROM {table_name}" + + return f"SELECT {function_expr} AS _0 FROM {table_name}" + + def sedona_db_transformation_expr(self, table_name: str) -> str: + fields = [] + for arg in self.arg_offsets: + if arg in self.geom_offsets: + crs = self.geom_offsets[arg] + fields.append(f"ST_GeomFromSedonaSpark(_{arg}, 'EPSG:{crs}') AS _{arg}") + continue + + fields.append(f"_{arg}") + + + fields_expr = ", ".join(fields) + return f"SELECT {fields_expr} FROM {table_name}" diff --git a/python/sedona/spark/worker/worker.py b/python/sedona/spark/worker/worker.py new file mode 100644 index 0000000000..74a61b02ee --- /dev/null +++ b/python/sedona/spark/worker/worker.py @@ -0,0 +1,295 @@ +import importlib +import os +import sys +import time + +import sedonadb +from pyspark import TaskContext, shuffle, SparkFiles +from pyspark.errors import PySparkRuntimeError +from pyspark.java_gateway import local_connect_and_auth +from pyspark.resource import ResourceInformation +from pyspark.serializers import read_int, UTF8Deserializer, read_bool, read_long, CPickleSerializer, write_int, \ + write_long, SpecialLengths + +from sedona.spark.worker.serde import SedonaDBSerializer +from sedona.spark.worker.udf_info import UDFInfo + + +def apply_iterator(db, iterator, udf_info: UDFInfo): + i = 0 + for df in iterator: + i+=1 + table_name = f"output_table_{i}" + df.to_view(table_name) + + function_call_sql = udf_info.get_function_call_sql(table_name) + + df_out = db.sql(function_call_sql) + df_out.to_view(f"view_{i}") + at = df_out.to_arrow_table() + batches = at.combine_chunks().to_batches() + + for batch in batches: + yield batch + + +def check_python_version(utf_serde: UTF8Deserializer, infile) -> str: + version = utf_serde.loads(infile) + + python_major, python_minor = sys.version_info[:2] + + if version != f"{python_major}.{python_minor}": + raise PySparkRuntimeError( + error_class="PYTHON_VERSION_MISMATCH", + message_parameters={ + "worker_version": str(sys.version_info[:2]), + "driver_version": str(version), + }, + ) + + return version + +def check_barrier_flag(infile): + is_barrier = read_bool(infile) + bound_port = read_int(infile) + secret = UTF8Deserializer().loads(infile) + + if is_barrier: + raise PySparkRuntimeError( + error_class="BARRIER_MODE_NOT_SUPPORTED", + message_parameters={ + "worker_version": str(sys.version_info[:2]), + "message": "Barrier mode is not supported by SedonaDB vectorized functions.", + }, + ) + + return is_barrier + +def assign_task_context(utf_serde: UTF8Deserializer, infile): + stage_id = read_int(infile) + partition_id = read_int(infile) + attempt_number = read_long(infile) + task_attempt_id = read_int(infile) + cpus = read_int(infile) + + task_context = TaskContext._getOrCreate() + task_context._stage_id = stage_id + task_context._partition_id = partition_id + task_context._attempt_number = attempt_number + task_context._task_attempt_id = task_attempt_id + task_context._cpus = cpus + + for r in range(read_int(infile)): + key = utf_serde.loads(infile) + name = utf_serde.loads(infile) + addresses = [] + task_context._resources = {} + for a in range(read_int(infile)): + addresses.append(utf_serde.loads(infile)) + task_context._resources[key] = ResourceInformation(name, addresses) + + task_context._localProperties = dict() + for i in range(read_int(infile)): + k = utf_serde.loads(infile) + v = utf_serde.loads(infile) + task_context._localProperties[k] = v + + return task_context + +def resolve_python_path(utf_serde: UTF8Deserializer, infile): + def add_path(path: str): + # worker can be used, so do not add path multiple times + if path not in sys.path: + # overwrite system packages + sys.path.insert(1, path) + + spark_files_dir = utf_serde.loads(infile) + # _accumulatorRegistry.clear() + + SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True + + add_path(spark_files_dir) # *.py files that were added will be copied here + num_python_includes = read_int(infile) + for _ in range(num_python_includes): + filename = utf_serde.loads(infile) + add_path(os.path.join(spark_files_dir, filename)) + + importlib.invalidate_caches() + + +def check_broadcast_variables(infile): + needs_broadcast_decryption_server = read_bool(infile) + num_broadcast_variables = read_int(infile) + + if needs_broadcast_decryption_server or num_broadcast_variables > 0: + raise PySparkRuntimeError( + error_class="BROADCAST_VARS_NOT_SUPPORTED", + message_parameters={ + "worker_version": str(sys.version_info[:2]), + "message": "Broadcast variables are not supported by SedonaDB vectorized functions.", + }, + ) + +def get_runner_conf(utf_serde: UTF8Deserializer, infile): + runner_conf = {} + num_conf = read_int(infile) + for i in range(num_conf): + k = utf_serde.loads(infile) + v = utf_serde.loads(infile) + runner_conf[k] = v + return runner_conf + + +def read_command(serializer, infile): + command = serializer._read_with_length(infile) + return command + +def read_udf(infile, pickle_ser) -> UDFInfo: + num_arg = read_int(infile) + arg_offsets = [read_int(infile) for i in range(num_arg)] + + function = None + return_type = None + + for i in range(read_int(infile)): + function, return_type = read_command(pickle_ser, infile) + + sedona_db_udf_expression = function() + + return UDFInfo( + arg_offsets=arg_offsets, + function=sedona_db_udf_expression, + return_type=return_type, + name=sedona_db_udf_expression._name, + geom_offsets=[0] + ) + +# def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): +# num_arg = read_int(infile) +# arg_offsets = [read_int(infile) for i in range(num_arg)] +# chained_func = None +# for i in range(read_int(infile)): +# f, return_type = read_command(pickleSer, infile) +# if chained_func is None: +# chained_func = f +# else: +# chained_func = chain(chained_func, f) +# +# func = chained_func +# +# # the last returnType will be the return type of UDF +# if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: +# return arg_offsets, func, return_type +# else: +# raise ValueError("Unknown eval type: {}".format(eval_type)) +# + +def register_sedona_db_udf(infile, pickle_ser) -> UDFInfo: + num_udfs = read_int(infile) + + udf = None + for _ in range(num_udfs): + udf = read_udf(infile, pickle_ser) + # Here we would register the UDF with SedonaDB's internal context + + + return udf + + +def report_times(outfile, boot, init, finish): + write_int(SpecialLengths.TIMING_DATA, outfile) + write_long(int(1000 * boot), outfile) + write_long(int(1000 * init), outfile) + write_long(int(1000 * finish), outfile) + + +def write_statistics(infile, outfile, boot_time, init_time) -> None: + TaskContext._setTaskContext(None) + finish_time = time.time() + report_times(outfile, boot_time, init_time, finish_time) + write_long(shuffle.MemoryBytesSpilled, outfile) + write_long(shuffle.DiskBytesSpilled, outfile) + + # Mark the beginning of the accumulators section of the output + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + # write_int(len(_accumulatorRegistry), outfile) + # for (aid, accum) in _accumulatorRegistry.items(): + # pickleSer._write_with_length((aid, accum._value), outfile) + + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + outfile.flush() + else: + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + outfile.flush() + sys.exit(-1) + + +def main(infile, outfile): + boot_time = time.time() + sedona_db = sedonadb.connect() + # + utf8_deserializer = UTF8Deserializer() + pickle_ser = CPickleSerializer() + + split_index = read_int(infile) + # + check_python_version(utf8_deserializer, infile) + # + check_barrier_flag(infile) + + task_context = assign_task_context(utf_serde=utf8_deserializer, infile=infile) + shuffle.MemoryBytesSpilled = 0 + shuffle.DiskBytesSpilled = 0 + + resolve_python_path(utf8_deserializer, infile) + # + check_broadcast_variables(infile) + + eval_type = read_int(infile) + + runner_conf = get_runner_conf(utf8_deserializer, infile) + + udf = register_sedona_db_udf(infile, pickle_ser) + + sedona_db.register_udf(udf.function) + init_time = time.time() + + serde = SedonaDBSerializer( + timezone=runner_conf.get("spark.sql.session.timeZone", "UTC"), + safecheck=False, + db=sedona_db, + udf_info=udf + ) + + number_of_geometries = read_int(infile) + geom_offsets = {} + for i in range(number_of_geometries): + geom_index = read_int(infile) + geom_srid = read_int(infile) + + geom_offsets[geom_index] = geom_srid + + udf.geom_offsets = geom_offsets + + iterator = serde.load_stream(infile) + out_iterator = apply_iterator(db=sedona_db, iterator=iterator, udf_info=udf) + + serde.dump_stream(out_iterator, outfile) + + write_statistics( + infile, outfile, boot_time=boot_time, init_time=init_time + ) + + +if __name__ == "__main__": + # add file handler + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + (sock_file, sc) = local_connect_and_auth(java_port, auth_secret) + + write_int(os.getpid(), sock_file) + sock_file.flush() + + main(sock_file, sock_file) diff --git a/python/tests/test_base.py b/python/tests/test_base.py index cc2b09e422..a6dbae6597 100644 --- a/python/tests/test_base.py +++ b/python/tests/test_base.py @@ -22,7 +22,7 @@ from typing import Iterable, Union import pyspark from pyspark.sql import DataFrame -from sedona.spark import * +from sedona.spark import SedonaContext from sedona.spark.utils.decorators import classproperty SPARK_REMOTE = os.getenv("SPARK_REMOTE") diff --git a/python/tests/utils/test_sedona_db_vectorized_udf.py b/python/tests/utils/test_sedona_db_vectorized_udf.py new file mode 100644 index 0000000000..749d45420e --- /dev/null +++ b/python/tests/utils/test_sedona_db_vectorized_udf.py @@ -0,0 +1,94 @@ +from sedona.spark.sql.functions import sedona_db_vectorized_udf +from tests.test_base import TestBase +import pyarrow as pa +import shapely +from sedona.sql import GeometryType +from pyspark.sql.functions import expr, lit +from pyspark.sql.types import DoubleType, IntegerType + + +class TestSedonaDBArrowFunction(TestBase): + def test_vectorized_udf(self): + @sedona_db_vectorized_udf(return_type=GeometryType(), input_types=[GeometryType(), IntegerType()]) + def my_own_function(geom, distance): + geom_wkb = pa.array(geom.storage.to_array()) + distance = pa.array(distance.to_array()) + geom = shapely.from_wkb(geom_wkb) + + result_shapely = shapely.centroid(geom) + + return pa.array(shapely.to_wkb(result_shapely)) + + df = self.spark.createDataFrame( + [ + (1, "POINT (1 1)"), + (2, "POINT (2 2)"), + (3, "POINT (3 3)"), + ], + ["id", "wkt"], + ).withColumn("wkt", expr("ST_GeomFromWKT(wkt)")) + + df.select(my_own_function(df.wkt, lit(100)).alias("geom")).show() + + def test_geometry_to_double(self): + @sedona_db_vectorized_udf(return_type=DoubleType(), input_types=[GeometryType()]) + def geometry_to_non_geometry_udf(geom): + geom_wkb = pa.array(geom.storage.to_array()) + geom = shapely.from_wkb(geom_wkb) + + result_shapely = shapely.get_x(shapely.centroid(geom)) + + return pa.array(result_shapely, pa.float64()) + + df = self.spark.createDataFrame( + [(1, "POINT (1 1)"), (2, "POINT (2 2)"), (3, "POINT (3 3)")], + ["id", "wkt"], + ).withColumn("wkt", expr("ST_GeomFromWKT(wkt)")) + + values = df.select(geometry_to_non_geometry_udf(df.wkt).alias("x_coord")) \ + .collect() + + values_list = [row["x_coord"] for row in values] + + assert values_list == [1.0, 2.0, 3.0] + + def test_geometry_to_int(self): + @sedona_db_vectorized_udf(return_type=IntegerType(), input_types=[GeometryType()]) + def geometry_to_int(geom): + geom_wkb = pa.array(geom.storage.to_array()) + geom = shapely.from_wkb(geom_wkb) + + result_shapely = shapely.get_num_points(geom) + + return pa.array(result_shapely, pa.int32()) + + df = self.spark.createDataFrame( + [(1, "POINT (1 1)"), (2, "POINT (2 2)"), (3, "POINT (3 3)")], + ["id", "wkt"], + ).withColumn("wkt", expr("ST_GeomFromWKT(wkt)")) + + values = df.select(geometry_to_int(df.wkt)) \ + .collect() + + values_list = [row[0] for row in values] + + assert values_list == [0, 0, 0] + + def test_geometry_crs_preservation(self): + @sedona_db_vectorized_udf(return_type=GeometryType(), input_types=[GeometryType()]) + def return_same_geometry(geom): + geom_wkb = pa.array(geom.storage.to_array()) + geom = shapely.from_wkb(geom_wkb) + + return pa.array(shapely.to_wkb(geom)) + + df = self.spark.createDataFrame( + [(1, "POINT (1 1)"), (2, "POINT (2 2)"), (3, "POINT (3 3)")], + ["id", "wkt"], + ).withColumn("wkt", expr("ST_SetSRID(ST_GeomFromWKT(wkt), 3857)")) + + result_df = df.select(return_same_geometry(df.wkt).alias("geom")) + + crs_list = result_df.selectExpr("ST_SRID(geom)").rdd.flatMap(lambda x: x).collect() + + assert crs_list == [3857, 3857, 3857] diff --git a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala index b0e46cf6e9..add3caf225 100644 --- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala +++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala @@ -41,7 +41,6 @@ class InternalApi( extends StaticAnnotation object SedonaContext { - private def customOptimizationsWithSession(sparkSession: SparkSession) = Seq( new TransformNestedUDTParquet(sparkSession), @@ -72,7 +71,7 @@ object SedonaContext { val sedonaArrowStrategy = Try( Class - .forName("org.apache.spark.sql.udf.SedonaArrowStrategy") + .forName("org.apache.spark.sql.execution.python.SedonaArrowStrategy") .getDeclaredConstructor() .newInstance() .asInstanceOf[SparkStrategy]) diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala index aece26267d..11263dd7f6 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala @@ -23,7 +23,14 @@ object PythonEvalType { val SQL_SCALAR_SEDONA_UDF = 5200 val SEDONA_UDF_TYPE_CONSTANT = 5000 + // sedona db eval types + val SQL_SCALAR_SEDONA_DB_UDF = 6200 + val SEDONA_DB_UDF_TYPE_CONSTANT = 6000 + def toString(pythonEvalType: Int): String = pythonEvalType match { case SQL_SCALAR_SEDONA_UDF => "SQL_SCALAR_GEO_UDF" + case SQL_SCALAR_SEDONA_DB_UDF => "SQL_SCALAR_SEDONA_DB_UDF" } + + def evals(): Set[Int] = Set(SQL_SCALAR_SEDONA_UDF, SQL_SCALAR_SEDONA_DB_UDF) } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala new file mode 100644 index 0000000000..0d3960d2d8 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala @@ -0,0 +1,58 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.api.python._ +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.types._ +import org.apache.spark.sql.vectorized.ColumnarBatch + +/** + * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. + */ +class SedonaArrowPythonRunner( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + protected override val schema: StructType, + protected override val timeZoneId: String, + protected override val largeVarTypes: Boolean, + protected override val workerConf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String], + geometryFields: Seq[(Int, Int)]) + extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + funcs, + evalType, + argOffsets, + jobArtifactUUID, + geometryFields) + with SedonaBasicPythonArrowInput + with SedonaBasicPythonArrowOutput { + + override val errorOnDuplicatedFieldNames: Boolean = true + + override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize + require( + bufferSize >= 4, + "Pandas execution requires more than 4 bytes. Please set higher buffer. " + + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala new file mode 100644 index 0000000000..bb897931b6 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala @@ -0,0 +1,159 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.sedona.sql.UDF.PythonEvalType.{SQL_SCALAR_SEDONA_DB_UDF, SQL_SCALAR_SEDONA_UDF} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.udf.SedonaArrowEvalPython +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT + +import scala.collection.JavaConverters.asScalaIteratorConverter + +// We use custom Strategy to avoid Apache Spark assert on types, we +// can consider extending this to support other engines working with +// arrow data +class SedonaArrowStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case SedonaArrowEvalPython(udfs, output, child, evalType) => + SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil + case _ => Nil + } +} + +// It's modification og Apache Spark's ArrowEvalPythonExec, we remove the check on the types to allow geometry types +// here, it's initial version to allow the vectorized udf for Sedona geometry types. We can consider extending this +// to support other engines working with arrow data +case class SedonaArrowEvalPythonExec( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: SparkPlan, + evalType: Int) + extends EvalPythonExec + with PythonSQLMetrics { + + private val batchSize = conf.arrowMaxRecordsPerBatch + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val largeVarTypes = conf.arrowUseLargeVarTypes + private val pythonRunnerConf = + Map[String, String](SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + private def inferCRS(iterator: Iterator[InternalRow], schema: StructType): Seq[(Int, Int)] = { + // this triggers the iterator + if (!iterator.hasNext) { + return Seq.empty + } + + val row = iterator.next() + + val rowMatched = row match { + case generic: GenericInternalRow => + Some(generic) + case _ => None + } + + schema + .filter { field => + field.dataType == GeometryUDT + } + .zipWithIndex + .map { case (_, index) => + if (rowMatched.isEmpty || rowMatched.get.values(index) == null) (index, 0) + else { + val geom = rowMatched.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] + val preambleByte = geom(0) & 0xff + val hasSrid = (preambleByte & 0x01) != 0 + + var srid = 0 + if (hasSrid) { + val srid2 = (geom(1) & 0xff) << 16 + val srid1 = (geom(2) & 0xff) << 8 + val srid0 = geom(3) & 0xff + srid = srid2 | srid1 | srid0 + } + + (index, srid) + } + } + } + + protected override def evaluate( + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] = { + val (probe, full) = iter.duplicate + + val geometryFields = inferCRS(probe, schema) + + val batchIter = if (batchSize > 0) new BatchIterator(full, batchSize) else Iterator(full) + + evalType match { + case SQL_SCALAR_SEDONA_DB_UDF => + val columnarBatchIter = new SedonaArrowPythonRunner( + funcs, + evalType - PythonEvalType.SEDONA_DB_UDF_TYPE_CONSTANT, + argOffsets, + schema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID, + geometryFields).compute(batchIter, context.partitionId(), context) + + val result = columnarBatchIter.flatMap { batch => + batch.rowIterator.asScala + } + + result + + case SQL_SCALAR_SEDONA_UDF => + val columnarBatchIter = new ArrowPythonRunner( + funcs, + evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, + argOffsets, + schema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID).compute(batchIter, context.partitionId(), context) + + val iter = columnarBatchIter.flatMap { batch => + batch.rowIterator.asScala + } + + iter + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala new file mode 100644 index 0000000000..8ecc110e39 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import java.net._ +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ +import org.apache.spark._ +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.EXECUTOR_CORES +import org.apache.spark.internal.config.Python._ +import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} +import org.apache.spark.util._ + +private object SedonaBasePythonRunner { + + private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler") +} + +private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + jobArtifactUUID: Option[String], + val geometryFields: Seq[(Int, Int)] = Seq.empty) + extends BasePythonRunner[IN, OUT](funcs, evalType, argOffsets, jobArtifactUUID) + with Logging { + + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + + private val conf = SparkEnv.get.conf + private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) + + private def getWorkerMemoryMb(mem: Option[Long], cores: Int): Option[Long] = { + mem.map(_ / cores) + } + + import java.io._ + + override def compute( + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): Iterator[OUT] = { + val startTime = System.currentTimeMillis + val env = SparkEnv.get + + val execCoresProp = Option(context.getLocalProperty(EXECUTOR_CORES_LOCAL_PROPERTY)) + val memoryMb = Option(context.getLocalProperty(PYSPARK_MEMORY_LOCAL_PROPERTY)).map(_.toLong) + + if (simplifiedTraceback) { + envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") + } + // SPARK-30299 this could be wrong with standalone mode when executor + // cores might not be correct because it defaults to all cores on the box. + val execCores = execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES)) + val workerMemoryMb = getWorkerMemoryMb(memoryMb, execCores) + if (workerMemoryMb.isDefined) { + envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString) + } + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) + envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) + if (faultHandlerEnabled) { + envVars.put("PYTHON_FAULTHANDLER_DIR", SedonaBasePythonRunner.faultHandlerLogDir.toString) + } + + envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) + + val (worker: Socket, pid: Option[Int]) = { + WorkerContext.createPythonWorker(pythonExec, envVars.asScala.toMap) + } + + val releasedOrClosed = new AtomicBoolean(false) + + // Start a thread to feed the process input from our parent's iterator + val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + + context.addTaskCompletionListener[Unit] { _ => + writerThread.shutdownOnTaskCompletion() + if (releasedOrClosed.compareAndSet(false, true)) { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + writerThread.start() + + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val stdoutIterator = newReaderIterator( + stream, + writerThread, + startTime, + env, + worker, + pid, + releasedOrClosed, + context) + new InterruptibleIterator(context, stdoutIterator) + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala new file mode 100644 index 0000000000..add09a7cb2 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala @@ -0,0 +1,118 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.{SparkException, SparkFiles} +import org.apache.spark.api.python.{PythonUtils, PythonWorkerFactory} +import org.apache.spark.util.Utils + +import java.io.{DataInputStream, File} +import java.net.{InetAddress, ServerSocket, Socket} +import java.util.Arrays +import java.io.InputStream +import scala.collection.JavaConverters._ +import scala.collection.mutable +import org.apache.spark._ +import org.apache.spark.security.SocketAuthHelper +import org.apache.spark.util.RedirectThread + +class SedonaDBWorkerFactory(pythonExec: String, envVars: Map[String, String]) + extends PythonWorkerFactory(pythonExec, envVars) { + self => + + private val simpleWorkers = new mutable.WeakHashMap[Socket, Process]() + private val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + + private val sedonaUDFWorkerModule = + SparkEnv.get.conf.get("sedona.python.worker.udf.module", "sedona.spark.worker.worker") + + private val pythonPath = PythonUtils.mergePythonPaths( + PythonUtils.sparkPythonPath, + envVars.getOrElse("PYTHONPATH", ""), + sys.env.getOrElse("PYTHONPATH", "")) + + override def create(): (Socket, Option[Int]) = { + createSimpleWorker(sedonaUDFWorkerModule) + } + + private def createSimpleWorker(workerModule: String): (Socket, Option[Int]) = { + var serverSocket: ServerSocket = null + try { + serverSocket = new ServerSocket(0, 1, InetAddress.getLoopbackAddress()) + + // Create and start the worker + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) + val jobArtifactUUID = envVars.getOrElse("SPARK_JOB_ARTIFACT_UUID", "default") + if (jobArtifactUUID != "default") { + val f = new File(SparkFiles.getRootDirectory(), jobArtifactUUID) + f.mkdir() + pb.directory(f) + } + val workerEnv = pb.environment() + workerEnv.putAll(envVars.asJava) + workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") + workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) + if (Utils.preferIPv6) { + workerEnv.put("SPARK_PREFER_IPV6", "True") + } + val worker = pb.start() + + // Redirect worker stdout and stderr + redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) + + // Wait for it to connect to our socket, and validate the auth secret. + serverSocket.setSoTimeout(10000) + + try { + val socket = serverSocket.accept() + authHelper.authClient(socket) + // TODO: When we drop JDK 8, we can just use worker.pid() + val pid = new DataInputStream(socket.getInputStream).readInt() + if (pid < 0) { + throw new IllegalStateException("Python failed to launch worker with code " + pid) + } + self.synchronized { + simpleWorkers.put(socket, worker) + } + + (socket, Some(pid)) + } catch { + case e: Exception => + throw new SparkException("Python worker failed to connect back.", e) + } + } finally { + if (serverSocket != null) { + serverSocket.close() + } + } + } + + private def redirectStreamsToStderr(stdout: InputStream, stderr: InputStream): Unit = { + try { + new RedirectThread(stdout, System.err, "stdout reader for " + pythonExec).start() + new RedirectThread(stderr, System.err, "stderr reader for " + pythonExec).start() + } catch { + case e: Exception => + logError("Exception in redirecting streams", e) + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala new file mode 100644 index 0000000000..18db42ae0d --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -0,0 +1,135 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.execution.arrow.ArrowWriter +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.util.Utils +import org.apache.spark.{SparkEnv, TaskContext} + +import java.io.DataOutputStream +import java.net.Socket + +private[python] trait SedonaPythonArrowInput[IN] extends PythonArrowInput[IN] { + self: SedonaBasePythonRunner[IN, _] => + protected override def newWriterThread( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): WriterThread = { + new WriterThread(env, worker, inputIterator, partitionIndex, context) { + + protected override def writeCommand(dataOut: DataOutputStream): Unit = { + handleMetadataBeforeExec(dataOut) + writeUDF(dataOut, funcs, argOffsets) + + // write + dataOut.writeInt(self.geometryFields.length) + // write geometry field indices and their SRIDs + geometryFields.foreach { case (index, srid) => + dataOut.writeInt(index) + dataOut.writeInt(srid) + } + } + + protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { + val arrowSchema = + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) + val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdout writer for $pythonExec", + 0, + Long.MaxValue) + val root = VectorSchemaRoot.create(arrowSchema, allocator) + + Utils.tryWithSafeFinally { + val writer = new ArrowStreamWriter(root, null, dataOut) + writer.start() + + writeIteratorToArrowStream(root, writer, dataOut, inputIterator) + + // end writes footer to the output stream and doesn't clean any resources. + // It could throw exception if the output stream is closed, so it should be + // in the try block. + writer.end() + } { + // If we close root and allocator in TaskCompletionListener, there could be a race + // condition where the writer thread keeps writing to the VectorSchemaRoot while + // it's being closed by the TaskCompletion listener. + // Closing root and allocator here is cleaner because root and allocator is owned + // by the writer thread and is only visible to the writer thread. + // + // If the writer thread is interrupted by TaskCompletionListener, it should either + // (1) in the try block, in which case it will get an InterruptedException when + // performing io, and goes into the finally block or (2) in the finally block, + // in which case it will ignore the interruption and close the resources. + root.close() + allocator.close() + } + } + } + } +} + +private[python] trait SedonaBasicPythonArrowInput + extends SedonaPythonArrowInput[Iterator[InternalRow]] { + self: SedonaBasePythonRunner[Iterator[InternalRow], _] => + + protected def writeIteratorToArrowStream( + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[Iterator[InternalRow]]): Unit = { + val arrowWriter = ArrowWriter.create(root) + while (inputIterator.hasNext) { + val startData = dataOut.size() + val nextBatch = inputIterator.next() + + while (nextBatch.hasNext) { + arrowWriter.write(nextBatch.next()) + } + + arrowWriter.finish() + writer.writeBatch() + arrowWriter.reset() + val deltaData = dataOut.size() - startData + pythonMetrics("pythonDataSent") += deltaData + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala new file mode 100644 index 0000000000..a9421df0af --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala @@ -0,0 +1,171 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import java.io.DataInputStream +import java.net.Socket +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamReader +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} + +private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => + + protected def pythonMetrics: Map[String, SQLMetric] + + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT + + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[OUT] = { + + new ReaderIterator( + stream, + writerThread, + startTime, + env, + worker, + pid, + releasedOrClosed, + context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", + 0, + Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + private var eos = false + private var nextObj: OUT = _ + + context.addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + + def handleEndOfDataSectionSedona(): Unit = { + if (stream.readInt() == SpecialLengths.END_OF_STREAM) {} + + eos = true + } + + protected override def handleEndOfDataSection(): Unit = { + handleEndOfDataSectionSedona() + } + + override def hasNext: Boolean = nextObj != null || { + if (!eos) { + nextObj = read() + hasNext + } else { + false + } + } + + override def next(): OUT = { + if (hasNext) { + val obj = nextObj + nextObj = null.asInstanceOf[OUT] + obj + } else { + Iterator.empty.next() + } + } + + protected override def read(): OUT = { + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + val bytesReadStart = reader.bytesRead() + batchLoaded = reader.loadNextBatch() + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + val rowCount = root.getRowCount + batch.setNumRows(root.getRowCount) + val bytesReadEnd = reader.bytesRead() + pythonMetrics("pythonNumRowsReceived") += rowCount + pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart + deserializeColumnarBatch(batch, schema) + } else { + reader.close(false) + allocator.close() + read() + } + } else { + val specialSign = stream.readInt() + + specialSign match { + case SpecialLengths.START_ARROW_STREAM => + reader = new ArrowStreamReader(stream, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root + .getFieldVectors() + .asScala + .map { vector => + new ArrowColumnVector(vector) + } + .toArray[ColumnVector] + + read() + case SpecialLengths.TIMING_DATA => + handleTimingData() + read() + case SpecialLengths.PYTHON_EXCEPTION_THROWN => + throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null.asInstanceOf[OUT] + } + } + } catch handleException + } + } + } +} + +private[python] trait SedonaBasicPythonArrowOutput + extends SedonaPythonArrowOutput[ColumnarBatch] { + self: BasePythonRunner[_, ColumnarBatch] => + + protected def deserializeColumnarBatch( + batch: ColumnarBatch, + schema: StructType): ColumnarBatch = batch +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala new file mode 100644 index 0000000000..dbad8358d6 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala @@ -0,0 +1,52 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import java.net.Socket +import scala.collection.mutable + +object WorkerContext { + + def createPythonWorker( + pythonExec: String, + envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.getOrElseUpdate(key, new SedonaDBWorkerFactory(pythonExec, envVars)).create() + } + } + + private[spark] def destroyPythonWorker( + pythonExec: String, + envVars: Map[String, String], + worker: Socket): Unit = { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers + .get(key) + .foreach(workerFactory => { + workerFactory.stopWorker(worker) + }) + } + } + + private val pythonWorkers = + mutable.HashMap[(String, Map[String, String]), SedonaDBWorkerFactory]() + +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala index 3d3301580c..ebb5a568e1 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala @@ -44,9 +44,7 @@ class ExtractSedonaUDFRule extends Rule[LogicalPlan] with Logging { } def isScalarPythonUDF(e: Expression): Boolean = { - e.isInstanceOf[PythonUDF] && e - .asInstanceOf[PythonUDF] - .evalType == PythonEvalType.SQL_SCALAR_SEDONA_UDF + e.isInstanceOf[PythonUDF] && PythonEvalType.evals.contains(e.asInstanceOf[PythonUDF].evalType) } private def collectEvaluableUDFsFromExpressions( @@ -168,13 +166,12 @@ class ExtractSedonaUDFRule extends Rule[LogicalPlan] with Logging { evalTypes.mkString(",")) } val evalType = evalTypes.head - val evaluation = evalType match { - case PythonEvalType.SQL_SCALAR_SEDONA_UDF => - SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType) - case _ => - throw new IllegalStateException("Unexpected UDF evalType") + if (!PythonEvalType.evals().contains(evalType)) { + throw new IllegalStateException(s"Unexpected UDF evalType: $evalType") } + val evaluation = SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType) + attributeMap ++= validUdfs.map(canonicalizeDeterministic).zip(resultAttrs) evaluation } else { diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala deleted file mode 100644 index a403fa6b9e..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala +++ /dev/null @@ -1,89 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.udf - -import org.apache.sedona.sql.UDF.PythonEvalType -import org.apache.spark.api.python.ChainedPythonFunctions -import org.apache.spark.{JobArtifactSet, TaskContext} -import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator, EvalPythonExec, PythonSQLMetrics} -import org.apache.spark.sql.types.StructType - -import scala.collection.JavaConverters.asScalaIteratorConverter - -// We use custom Strategy to avoid Apache Spark assert on types, we -// can consider extending this to support other engines working with -// arrow data -class SedonaArrowStrategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case SedonaArrowEvalPython(udfs, output, child, evalType) => - SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil - case _ => Nil - } -} - -// It's modification og Apache Spark's ArrowEvalPythonExec, we remove the check on the types to allow geometry types -// here, it's initial version to allow the vectorized udf for Sedona geometry types. We can consider extending this -// to support other engines working with arrow data -case class SedonaArrowEvalPythonExec( - udfs: Seq[PythonUDF], - resultAttrs: Seq[Attribute], - child: SparkPlan, - evalType: Int) - extends EvalPythonExec - with PythonSQLMetrics { - - private val batchSize = conf.arrowMaxRecordsPerBatch - private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val largeVarTypes = conf.arrowUseLargeVarTypes - private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) - - protected override def evaluate( - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]], - iter: Iterator[InternalRow], - schema: StructType, - context: TaskContext): Iterator[InternalRow] = { - - val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) - - val columnarBatchIter = new ArrowPythonRunner( - funcs, - evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, - argOffsets, - schema, - sessionLocalTimeZone, - largeVarTypes, - pythonRunnerConf, - pythonMetrics, - jobArtifactUUID).compute(batchIter, context.partitionId(), context) - - columnarBatchIter.flatMap { batch => - batch.rowIterator.asScala - } - } - - override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = - copy(child = newChild) -} diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 28943ff11d..e0b81c5e47 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -46,6 +46,7 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { // We need to be explicit about broadcasting in tests. .config("sedona.join.autoBroadcastJoinThreshold", "-1") .config("spark.sql.extensions", "org.apache.sedona.sql.SedonaSqlExtensions") + .config("sedona.python.worker.udf.module", "sedonaworker.worker") .config(keyParserExtension, ThreadLocalRandom.current().nextBoolean()) .getOrCreate() diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala index 8d41848de9..7719b2199c 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.udf import org.apache.sedona.sql.TestBaseScala import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.functions.col -import org.apache.spark.sql.udf.ScalarUDF.geoPandasScalaFunction +import org.apache.spark.sql.functions.{col, expr, lit} +import org.apache.spark.sql.udf.ScalarUDF.{geoPandasScalaFunction, sedonaDBGeometryToGeometryFunction} import org.locationtech.jts.io.WKTReader import org.scalatest.matchers.should.Matchers @@ -35,7 +35,8 @@ class StrategySuite extends TestBaseScala with Matchers { import spark.implicits._ - it("sedona geospatial UDF") { + + it("sedona geospatial UDF - geopandas") { val df = Seq( (1, "value", wktReader.read("POINT(21 52)")), (2, "value1", wktReader.read("POINT(20 50)")), @@ -43,11 +44,13 @@ class StrategySuite extends TestBaseScala with Matchers { (4, "value3", wktReader.read("POINT(20 48)")), (5, "value4", wktReader.read("POINT(20 47)"))) .toDF("id", "value", "geom") + + val geopandasUDFDF = df .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) - df.count shouldEqual 5 + geopandasUDFDF.count shouldEqual 5 - df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") + geopandasUDFDF.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") .as[String] .collect() should contain theSameElementsAs Seq( "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", @@ -56,4 +59,23 @@ class StrategySuite extends TestBaseScala with Matchers { "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") } + + it("sedona geospatial UDF - sedona db") { + val df = Seq( + (1, "value", wktReader.read("POINT(21 52)")), + (2, "value1", wktReader.read("POINT(20 50)")), + (3, "value2", wktReader.read("POINT(20 49)")), + (4, "value3", wktReader.read("POINT(20 48)")), + (5, "value4", wktReader.read("POINT(20 47)"))) + .toDF("id", "value", "geom") + + val dfVectorized = df + .withColumn("geometry", expr("ST_SetSRID(geom, '4326')")) + .select(sedonaDBGeometryToGeometryFunction(col("geometry"), lit(100)).alias("geom")) + + dfVectorized.selectExpr("ST_X(ST_Centroid(geom)) AS x") + .selectExpr("sum(x)") + .as[Double] + .collect().head shouldEqual 101 + } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index c0a2d8f260..23aac14bbe 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -19,11 +19,13 @@ package org.apache.spark.sql.udf import org.apache.sedona.sql.UDF -import org.apache.spark.TestUtils +import org.apache.spark.{SparkEnv, TestUtils} import org.apache.spark.api.python._ import org.apache.spark.broadcast.Broadcast +import org.apache.spark.internal.config.Python.{PYTHON_USE_DAEMON, PYTHON_WORKER_MODULE} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.FloatType import org.apache.spark.util.Utils import java.io.File @@ -43,6 +45,9 @@ object ScalarUDF { } } + SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false) + SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "sedonaworker.work") + private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "") protected lazy val sparkHome: String = { sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) @@ -54,7 +59,7 @@ object ScalarUDF { private lazy val isPythonAvailable: Boolean = TestUtils.testCommandAvailable(pythonExec) - lazy val pythonVer: String = if (isPythonAvailable) { + val pythonVer: String = if (isPythonAvailable) { Process( Seq(pythonExec, "-c", "import sys; print('%d.%d' % sys.version_info[:2])"), None, @@ -70,31 +75,85 @@ object ScalarUDF { finally Utils.deleteRecursively(path) } - val pandasFunc: Array[Byte] = { + val additionalModule = "spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf" + + val vectorizedFunction: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + f""" + |from pyspark.sql.types import FloatType + |from pyspark.serializers import CloudPickleSerializer + |f = open('$path', 'wb'); + | + |def apply_function_on_number(x): + | return x + 1.0 + |f.write(CloudPickleSerializer().dumps((apply_function_on_number, FloatType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } + + val sedonaDBGeometryToGeometryFunctionBytes: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + f""" + |import pyarrow as pa + |import shapely + |import geoarrow.pyarrow as ga + |from sedonadb import udf + |from sedona.sql.types import GeometryType + |from pyspark.serializers import CloudPickleSerializer + |from pyspark.sql.types import DoubleType, IntegerType + |from sedonadb import udf as sedona_udf_module + | + |@sedona_udf_module.arrow_udf(ga.wkb(), [udf.GEOMETRY, udf.NUMERIC]) + |def geometry_udf(geom, distance): + | geom_wkb = pa.array(geom.storage.to_array()) + | distance = pa.array(distance.to_array()) + | geom = shapely.from_wkb(geom_wkb) + | result_shapely = shapely.buffer(geom, distance) + | + | return pa.array(shapely.to_wkb(result_shapely)) + | + |f = open('$path', 'wb'); + |f.write(CloudPickleSerializer().dumps((lambda: geometry_udf, GeometryType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } + + val geopandasNonGeometryToGeometryFunction: Array[Byte] = { var binaryPandasFunc: Array[Byte] = null withTempPath { path => - println(path) Process( Seq( pythonExec, "-c", f""" - |from pyspark.sql.types import IntegerType - |from shapely.geometry import Point - |from sedona.sql.types import GeometryType - |from pyspark.serializers import CloudPickleSerializer - |from sedona.utils import geometry_serde - |from shapely import box - |f = open('$path', 'wb'); - |def w(x): - | def apply_function(w): - | geom, offset = geometry_serde.deserialize(w) - | bounds = geom.buffer(1).bounds - | x = box(*bounds) - | return geometry_serde.serialize(x) - | return x.apply(apply_function) - |f.write(CloudPickleSerializer().dumps((w, GeometryType()))) - |""".stripMargin), + |from sedona.sql.types import GeometryType + |from shapely.wkt import loads + |from pyspark.serializers import CloudPickleSerializer + |f = open('$path', 'wb'); + |def apply_geopandas(x): + | return x.apply(lambda wkt: loads(wkt).buffer(1)) + |f.write(CloudPickleSerializer().dumps((apply_geopandas, GeometryType()))) + |""".stripMargin), None, "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! binaryPandasFunc = Files.readAllBytes(path.toPath) @@ -104,7 +163,39 @@ object ScalarUDF { } private val workerEnv = new java.util.HashMap[String, String]() - workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") + + val pandasFunc: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + println(path) + Process( + Seq( + pythonExec, + "-c", + f""" + |from pyspark.sql.types import IntegerType + |from shapely.geometry import Point + |from sedona.sql.types import GeometryType + |from pyspark.serializers import CloudPickleSerializer + |from sedona.utils import geometry_serde + |from shapely import box + |f = open('$path', 'wb'); + |def w(x): + | def apply_function(w): + | geom, offset = geometry_serde.deserialize(w) + | bounds = geom.buffer(1).bounds + | x = box(*bounds) + | return geometry_serde.serialize(x) + | return x.apply(apply_function) + |f.write(CloudPickleSerializer().dumps((w, GeometryType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } val geoPandasScalaFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( name = "geospatial_udf", @@ -119,4 +210,33 @@ object ScalarUDF { dataType = GeometryUDT, pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, udfDeterministic = true) + + val nonGeometryVectorizedUDF: UserDefinedPythonFunction = UserDefinedPythonFunction( + name = "vectorized_udf", + func = SimplePythonFunction( + command = vectorizedFunction, + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = FloatType, + pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, + udfDeterministic = false) + + val sedonaDBGeometryToGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( + name = "geospatial_udf", + func = SimplePythonFunction( + command = sedonaDBGeometryToGeometryFunctionBytes, + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = GeometryUDT, + pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_DB_UDF, + udfDeterministic = true) + }
