This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch arrow-worker in repository https://gitbox.apache.org/repos/asf/sedona.git
commit b9db627919fd45eda2341fff4d01ad3e9beb91e4 Author: pawelkocinski <[email protected]> AuthorDate: Sun Jan 4 00:57:30 2026 +0100 add sedonadb sedona udf worker example --- python/pyproject.toml | 52 ++-- python/sedona/spark/sql/functions.py | 64 ++++- .../sedona/spark/worker/__init__.py | 0 python/sedona/spark/worker/serde.py | 82 ++++++ python/sedona/spark/worker/udf_info.py | 34 +++ python/sedona/spark/worker/worker.py | 295 +++++++++++++++++++++ python/tests/test_base.py | 2 +- .../tests/utils/test_sedona_db_vectorized_udf.py | 94 +++++++ .../sedona/sql/utils/GeometrySerializer.scala | 3 +- .../execution/python/SedonaArrowPythonRunner.scala | 6 +- .../sql/execution/python/SedonaArrowStrategy.scala | 59 ++++- .../execution/python/SedonaBasePythonRunner.scala | 5 +- .../execution/python/SedonaDBWorkerFactory.scala | 3 +- .../execution/python/SedonaPythonArrowInput.scala | 62 +---- .../execution/python/SedonaPythonArrowOutput.scala | 86 +++--- .../spark/sql/execution/python/WorkerContext.scala | 8 +- .../apache/sedona/sql/GeoPackageReaderTest.scala | 93 ------- .../org/apache/sedona/sql/TestBaseScala.scala | 1 - .../org/apache/spark/sql/udf/StrategySuite.scala | 179 ++----------- .../apache/spark/sql/udf/TestScalarPandasUDF.scala | 34 ++- 20 files changed, 749 insertions(+), 413 deletions(-) diff --git a/python/pyproject.toml b/python/pyproject.toml index 9b8ef8a585..29c7d0c388 100644 --- a/python/pyproject.toml +++ b/python/pyproject.toml @@ -16,7 +16,7 @@ # under the License. [build-system] -requires = ["setuptools>=69", "wheel"] +requires = ["setuptools>=80.9.0", "wheel"] build-backend = "setuptools.build_meta" [project] @@ -26,13 +26,19 @@ description = "Apache Sedona is a cluster computing system for processing large- readme = "README.md" license = { text = "Apache-2.0" } authors = [ { name = "Apache Sedona", email = "[email protected]" } ] -requires-python = ">=3.8" +requires-python = ">=3.12" classifiers = [ "Programming Language :: Python :: 3", "License :: OSI Approved :: Apache Software License", ] dependencies = [ "attrs", + "geoarrow-c>=0.3.1", + "geoarrow-pyarrow>=0.2.0", + "geopandas>=1.1.2", + "pyarrow>=16.1.0", + "pyspark==3.5.4", + "sedonadb", "shapely>=1.7.0", ] @@ -43,38 +49,16 @@ kepler-map = ["geopandas", "keplergl==0.3.2"] flink = ["apache-flink>=1.19.0"] db = ["sedonadb[geopandas]; python_version >= '3.9'"] all = [ - "pyspark>=3.4.0,<4.1.0", - "geopandas", - "pydeck==0.8.0", - "keplergl==0.3.2", - "rasterio>=1.2.10", +# "pyspark>=3.4.0,<4.1.0", +# "geopandas", +# "pydeck==0.8.0", +# "keplergl==0.3.2", +# "rasterio>=1.2.10", ] [dependency-groups] dev = [ - "pytest", - "pytest-cov", - "notebook==6.4.12", - "jupyter", - "mkdocs", - "scikit-learn", - "esda", - "libpysal", - "matplotlib", # implicit dependency of esda - # prevent incompatibility with pysal 4.7.0, which is what is resolved to when shapely >2 is specified - "scipy<=1.10.0", - "pandas>=2.0.0", - "numpy<2", - "geopandas", - # https://stackoverflow.com/questions/78949093/how-to-resolve-attributeerror-module-fiona-has-no-attribute-path - # cannot set geopandas>=0.14.4 since it doesn't support python 3.8, so we pin fiona to <1.10.0 - "fiona<1.10.0", - "pyarrow", - "pyspark>=3.4.0,<4.1.0", - "keplergl==0.3.2", - "pydeck==0.8.0", - "pystac==1.5.0", - "rasterio>=1.2.10", + "pytest>=9.0.2", ] [project.urls] @@ -99,3 +83,11 @@ sources = [ "src/geom_buf.c", "src/geos_c_dyn.c", ] + +[tool.uv] +dev-dependencies = [ + "pytest>=9.0.2", +] + +[tool.uv.sources] +sedonadb = { path = "../../../sedona-db/target/wheels/sedonadb-0.3.0-cp312-cp312-macosx_11_0_arm64.whl" } diff --git a/python/sedona/spark/sql/functions.py b/python/sedona/spark/sql/functions.py index 2420301d52..7c480e1700 100644 --- a/python/sedona/spark/sql/functions.py +++ b/python/sedona/spark/sql/functions.py @@ -21,11 +21,14 @@ from enum import Enum import pandas as pd -from sedona.spark.sql.types import GeometryType from sedona.spark.utils import geometry_serde -from pyspark.sql.udf import UserDefinedFunction -from pyspark.sql.types import DataType from shapely.geometry.base import BaseGeometry +from pyspark.sql.udf import UserDefinedFunction +import pyarrow as pa +import geoarrow.pyarrow as ga +from sedonadb import udf as sedona_udf_module +from sedona.spark.sql.types import GeometryType +from pyspark.sql.types import DataType, FloatType, DoubleType, IntegerType, StringType SEDONA_SCALAR_EVAL_TYPE = 5200 @@ -142,3 +145,58 @@ def serialize_to_geometry_if_geom(data, return_type: DataType): return geometry_serde.serialize(data) return data + + +def infer_pa_type(spark_type: DataType): + if isinstance(spark_type, GeometryType): + return ga.wkb() + elif isinstance(spark_type, FloatType): + return pa.float32() + elif isinstance(spark_type, DoubleType): + return pa.float64() + elif isinstance(spark_type, IntegerType): + return pa.int32() + elif isinstance(spark_type, StringType): + return pa.string() + else: + raise NotImplementedError(f"Type {spark_type} is not supported yet.") + +def infer_input_type(spark_type: DataType): + if isinstance(spark_type, GeometryType): + return sedona_udf_module.GEOMETRY + elif isinstance(spark_type, FloatType) or isinstance(spark_type, DoubleType) or isinstance(spark_type, IntegerType): + return sedona_udf_module.NUMERIC + elif isinstance(spark_type, StringType): + return sedona_udf_module.STRING + else: + raise NotImplementedError(f"Type {spark_type} is not supported yet.") + +def infer_input_types(spark_types: list[DataType]): + pa_types = [] + for spark_type in spark_types: + pa_type = infer_input_type(spark_type) + pa_types.append(pa_type) + + return pa_types + + +def sedona_db_vectorized_udf( + return_type: DataType, + input_types: list[DataType] +): + def apply_fn(fn): + out_type = infer_pa_type(return_type) + input_types_sedona_db = infer_input_types(input_types) + + @sedona_udf_module.arrow_udf(out_type, input_types=input_types_sedona_db) + def shapely_udf(*args, **kwargs): + return fn(*args, **kwargs) + + udf = UserDefinedFunction( + lambda: shapely_udf, return_type, "SedonaPandasArrowUDF", evalType=6200 + ) + + return udf + + + return apply_fn diff --git a/spark/common/src/test/scala/org/apache/sedona/stats/data/_SUCCESS b/python/sedona/spark/worker/__init__.py similarity index 100% rename from spark/common/src/test/scala/org/apache/sedona/stats/data/_SUCCESS rename to python/sedona/spark/worker/__init__.py diff --git a/python/sedona/spark/worker/serde.py b/python/sedona/spark/worker/serde.py new file mode 100644 index 0000000000..31038b7fcd --- /dev/null +++ b/python/sedona/spark/worker/serde.py @@ -0,0 +1,82 @@ +import socket + +from pyspark.serializers import write_int, SpecialLengths +from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer + +from sedona.spark.worker.udf_info import UDFInfo + + +def read_available(buf, chunk=4096): + # buf.raw._sock.settimeout(0.01) # non-blocking-ish + data = bytearray() + index = 0 + while True: + index+=1 + try: + chunk_bytes = buf.read(chunk) + except socket.timeout: + break + + if not chunk_bytes and index > 10: + break + + data.extend(chunk_bytes) + + return bytes(data) + +class SedonaDBSerializer(ArrowStreamPandasSerializer): + def __init__(self, timezone, safecheck, db, udf_info: UDFInfo): + super(SedonaDBSerializer, self).__init__(timezone, safecheck) + self.db = db + self.udf_info = udf_info + + def load_stream(self, stream): + import pyarrow as pa + + batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) + index = 0 + for batch in batches: + table = pa.Table.from_batches(batches=[batch]) + import pyarrow as pa + df = self.db.create_data_frame(table) + table_name = f"my_table_{index}" + + df.to_view(table_name) + + sql_expression = self.udf_info.sedona_db_transformation_expr(table_name) + + index += 1 + + yield self.db.sql(sql_expression) + + def arrow_dump_stream(self, iterator, stream): + import pyarrow as pa + + writer = None + try: + for batch in iterator: + if writer is None: + writer = pa.RecordBatchStreamWriter(stream, batch.schema) + writer.write_batch(batch) + # stream.flush() + finally: + if writer is not None: + writer.close() + + def dump_stream(self, iterator, stream): + """ + Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. + This should be sent after creating the first record batch so in case of an error, it can + be sent back to the JVM before the Arrow stream starts. + """ + + def init_stream_yield_batches(): + should_write_start_length = True + for batch in iterator: + if should_write_start_length: + write_int(SpecialLengths.START_ARROW_STREAM, stream) + should_write_start_length = False + + yield batch + + return self.arrow_dump_stream(init_stream_yield_batches(), stream) diff --git a/python/sedona/spark/worker/udf_info.py b/python/sedona/spark/worker/udf_info.py new file mode 100644 index 0000000000..d354bcea7e --- /dev/null +++ b/python/sedona/spark/worker/udf_info.py @@ -0,0 +1,34 @@ +from dataclasses import dataclass + +from sedona.spark import GeometryType + + +@dataclass +class UDFInfo: + arg_offsets: list + geom_offsets: dict + function: object + return_type: object + name: str + + def get_function_call_sql(self, table_name: str) -> str: + arg_offset_str = ", ".join([f"_{el}" for el in self.arg_offsets]) + function_expr = f"{self.name}({arg_offset_str})" + if isinstance(self.return_type, GeometryType): + return f"SELECT ST_GeomToSedonaSpark({function_expr}) AS _0 FROM {table_name}" + + return f"SELECT {function_expr} AS _0 FROM {table_name}" + + def sedona_db_transformation_expr(self, table_name: str) -> str: + fields = [] + for arg in self.arg_offsets: + if arg in self.geom_offsets: + crs = self.geom_offsets[arg] + fields.append(f"ST_GeomFromSedonaSpark(_{arg}, 'EPSG:{crs}') AS _{arg}") + continue + + fields.append(f"_{arg}") + + + fields_expr = ", ".join(fields) + return f"SELECT {fields_expr} FROM {table_name}" diff --git a/python/sedona/spark/worker/worker.py b/python/sedona/spark/worker/worker.py new file mode 100644 index 0000000000..74a61b02ee --- /dev/null +++ b/python/sedona/spark/worker/worker.py @@ -0,0 +1,295 @@ +import importlib +import os +import sys +import time + +import sedonadb +from pyspark import TaskContext, shuffle, SparkFiles +from pyspark.errors import PySparkRuntimeError +from pyspark.java_gateway import local_connect_and_auth +from pyspark.resource import ResourceInformation +from pyspark.serializers import read_int, UTF8Deserializer, read_bool, read_long, CPickleSerializer, write_int, \ + write_long, SpecialLengths + +from sedona.spark.worker.serde import SedonaDBSerializer +from sedona.spark.worker.udf_info import UDFInfo + + +def apply_iterator(db, iterator, udf_info: UDFInfo): + i = 0 + for df in iterator: + i+=1 + table_name = f"output_table_{i}" + df.to_view(table_name) + + function_call_sql = udf_info.get_function_call_sql(table_name) + + df_out = db.sql(function_call_sql) + df_out.to_view(f"view_{i}") + at = df_out.to_arrow_table() + batches = at.combine_chunks().to_batches() + + for batch in batches: + yield batch + + +def check_python_version(utf_serde: UTF8Deserializer, infile) -> str: + version = utf_serde.loads(infile) + + python_major, python_minor = sys.version_info[:2] + + if version != f"{python_major}.{python_minor}": + raise PySparkRuntimeError( + error_class="PYTHON_VERSION_MISMATCH", + message_parameters={ + "worker_version": str(sys.version_info[:2]), + "driver_version": str(version), + }, + ) + + return version + +def check_barrier_flag(infile): + is_barrier = read_bool(infile) + bound_port = read_int(infile) + secret = UTF8Deserializer().loads(infile) + + if is_barrier: + raise PySparkRuntimeError( + error_class="BARRIER_MODE_NOT_SUPPORTED", + message_parameters={ + "worker_version": str(sys.version_info[:2]), + "message": "Barrier mode is not supported by SedonaDB vectorized functions.", + }, + ) + + return is_barrier + +def assign_task_context(utf_serde: UTF8Deserializer, infile): + stage_id = read_int(infile) + partition_id = read_int(infile) + attempt_number = read_long(infile) + task_attempt_id = read_int(infile) + cpus = read_int(infile) + + task_context = TaskContext._getOrCreate() + task_context._stage_id = stage_id + task_context._partition_id = partition_id + task_context._attempt_number = attempt_number + task_context._task_attempt_id = task_attempt_id + task_context._cpus = cpus + + for r in range(read_int(infile)): + key = utf_serde.loads(infile) + name = utf_serde.loads(infile) + addresses = [] + task_context._resources = {} + for a in range(read_int(infile)): + addresses.append(utf_serde.loads(infile)) + task_context._resources[key] = ResourceInformation(name, addresses) + + task_context._localProperties = dict() + for i in range(read_int(infile)): + k = utf_serde.loads(infile) + v = utf_serde.loads(infile) + task_context._localProperties[k] = v + + return task_context + +def resolve_python_path(utf_serde: UTF8Deserializer, infile): + def add_path(path: str): + # worker can be used, so do not add path multiple times + if path not in sys.path: + # overwrite system packages + sys.path.insert(1, path) + + spark_files_dir = utf_serde.loads(infile) + # _accumulatorRegistry.clear() + + SparkFiles._root_directory = spark_files_dir + SparkFiles._is_running_on_worker = True + + add_path(spark_files_dir) # *.py files that were added will be copied here + num_python_includes = read_int(infile) + for _ in range(num_python_includes): + filename = utf_serde.loads(infile) + add_path(os.path.join(spark_files_dir, filename)) + + importlib.invalidate_caches() + + +def check_broadcast_variables(infile): + needs_broadcast_decryption_server = read_bool(infile) + num_broadcast_variables = read_int(infile) + + if needs_broadcast_decryption_server or num_broadcast_variables > 0: + raise PySparkRuntimeError( + error_class="BROADCAST_VARS_NOT_SUPPORTED", + message_parameters={ + "worker_version": str(sys.version_info[:2]), + "message": "Broadcast variables are not supported by SedonaDB vectorized functions.", + }, + ) + +def get_runner_conf(utf_serde: UTF8Deserializer, infile): + runner_conf = {} + num_conf = read_int(infile) + for i in range(num_conf): + k = utf_serde.loads(infile) + v = utf_serde.loads(infile) + runner_conf[k] = v + return runner_conf + + +def read_command(serializer, infile): + command = serializer._read_with_length(infile) + return command + +def read_udf(infile, pickle_ser) -> UDFInfo: + num_arg = read_int(infile) + arg_offsets = [read_int(infile) for i in range(num_arg)] + + function = None + return_type = None + + for i in range(read_int(infile)): + function, return_type = read_command(pickle_ser, infile) + + sedona_db_udf_expression = function() + + return UDFInfo( + arg_offsets=arg_offsets, + function=sedona_db_udf_expression, + return_type=return_type, + name=sedona_db_udf_expression._name, + geom_offsets=[0] + ) + +# def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): +# num_arg = read_int(infile) +# arg_offsets = [read_int(infile) for i in range(num_arg)] +# chained_func = None +# for i in range(read_int(infile)): +# f, return_type = read_command(pickleSer, infile) +# if chained_func is None: +# chained_func = f +# else: +# chained_func = chain(chained_func, f) +# +# func = chained_func +# +# # the last returnType will be the return type of UDF +# if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: +# return arg_offsets, func, return_type +# else: +# raise ValueError("Unknown eval type: {}".format(eval_type)) +# + +def register_sedona_db_udf(infile, pickle_ser) -> UDFInfo: + num_udfs = read_int(infile) + + udf = None + for _ in range(num_udfs): + udf = read_udf(infile, pickle_ser) + # Here we would register the UDF with SedonaDB's internal context + + + return udf + + +def report_times(outfile, boot, init, finish): + write_int(SpecialLengths.TIMING_DATA, outfile) + write_long(int(1000 * boot), outfile) + write_long(int(1000 * init), outfile) + write_long(int(1000 * finish), outfile) + + +def write_statistics(infile, outfile, boot_time, init_time) -> None: + TaskContext._setTaskContext(None) + finish_time = time.time() + report_times(outfile, boot_time, init_time, finish_time) + write_long(shuffle.MemoryBytesSpilled, outfile) + write_long(shuffle.DiskBytesSpilled, outfile) + + # Mark the beginning of the accumulators section of the output + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + # write_int(len(_accumulatorRegistry), outfile) + # for (aid, accum) in _accumulatorRegistry.items(): + # pickleSer._write_with_length((aid, accum._value), outfile) + + if read_int(infile) == SpecialLengths.END_OF_STREAM: + write_int(SpecialLengths.END_OF_STREAM, outfile) + outfile.flush() + else: + write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) + outfile.flush() + sys.exit(-1) + + +def main(infile, outfile): + boot_time = time.time() + sedona_db = sedonadb.connect() + # + utf8_deserializer = UTF8Deserializer() + pickle_ser = CPickleSerializer() + + split_index = read_int(infile) + # + check_python_version(utf8_deserializer, infile) + # + check_barrier_flag(infile) + + task_context = assign_task_context(utf_serde=utf8_deserializer, infile=infile) + shuffle.MemoryBytesSpilled = 0 + shuffle.DiskBytesSpilled = 0 + + resolve_python_path(utf8_deserializer, infile) + # + check_broadcast_variables(infile) + + eval_type = read_int(infile) + + runner_conf = get_runner_conf(utf8_deserializer, infile) + + udf = register_sedona_db_udf(infile, pickle_ser) + + sedona_db.register_udf(udf.function) + init_time = time.time() + + serde = SedonaDBSerializer( + timezone=runner_conf.get("spark.sql.session.timeZone", "UTC"), + safecheck=False, + db=sedona_db, + udf_info=udf + ) + + number_of_geometries = read_int(infile) + geom_offsets = {} + for i in range(number_of_geometries): + geom_index = read_int(infile) + geom_srid = read_int(infile) + + geom_offsets[geom_index] = geom_srid + + udf.geom_offsets = geom_offsets + + iterator = serde.load_stream(infile) + out_iterator = apply_iterator(db=sedona_db, iterator=iterator, udf_info=udf) + + serde.dump_stream(out_iterator, outfile) + + write_statistics( + infile, outfile, boot_time=boot_time, init_time=init_time + ) + + +if __name__ == "__main__": + # add file handler + auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] + java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) + (sock_file, sc) = local_connect_and_auth(java_port, auth_secret) + + write_int(os.getpid(), sock_file) + sock_file.flush() + + main(sock_file, sock_file) diff --git a/python/tests/test_base.py b/python/tests/test_base.py index cc2b09e422..a6dbae6597 100644 --- a/python/tests/test_base.py +++ b/python/tests/test_base.py @@ -22,7 +22,7 @@ from typing import Iterable, Union import pyspark from pyspark.sql import DataFrame -from sedona.spark import * +from sedona.spark import SedonaContext from sedona.spark.utils.decorators import classproperty SPARK_REMOTE = os.getenv("SPARK_REMOTE") diff --git a/python/tests/utils/test_sedona_db_vectorized_udf.py b/python/tests/utils/test_sedona_db_vectorized_udf.py new file mode 100644 index 0000000000..749d45420e --- /dev/null +++ b/python/tests/utils/test_sedona_db_vectorized_udf.py @@ -0,0 +1,94 @@ +from sedona.spark.sql.functions import sedona_db_vectorized_udf +from tests.test_base import TestBase +import pyarrow as pa +import shapely +from sedona.sql import GeometryType +from pyspark.sql.functions import expr, lit +from pyspark.sql.types import DoubleType, IntegerType + + +class TestSedonaDBArrowFunction(TestBase): + def test_vectorized_udf(self): + @sedona_db_vectorized_udf(return_type=GeometryType(), input_types=[GeometryType(), IntegerType()]) + def my_own_function(geom, distance): + geom_wkb = pa.array(geom.storage.to_array()) + distance = pa.array(distance.to_array()) + geom = shapely.from_wkb(geom_wkb) + + result_shapely = shapely.centroid(geom) + + return pa.array(shapely.to_wkb(result_shapely)) + + df = self.spark.createDataFrame( + [ + (1, "POINT (1 1)"), + (2, "POINT (2 2)"), + (3, "POINT (3 3)"), + ], + ["id", "wkt"], + ).withColumn("wkt", expr("ST_GeomFromWKT(wkt)")) + + df.select(my_own_function(df.wkt, lit(100)).alias("geom")).show() + + def test_geometry_to_double(self): + @sedona_db_vectorized_udf(return_type=DoubleType(), input_types=[GeometryType()]) + def geometry_to_non_geometry_udf(geom): + geom_wkb = pa.array(geom.storage.to_array()) + geom = shapely.from_wkb(geom_wkb) + + result_shapely = shapely.get_x(shapely.centroid(geom)) + + return pa.array(result_shapely, pa.float64()) + + df = self.spark.createDataFrame( + [(1, "POINT (1 1)"), (2, "POINT (2 2)"), (3, "POINT (3 3)")], + ["id", "wkt"], + ).withColumn("wkt", expr("ST_GeomFromWKT(wkt)")) + + values = df.select(geometry_to_non_geometry_udf(df.wkt).alias("x_coord")) \ + .collect() + + values_list = [row["x_coord"] for row in values] + + assert values_list == [1.0, 2.0, 3.0] + + def test_geometry_to_int(self): + @sedona_db_vectorized_udf(return_type=IntegerType(), input_types=[GeometryType()]) + def geometry_to_int(geom): + geom_wkb = pa.array(geom.storage.to_array()) + geom = shapely.from_wkb(geom_wkb) + + result_shapely = shapely.get_num_points(geom) + + return pa.array(result_shapely, pa.int32()) + + df = self.spark.createDataFrame( + [(1, "POINT (1 1)"), (2, "POINT (2 2)"), (3, "POINT (3 3)")], + ["id", "wkt"], + ).withColumn("wkt", expr("ST_GeomFromWKT(wkt)")) + + values = df.select(geometry_to_int(df.wkt)) \ + .collect() + + values_list = [row[0] for row in values] + + assert values_list == [0, 0, 0] + + def test_geometry_crs_preservation(self): + @sedona_db_vectorized_udf(return_type=GeometryType(), input_types=[GeometryType()]) + def return_same_geometry(geom): + geom_wkb = pa.array(geom.storage.to_array()) + geom = shapely.from_wkb(geom_wkb) + + return pa.array(shapely.to_wkb(geom)) + + df = self.spark.createDataFrame( + [(1, "POINT (1 1)"), (2, "POINT (2 2)"), (3, "POINT (3 3)")], + ["id", "wkt"], + ).withColumn("wkt", expr("ST_SetSRID(ST_GeomFromWKT(wkt), 3857)")) + + result_df = df.select(return_same_geometry(df.wkt).alias("geom")) + + crs_list = result_df.selectExpr("ST_SRID(geom)").rdd.flatMap(lambda x: x).collect() + + assert crs_list == [3857, 3857, 3857] diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/utils/GeometrySerializer.scala b/spark/common/src/main/scala/org/apache/sedona/sql/utils/GeometrySerializer.scala index 02f7f3157d..a75a88f7ba 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/utils/GeometrySerializer.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/utils/GeometrySerializer.scala @@ -35,8 +35,7 @@ object GeometrySerializer { * Array of bites represents this geometry */ def serialize(geometry: Geometry): Array[Byte] = { - val serialized = geometrySerde.GeometrySerializer.serialize(geometry) - serialized + geometrySerde.GeometrySerializer.serialize(geometry) } /** diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala index 3bb93fe62e..ecb7d90231 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala @@ -37,12 +37,14 @@ class SedonaArrowPythonRunner( protected override val largeVarTypes: Boolean, protected override val workerConf: Map[String, String], val pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + geometryFields: Seq[(Int, Int)]) extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch]( funcs, evalType, argOffsets, - jobArtifactUUID) + jobArtifactUUID, + geometryFields) with SedonaBasicPythonArrowInput with SedonaBasicPythonArrowOutput { diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala index fa6dee3728..268100ef44 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala @@ -23,13 +23,14 @@ import org.apache.sedona.sql.UDF.PythonEvalType.{SQL_SCALAR_SEDONA_DB_UDF, SQL_S import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} +import org.apache.spark.sql.catalyst.expressions.{Attribute, GenericInternalRow, PythonUDF} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.StructType import org.apache.spark.sql.udf.SedonaArrowEvalPython import org.apache.spark.{JobArtifactSet, TaskContext} import org.apache.spark.sql.internal.SQLConf +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import scala.collection.JavaConverters.asScalaIteratorConverter @@ -58,21 +59,65 @@ case class SedonaArrowEvalPythonExec( private val batchSize = conf.arrowMaxRecordsPerBatch private val sessionLocalTimeZone = conf.sessionLocalTimeZone private val largeVarTypes = conf.arrowUseLargeVarTypes - private val pythonRunnerConf = Map[String, String]( - SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone - ) + private val pythonRunnerConf = + Map[String, String](SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + private def inferCRS(iterator: Iterator[InternalRow], schema: StructType): Seq[(Int, Int)] = { + // this triggers the iterator + if (!iterator.hasNext) { + return Seq.empty + } + + // TODO think of running it multiple times if the needed geometry is null, it's worth + // considering making this parameter based + val row = iterator.next() + + val rowMatched = row match { + case generic: GenericInternalRow => + Some(row.asInstanceOf[GenericInternalRow]) + case _ => None + } + + schema.zipWithIndex + .filter { case (field, index) => + field.dataType == GeometryUDT + } + .map { case (field, index) => + if (rowMatched.isEmpty || rowMatched.get.values(index) == null) (index, 0) + else { + val geom = rowMatched.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] + val preambleByte = geom(0) & 0xff + val hasSrid = (preambleByte & 0x01) != 0 + + var srid = 0 + if (hasSrid) { + val srid2 = (geom(1) & 0xff) << 16 + val srid1 = (geom(2) & 0xff) << 8 + val srid0 = geom(3) & 0xff + srid = srid2 | srid1 | srid0 + } + + (index, srid) + } + } + } + protected override def evaluate( funcs: Seq[ChainedPythonFunctions], argOffsets: Array[Array[Int]], iter: Iterator[InternalRow], schema: StructType, context: TaskContext): Iterator[InternalRow] = { + val (probe, full) = iter.duplicate + + val geometryFields = inferCRS(probe, schema) +// val geometryFields = Seq() + // val firstRow = probe.buffered val outputTypes = output.drop(child.output.length).map(_.dataType) - val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) + val batchIter = if (batchSize > 0) new BatchIterator(full, batchSize) else Iterator(full) evalType match { case SQL_SCALAR_SEDONA_DB_UDF => @@ -85,7 +130,8 @@ case class SedonaArrowEvalPythonExec( largeVarTypes, pythonRunnerConf, pythonMetrics, - jobArtifactUUID).compute(batchIter, context.partitionId(), context) + jobArtifactUUID, + geometryFields).compute(batchIter, context.partitionId(), context) val result = columnarBatchIter.flatMap { batch => batch.rowIterator.asScala @@ -116,4 +162,3 @@ case class SedonaArrowEvalPythonExec( override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) } - diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala index 06ae60bbf7..ffe52912da 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala @@ -50,7 +50,8 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( funcs: Seq[ChainedPythonFunctions], evalType: Int, argOffsets: Array[Array[Int]], - jobArtifactUUID: Option[String]) + jobArtifactUUID: Option[String], + val geometryFields: Seq[(Int, Int)] = Seq.empty) extends BasePythonRunner[IN, OUT](funcs, evalType, argOffsets, jobArtifactUUID) with Logging { @@ -112,7 +113,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( val releasedOrClosed = new AtomicBoolean(false) // Start a thread to feed the process input from our parent's iterator - val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener[Unit] { _ => writerThread.shutdownOnTaskCompletion() diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala index bfcc8ee2cc..7c3697b8b5 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala @@ -36,7 +36,8 @@ class SedonaDBWorkerFactory(pythonExec: String, envVars: Map[String, String]) extends PythonWorkerFactory(pythonExec, envVars) { self => - private val sedonaWorkerModule = "sedonaworker.reader" + private val sedonaWorkerModule = "sedona.spark.worker.worker" +// private val sedonaWorkerModule = "sedonaworker.reader" private val simpleWorkers = new mutable.WeakHashMap[Socket, Process]() private val authHelper = new SocketAuthHelper(SparkEnv.get.conf) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala index a0d40121a7..beebce581f 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -63,63 +63,25 @@ private[python] trait SedonaPythonArrowInput[IN] extends PythonArrowInput[IN] { context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { - val dataOutFile = s"/tmp/sedona_python_arrow_input_${context.taskAttemptId()}.bin" - val dataOutStream = new FileOutputStream(new File(dataOutFile)) - val dataOut2 = new DataOutputStream(dataOutStream) - protected override def writeCommand(dataOut: DataOutputStream): Unit = { handleMetadataBeforeExec(dataOut) writeUDF(dataOut, funcs, argOffsets) -// val toReadCRS = inputIterator.buffered.headOption.flatMap(el => -// el.asInstanceOf[Iterator[IN]].buffered.headOption) - -// val row = toReadCRS match { -// case Some(value) => -// value match { -// case row: GenericInternalRow => -// Some(row) -// } -// case None => None -// } - -// val geometryFields = schema.zipWithIndex -// .filter { case (field, index) => -// field.dataType == GeometryUDT -// } -// .map { case (field, index) => -// if (row.isEmpty || row.get.values(index) == null) (index, 0) -// else { -// val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] -// val preambleByte = geom(0) & 0xff -// val hasSrid = (preambleByte & 0x01) != 0 -// -// var srid = 0 -// if (hasSrid) { -// val srid2 = (geom(1) & 0xff) << 16 -// val srid1 = (geom(2) & 0xff) << 8 -// val srid0 = geom(3) & 0xff -// srid = srid2 | srid1 | srid0 -// } -// (index, srid) -// } -// } - - // write number of geometry fields -// dataOut.writeInt(geometryFields.length) - dataOut.writeInt(0) + dataOut.writeInt(self.geometryFields.length) // write geometry field indices and their SRIDs -// geometryFields.foreach { case (index, srid) => -// dataOut.writeInt(index) -// dataOut.writeInt(srid) -// } + geometryFields.foreach { case (index, srid) => + dataOut.writeInt(index) + dataOut.writeInt(srid) + } } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema( - schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) + val arrowSchema = + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdout writer for $pythonExec", 0, Long.MaxValue) + s"stdout writer for $pythonExec", + 0, + Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) Utils.tryWithSafeFinally { @@ -161,8 +123,6 @@ private[python] trait SedonaBasicPythonArrowInput dataOut: DataOutputStream, inputIterator: Iterator[Iterator[InternalRow]]): Unit = { val arrowWriter = ArrowWriter.create(root) - var record = 0 - while (inputIterator.hasNext) { val startData = dataOut.size() val nextBatch = inputIterator.next() @@ -176,8 +136,6 @@ private[python] trait SedonaBasicPythonArrowInput arrowWriter.reset() val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData - record += 1 - println("Written batch number: " + record) } } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala index 8b2d94a1c9..55f05748e6 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala @@ -1,21 +1,21 @@ - /* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ - package org.apache.spark.sql.execution.python import com.univocity.parsers.common.input.EOFException @@ -33,30 +33,38 @@ import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} - private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => protected def pythonMetrics: Map[String, SQLMetric] - protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } + protected def handleMetadataAfterExec(stream: DataInputStream): Unit = {} protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT protected def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[OUT] = { + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[OUT] = { new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + stream, + writerThread, + startTime, + env, + worker, + pid, + releasedOrClosed, + context) { private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", 0, Long.MaxValue) + s"stdin reader for $pythonExec", + 0, + Long.MaxValue) private var reader: ArrowStreamReader = _ private var root: VectorSchemaRoot = _ @@ -74,10 +82,8 @@ private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: BasePythonR private var batchLoaded = true - def handleEndOfDataSectionSedona (): Unit = { - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - - } + def handleEndOfDataSectionSedona(): Unit = { + if (stream.readInt() == SpecialLengths.END_OF_STREAM) {} eos = true } @@ -106,8 +112,6 @@ private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: BasePythonR } } - - protected override def read(): OUT = { if (writerThread.exception.isDefined) { throw writerThread.exception.get @@ -138,9 +142,13 @@ private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: BasePythonR reader = new ArrowStreamReader(stream, allocator) root = reader.getVectorSchemaRoot() schema = ArrowUtils.fromArrowSchema(root.getSchema()) - vectors = root.getFieldVectors().asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] + vectors = root + .getFieldVectors() + .asScala + .map { vector => + new ArrowColumnVector(vector) + } + .toArray[ColumnVector] read() case SpecialLengths.TIMING_DATA => @@ -159,11 +167,11 @@ private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: BasePythonR } } -private[python] trait SedonaBasicPythonArrowOutput extends SedonaPythonArrowOutput[ColumnarBatch] { +private[python] trait SedonaBasicPythonArrowOutput + extends SedonaPythonArrowOutput[ColumnarBatch] { self: BasePythonRunner[_, ColumnarBatch] => protected def deserializeColumnarBatch( - batch: ColumnarBatch, - schema: StructType): ColumnarBatch = batch + batch: ColumnarBatch, + schema: StructType): ColumnarBatch = batch } - diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala index 3aa12467e4..30aec984cb 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala @@ -39,9 +39,11 @@ object WorkerContext { synchronized { // worker.close() val key = (pythonExec, envVars) - pythonWorkers.get(key).foreach(workerFactory => { - workerFactory.stopWorker(worker) - }) + pythonWorkers + .get(key) + .foreach(workerFactory => { + workerFactory.stopWorker(worker) + }) } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala index 54f2cc0ea9..6d9f41bf4e 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoPackageReaderTest.scala @@ -277,99 +277,6 @@ class GeoPackageReaderTest extends TestBaseScala with Matchers { } - describe("check the binary format") { - it("should read binary format correctly") { -// sparkSession.sql( -// "SELECT ST_SetSRID(ST_GeomFromText('POINT (1 1)'), 4326) AS geometry" -// ).show - -// sparkSession.sql( -// "SELECT ST_GeomFromText('POINT EMPTY') AS geometry" -// ).show - -// sparkSession.sql( -// "SELECT ST_GeomFromText('GEOMETRYCOLLECTION EMPTY') AS geometry" -// ).show - - sparkSession.sql( - """SELECT ST_GeomFromText(' - |POLYGON ( - | ( - | 12.345678901234 45.678901234567, - | 23.456789012345 67.890123456789, - | 34.567890123456 56.789012345678, - | 45.678901234567 34.567890123456, - | 29.876543210987 22.345678901234, - | 12.345678901234 45.678901234567 - | ), - | ( - | 25.123456789012 45.987654321098, - | 30.987654321098 50.123456789012, - | 35.456789012345 45.456789012345, - | 30.234567890123 40.987654321098, - | 25.123456789012 45.987654321098 - | ) - |) - |' - |)""".stripMargin - ).show - -// sparkSession.sql( -// "SELECT ST_GeomFromText('MULTILINESTRING((1 1, 2 2), (4 5, 6 7))') AS geometry" -// ).show -// -// sparkSession.sql( -// """ -// |SELECT ST_GeomFromText(' -// |MULTIPOLYGON ( -// | ( -// | (1 1, 10 1, 10 10, 1 10, 1 1), -// | (2 2, 4 2, 4 4, 2 4, 2 2), -// | (6 6, 8 6, 8 8, 6 8, 6 6) -// | ), -// | ( -// | (12 1, 20 1, 20 9, 12 9, 12 1), -// | (13 2, 15 2, 15 4, 13 4, 13 2), -// | (17 5, 19 5, 19 7, 17 7, 17 5) -// | ) -// |) -// |') -// | -// |""".stripMargin -// ).show - -// sparkSession.sql( -// "SELECT ST_GeomFromText('GEOMETRYCOLLECTION(POINT(4 6),LINESTRING(4 6,7 10), POLYGON((4 6,7 10,4 10,4 6)), MULTIPOINT((1 2),(3 4)))') AS geometry" -// ).show() - -// sparkSession.sql( -// """ -// |SELECT ST_GeomFromText(' -// |GEOMETRYCOLLECTION ( -// | POINT (1 1), -// | GEOMETRYCOLLECTION ( -// | LINESTRING (0 0, 1 1), -// | POLYGON ((0 0, 2 0, 2 2, 0 2, 0 0)) -// | ) -// |)') -// | -// | -// |""".stripMargin -// ).show() - - -// sparkSession.sql( -// "SELECT ST_GeomFromText('MULTIPOINT((1 1), (2 2), (4 5))') AS geometry" -// ).show -// -// sparkSession.sql( -// "SELECT ST_GeomFromText('LINESTRING (0 0, 1 1, 2 2)') AS geometry" -// ).show - - } - - } - describe("Reading from S3") { it("should be able to read files from S3") { val container = new MinIOContainer("minio/minio:latest") diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 32d2f06fc8..28943ff11d 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -38,7 +38,6 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { val keyParserExtension = "spark.sedona.enableParserExtension" val warehouseLocation = System.getProperty("user.dir") + "/target/" -// 4425302.491982245 val sparkSession = SedonaContext .builder() .master("local[*]") diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala index 426ed18429..f58e58783d 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -19,16 +19,12 @@ package org.apache.spark.sql.udf import org.apache.sedona.sql.TestBaseScala -import org.apache.spark.SparkEnv -import org.apache.spark.security.SocketAuthHelper import org.apache.spark.sql.SparkSession import org.apache.spark.sql.functions.{col, expr, lit} import org.apache.spark.sql.udf.ScalarUDF.{geoPandasScalaFunction, nonGeometryVectorizedUDF, sedonaDBGeometryToGeometryFunction} import org.locationtech.jts.io.WKTReader import org.scalatest.matchers.should.Matchers -import java.net.{InetAddress, ServerSocket, Socket} - class StrategySuite extends TestBaseScala with Matchers { val wktReader = new WKTReader() @@ -40,172 +36,37 @@ class StrategySuite extends TestBaseScala with Matchers { import spark.implicits._ it("sedona geospatial UDF - geopandas") { - val df = spark.read - .format("geoparquet") - .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") - .withColumn("geom_buffer", geoPandasScalaFunction(col("geometry")) ) - - df.printSchema() - - df.show() -// -// val df = Seq( -// (1, "value", wktReader.read("POINT(21 52)")), -// (2, "value1", wktReader.read("POINT(20 50)")), -// (3, "value2", wktReader.read("POINT(20 49)")), -// (4, "value3", wktReader.read("POINT(20 48)")), -// (5, "value4", wktReader.read("POINT(20 47)"))) -// .toDF("id", "value", "geom") -// .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) - -// df.count shouldEqual 5 -// -// df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") -// .as[String] -// .collect() should contain theSameElementsAs Seq( -// "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", -// "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", -// "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", -// "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", -// "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") + val df = Seq( + (1, "value", wktReader.read("POINT(21 52)")), + (2, "value1", wktReader.read("POINT(20 50)")), + (3, "value2", wktReader.read("POINT(20 49)")), + (4, "value3", wktReader.read("POINT(20 48)")), + (5, "value4", wktReader.read("POINT(20 47)"))) + .toDF("id", "value", "geom") + .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) + + df.count shouldEqual 5 + + df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") + .as[String] + .collect() should contain theSameElementsAs Seq( + "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", + "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", + "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", + "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", + "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") } it("sedona geospatial UDF - sedona db") { -// val df = Seq( -// (1, "value", wktReader.read("POINT(21 52)")), -// (2, "value1", wktReader.read("POINT(20 50)")), -// (3, "value2", wktReader.read("POINT(20 49)")), -// (4, "value3", wktReader.read("POINT(20 48)")), -// (5, "value4", wktReader.read("POINT(20 47)"))) -// .toDF("id", "value", "geometry") - -// df.cache() -// df.count() -// .select( -// sedonaDBGeometryToGeometryFunction(col("geometry")).alias("geom"), -// nonGeometryVectorizedUDF(col("id")).alias("id_increased"), -// ) - -// spark.read -// .format("geoparquet") -// .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") -// .limit(10000) -// .write.format("geoparquet") -// .save("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings_2") val df = spark.read .format("geoparquet") .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") .select("geometry") -// df.cache() -// df.count() -// .limit(100) - - -// println(df.count()) - -// df.cache() -// -// df.count() - val dfVectorized = df .withColumn("geometry", expr("ST_SetSRID(geometry, '4326')")) - .select( -// col("id"), -// col("version"), -// col("bbox"), - sedonaDBGeometryToGeometryFunction(col("geometry"), lit(100)).alias("geom"), -// nonGeometryVectorizedUDF(col("id")).alias("id_increased"), - ) + .select(sedonaDBGeometryToGeometryFunction(col("geometry"), lit(100)).alias("geom")) -// dfVectorized.show() dfVectorized.selectExpr("ST_X(ST_Centroid(geom)) AS x").selectExpr("sum(x)").show() -// dfVectorized.selectExpr("ST_X(ST_Centroid(geom)) AS x").selectExpr("sum(x)").show() -// val processingContext = df.queryExecution.explainString(mode = ExplainMode.fromString("extended")) - -// println(processingContext) - } - - it("should properly start socket server") { - val authHelper = new SocketAuthHelper(SparkEnv.get.conf) - val serverSocket = new ServerSocket(5356, 1, InetAddress.getLoopbackAddress()) - -// serverSocket.setSoTimeout(15000) - println(serverSocket.getLocalPort) - val socket = serverSocket.accept() - - println("socket accepted") - -// -// val acceptThread = new Thread(() => { -// println("Waiting for client...") -// val socket = serverSocket.accept() // BLOCKS HERE -// println("Client connected!") -// }, "accept-thread") -// -// acceptThread.start() -// -// println("Main thread continues immediately") - -// val t = new Thread() { -// override def run(): Unit = { -// println("starting client") -// socket = serverSocket.accept() -// println("client connected") -// -// } -// } - -// t.start() -// val socket = serverSocket.accept() -// authHelper.authClient(socket) - - println(authHelper.secret) -// Thread.sleep(10000) -// authHelper.authClient(socket) - -// var socket: Socket = null - -// new Thread() { -// socket = serverSocket.accept() -// } -// -// val t2 = new Thread() { -// override def run(): Unit = { -// println("accepted connection") -// val socket = serverSocket.accept() -// val in = socket.getInputStream -// val buffer = new Array[Byte](1024) -// println("accepted connection") -// var bytesRead = in.read(buffer) -// while (bytesRead != -1) { -// val received = new String(buffer, 0, bytesRead) -// println(s"Received: $received") -// bytesRead = in.read(buffer) -// } -// in.close() -// socket.close() -// } -// } -// -// t2.start() -// -// val thread2 = new Thread() { -// override def run(): Unit = { -// println("starting client") -// serverSocket.close() -// } -// } -// -// val t = new Thread(() => { -// println("hello from thread") -// }) -// -// t.start() -// t.join() // <-- this IS valid -// t2.join() - -// Thread.sleep(30000) - } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index e941adcff4..d759d8e37f 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -136,14 +136,14 @@ object ScalarUDF { binaryPandasFunc } - val geopandasNonGeometryToGeometryFunction: Array[Byte] = { - var binaryPandasFunc: Array[Byte] = null - withTempPath { path => - Process( - Seq( - pythonExec, - "-c", - f""" + val geopandasNonGeometryToGeometryFunction: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + f""" |from sedona.sql.types import GeometryType |from shapely.wkt import loads |from pyspark.serializers import CloudPickleSerializer @@ -152,13 +152,13 @@ object ScalarUDF { | return x.apply(lambda wkt: loads(wkt).buffer(1)) |f.write(CloudPickleSerializer().dumps((apply_geopandas, GeometryType()))) |""".stripMargin), - None, - "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! - binaryPandasFunc = Files.readAllBytes(path.toPath) - } - assert(binaryPandasFunc != null) - binaryPandasFunc + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) } + assert(binaryPandasFunc != null) + binaryPandasFunc + } private val workerEnv = new java.util.HashMap[String, String]() @@ -218,12 +218,10 @@ object ScalarUDF { pythonExec = pythonExec, pythonVer = pythonVer, broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, - accumulator = null - ), + accumulator = null), dataType = FloatType, pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, - udfDeterministic = false - ) + udfDeterministic = false) val sedonaDBGeometryToGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( name = "geospatial_udf",
