This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch arrow-worker in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 3b66bb5346336659f03be0880ce50259fc0f7775 Author: pawelkocinski <[email protected]> AuthorDate: Wed Nov 12 13:17:20 2025 +0100 SEDONA-748 Fix issue with no optimization for weighting function. --- sedonaworker/__init__.py | 0 sedonaworker/serializer/__init__.py | 1 - sedonaworker/serializer/serializer.py | 100 -------- sedonaworker/worker.py | 428 ---------------------------------- 4 files changed, 529 deletions(-) diff --git a/sedonaworker/__init__.py b/sedonaworker/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sedonaworker/serializer/__init__.py b/sedonaworker/serializer/__init__.py deleted file mode 100644 index 7a0599eac4..0000000000 --- a/sedonaworker/serializer/__init__.py +++ /dev/null @@ -1 +0,0 @@ -from .serializer import SedonaArrowStreamPandasUDFSerializer diff --git a/sedonaworker/serializer/serializer.py b/sedonaworker/serializer/serializer.py deleted file mode 100644 index 229c06ce04..0000000000 --- a/sedonaworker/serializer/serializer.py +++ /dev/null @@ -1,100 +0,0 @@ -from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer, ArrowStreamSerializer -from pyspark.errors import PySparkTypeError, PySparkValueError -import struct - -from pyspark.serializers import write_int - -def write_int(value, stream): - stream.write(struct.pack("!i", value)) - -class SpecialLengths: - START_ARROW_STREAM = -6 - - -class SedonaArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): - def __init__(self, timezone, safecheck, assign_cols_by_name): - super(SedonaArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck) - self._assign_cols_by_name = assign_cols_by_name - - def load_stream(self, stream): - import geoarrow.pyarrow as ga - import pyarrow as pa - - batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) - for batch in batches: - table = pa.Table.from_batches(batches=[batch]) - data = [] - - for c in table.itercolumns(): - meta = table.schema.field(c._name).metadata - if meta and meta[b"ARROW:extension:name"] == b'geoarrow.wkb': - data.append(ga.to_geopandas(c)) - continue - - data.append(self.arrow_to_pandas(c)) - - yield data - - def _create_batch(self, series): - import pyarrow as pa - - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - arrs = [] - for s, t in series: - arrs.append(self._create_array(s, t)) - - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - def _create_array(self, series, arrow_type): - import pyarrow as pa - import geopandas as gpd - - if hasattr(series.array, "__arrow_array__"): - mask = None - else: - mask = series.isnull() - - try: - if isinstance(series, gpd.GeoSeries): - import geoarrow.pyarrow as ga - # If the series is a GeoSeries, convert it to an Arrow array using geoarrow - return ga.array(series) - - array = pa.Array.from_pandas(series, mask=mask, type=arrow_type) - return array - except TypeError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e - except ValueError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e - - def dump_stream(self, iterator, stream): - """ - Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. - This should be sent after creating the first record batch so in case of an error, it can - be sent back to the JVM before the Arrow stream starts. - """ - - def init_stream_yield_batches(): - should_write_start_length = True - for series in iterator: - batch = self._create_batch(series) - if should_write_start_length: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - should_write_start_length = False - yield batch - - return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) - - def __repr__(self): - return "ArrowStreamPandasUDFSerializer" - - diff --git a/sedonaworker/worker.py b/sedonaworker/worker.py deleted file mode 100644 index c98b34b1d8..0000000000 --- a/sedonaworker/worker.py +++ /dev/null @@ -1,428 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Worker that receives input from Piped RDD. -""" -import os -import sys -import time -from inspect import currentframe, getframeinfo -import importlib - -from serializer import SedonaArrowStreamPandasUDFSerializer - -has_resource_module = True -try: - import resource -except ImportError: - has_resource_module = False -import traceback -import warnings -import faulthandler - -from pyspark.accumulators import _accumulatorRegistry -from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.java_gateway import local_connect_and_auth -from pyspark.taskcontext import BarrierTaskContext, TaskContext -from pyspark.files import SparkFiles -from pyspark.resource import ResourceInformation -from pyspark.rdd import PythonEvalType -from pyspark.serializers import ( - write_with_length, - write_int, - read_long, - read_bool, - write_long, - read_int, - SpecialLengths, - UTF8Deserializer, - CPickleSerializer, -) - -from pyspark.sql.pandas.types import to_arrow_type -from pyspark.sql.types import StructType -from pyspark.util import fail_on_stopiteration, try_simplify_traceback -from pyspark import shuffle -from pyspark.errors import PySparkRuntimeError, PySparkTypeError - -pickleSer = CPickleSerializer() -utf8_deserializer = UTF8Deserializer() - - -def report_times(outfile, boot, init, finish): - write_int(SpecialLengths.TIMING_DATA, outfile) - write_long(int(1000 * boot), outfile) - write_long(int(1000 * init), outfile) - write_long(int(1000 * finish), outfile) - - -def add_path(path): - # worker can be used, so do not add path multiple times - if path not in sys.path: - # overwrite system packages - sys.path.insert(1, path) - - -def read_command(serializer, file): - command = serializer._read_with_length(file) - if isinstance(command, Broadcast): - command = serializer.loads(command.value) - return command - - -def chain(f, g): - """chain two functions together""" - return lambda *a: g(f(*a)) - -def wrap_scalar_pandas_udf(f, return_type): - arrow_return_type = to_arrow_type(return_type) - - def verify_result_type(result): - if not hasattr(result, "__len__"): - pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - raise PySparkTypeError( - error_class="UDF_RETURN_TYPE", - message_parameters={ - "expected": pd_type, - "actual": type(result).__name__, - }, - ) - return result - - def verify_result_length(result, length): - if len(result) != length: - raise PySparkRuntimeError( - error_class="SCHEMA_MISMATCH_FOR_PANDAS_UDF", - message_parameters={ - "expected": str(length), - "actual": str(len(result)), - }, - ) - return result - - return lambda *a: ( - verify_result_length(verify_result_type(f(*a)), len(a[0])), - arrow_return_type, - ) - - -def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): - num_arg = read_int(infile) - arg_offsets = [read_int(infile) for i in range(num_arg)] - chained_func = None - for i in range(read_int(infile)): - f, return_type = read_command(pickleSer, infile) - if chained_func is None: - chained_func = f - else: - chained_func = chain(chained_func, f) - - func = fail_on_stopiteration(chained_func) - - # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return arg_offsets, wrap_scalar_pandas_udf(func, return_type), return_type - else: - raise ValueError("Unknown eval type: {}".format(eval_type)) - - -# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and SQL_ARROW_BATCHED_UDF when -# returning StructType -def assign_cols_by_name(runner_conf): - return ( - runner_conf.get( - "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true" - ).lower() - == "true" - ) - - -def read_udfs(pickleSer, infile, eval_type): - from sedona.sql.types import GeometryType - import geopandas as gpd - runner_conf = {} - - # Load conf used for pandas_udf evaluation - num_conf = read_int(infile) - for i in range(num_conf): - k = utf8_deserializer.loads(infile) - v = utf8_deserializer.loads(infile) - runner_conf[k] = v - - timezone = runner_conf.get("spark.sql.session.timeZone", None) - safecheck = ( - runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() - == "true" - ) - - num_udfs = read_int(infile) - - udfs = [] - for i in range(num_udfs): - udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i)) - - def mapper(a): - results = [] - - for (arg_offsets, f, return_type) in udfs: - result = f(*[a[o] for o in arg_offsets]) - if isinstance(return_type, GeometryType): - results.append(( - gpd.GeoSeries(result[0]), - result[1], - )) - - continue - - results.append(result) - - if len(results) == 1: - return results[0] - else: - return results - - def func(_, it): - return map(mapper, it) - - ser = SedonaArrowStreamPandasUDFSerializer( - timezone, - safecheck, - assign_cols_by_name(runner_conf), - ) - - # profiling is not supported for UDF - return func, None, ser, ser - - -def main(infile, outfile): - faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) - try: - if faulthandler_log_path: - faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) - faulthandler_log_file = open(faulthandler_log_path, "w") - faulthandler.enable(file=faulthandler_log_file) - - boot_time = time.time() - split_index = read_int(infile) - if split_index == -1: # for unit tests - sys.exit(-1) - - version = utf8_deserializer.loads(infile) - - if version != "%d.%d" % sys.version_info[:2]: - raise PySparkRuntimeError( - error_class="PYTHON_VERSION_MISMATCH", - message_parameters={ - "worker_version": str(sys.version_info[:2]), - "driver_version": str(version), - }, - ) - - # read inputs only for a barrier task - isBarrier = read_bool(infile) - boundPort = read_int(infile) - secret = UTF8Deserializer().loads(infile) - - # set up memory limits - memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1")) - if memory_limit_mb > 0 and has_resource_module: - total_memory = resource.RLIMIT_AS - try: - (soft_limit, hard_limit) = resource.getrlimit(total_memory) - msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) - print(msg, file=sys.stderr) - - # convert to bytes - new_limit = memory_limit_mb * 1024 * 1024 - - if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: - msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) - print(msg, file=sys.stderr) - resource.setrlimit(total_memory, (new_limit, new_limit)) - - except (resource.error, OSError, ValueError) as e: - # not all systems support resource limits, so warn instead of failing - lineno = ( - getframeinfo(currentframe()).lineno + 1 if currentframe() is not None else 0 - ) - if "__file__" in globals(): - print( - warnings.formatwarning( - "Failed to set memory limit: {0}".format(e), - ResourceWarning, - __file__, - lineno, - ), - file=sys.stderr, - ) - - # initialize global state - taskContext = None - if isBarrier: - taskContext = BarrierTaskContext._getOrCreate() - BarrierTaskContext._initialize(boundPort, secret) - # Set the task context instance here, so we can get it by TaskContext.get for - # both TaskContext and BarrierTaskContext - TaskContext._setTaskContext(taskContext) - else: - taskContext = TaskContext._getOrCreate() - # read inputs for TaskContext info - taskContext._stageId = read_int(infile) - taskContext._partitionId = read_int(infile) - taskContext._attemptNumber = read_int(infile) - taskContext._taskAttemptId = read_long(infile) - taskContext._cpus = read_int(infile) - taskContext._resources = {} - for r in range(read_int(infile)): - key = utf8_deserializer.loads(infile) - name = utf8_deserializer.loads(infile) - addresses = [] - taskContext._resources = {} - for a in range(read_int(infile)): - addresses.append(utf8_deserializer.loads(infile)) - taskContext._resources[key] = ResourceInformation(name, addresses) - - taskContext._localProperties = dict() - for i in range(read_int(infile)): - k = utf8_deserializer.loads(infile) - v = utf8_deserializer.loads(infile) - taskContext._localProperties[k] = v - - shuffle.MemoryBytesSpilled = 0 - shuffle.DiskBytesSpilled = 0 - _accumulatorRegistry.clear() - - # fetch name of workdir - spark_files_dir = utf8_deserializer.loads(infile) - SparkFiles._root_directory = spark_files_dir - SparkFiles._is_running_on_worker = True - - # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - add_path(spark_files_dir) # *.py files that were added will be copied here - num_python_includes = read_int(infile) - for _ in range(num_python_includes): - filename = utf8_deserializer.loads(infile) - add_path(os.path.join(spark_files_dir, filename)) - - importlib.invalidate_caches() - - # fetch names and values of broadcast variables - needs_broadcast_decryption_server = read_bool(infile) - num_broadcast_variables = read_int(infile) - if needs_broadcast_decryption_server: - # read the decrypted data from a server in the jvm - port = read_int(infile) - auth_secret = utf8_deserializer.loads(infile) - (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) - - for _ in range(num_broadcast_variables): - bid = read_long(infile) - if bid >= 0: - if needs_broadcast_decryption_server: - read_bid = read_long(broadcast_sock_file) - assert read_bid == bid - _broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file) - else: - path = utf8_deserializer.loads(infile) - _broadcastRegistry[bid] = Broadcast(path=path) - - else: - bid = -bid - 1 - _broadcastRegistry.pop(bid) - - if needs_broadcast_decryption_server: - broadcast_sock_file.write(b"1") - broadcast_sock_file.close() - - _accumulatorRegistry.clear() - eval_type = read_int(infile) - func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) - init_time = time.time() - - def process(): - iterator = deserializer.load_stream(infile) - out_iter = func(split_index, iterator) - try: - serializer.dump_stream(out_iter, outfile) - finally: - if hasattr(out_iter, "close"): - out_iter.close() - - if profiler: - profiler.profile(process) - else: - process() - - # Reset task context to None. This is a guard code to avoid residual context when worker - # reuse. - TaskContext._setTaskContext(None) - BarrierTaskContext._setTaskContext(None) - except BaseException as e: - try: - exc_info = None - if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False): - tb = try_simplify_traceback(sys.exc_info()[-1]) - if tb is not None: - e.__cause__ = None - exc_info = "".join(traceback.format_exception(type(e), e, tb)) - if exc_info is None: - exc_info = traceback.format_exc() - - write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(exc_info.encode("utf-8"), outfile) - except IOError: - # JVM close the socket - pass - except BaseException: - # Write the error to stderr if it happened while serializing - print("PySpark worker failed with exception:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - sys.exit(-1) - finally: - if faulthandler_log_path: - faulthandler.disable() - faulthandler_log_file.close() - os.remove(faulthandler_log_path) - finish_time = time.time() - report_times(outfile, boot_time, init_time, finish_time) - write_long(shuffle.MemoryBytesSpilled, outfile) - write_long(shuffle.DiskBytesSpilled, outfile) - - # Mark the beginning of the accumulators section of the output - write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - write_int(len(_accumulatorRegistry), outfile) - for (aid, accum) in _accumulatorRegistry.items(): - pickleSer._write_with_length((aid, accum._value), outfile) - - # check end of stream - if read_int(infile) == SpecialLengths.END_OF_STREAM: - write_int(SpecialLengths.END_OF_STREAM, outfile) - else: - # write a different value to tell JVM to not reuse this worker - write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - sys.exit(-1) - - -if __name__ == "__main__": - # Read information about how to connect back to the JVM from the environment. - java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) - auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth(java_port, auth_secret) - write_int(os.getpid(), sock_file) - sock_file.flush() - main(sock_file, sock_file)
