This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch arrow-worker
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 3b66bb5346336659f03be0880ce50259fc0f7775
Author: pawelkocinski <[email protected]>
AuthorDate: Wed Nov 12 13:17:20 2025 +0100

    SEDONA-748 Fix issue with no optimization for weighting function.
---
 sedonaworker/__init__.py              |   0
 sedonaworker/serializer/__init__.py   |   1 -
 sedonaworker/serializer/serializer.py | 100 --------
 sedonaworker/worker.py                | 428 ----------------------------------
 4 files changed, 529 deletions(-)

diff --git a/sedonaworker/__init__.py b/sedonaworker/__init__.py
deleted file mode 100644
index e69de29bb2..0000000000
diff --git a/sedonaworker/serializer/__init__.py 
b/sedonaworker/serializer/__init__.py
deleted file mode 100644
index 7a0599eac4..0000000000
--- a/sedonaworker/serializer/__init__.py
+++ /dev/null
@@ -1 +0,0 @@
-from .serializer import SedonaArrowStreamPandasUDFSerializer
diff --git a/sedonaworker/serializer/serializer.py 
b/sedonaworker/serializer/serializer.py
deleted file mode 100644
index 229c06ce04..0000000000
--- a/sedonaworker/serializer/serializer.py
+++ /dev/null
@@ -1,100 +0,0 @@
-from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer, 
ArrowStreamSerializer
-from pyspark.errors import PySparkTypeError, PySparkValueError
-import struct
-
-from pyspark.serializers import write_int
-
-def write_int(value, stream):
-    stream.write(struct.pack("!i", value))
-
-class SpecialLengths:
-    START_ARROW_STREAM = -6
-
-
-class SedonaArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer):
-    def __init__(self, timezone, safecheck, assign_cols_by_name):
-        super(SedonaArrowStreamPandasUDFSerializer, self).__init__(timezone, 
safecheck)
-        self._assign_cols_by_name = assign_cols_by_name
-
-    def load_stream(self, stream):
-        import geoarrow.pyarrow as ga
-        import pyarrow as pa
-
-        batches = super(ArrowStreamPandasSerializer, self).load_stream(stream)
-        for batch in batches:
-            table = pa.Table.from_batches(batches=[batch])
-            data = []
-
-            for c in table.itercolumns():
-                meta = table.schema.field(c._name).metadata
-                if meta and meta[b"ARROW:extension:name"] == b'geoarrow.wkb':
-                    data.append(ga.to_geopandas(c))
-                    continue
-
-                data.append(self.arrow_to_pandas(c))
-
-            yield data
-
-    def _create_batch(self, series):
-        import pyarrow as pa
-
-        series = ((s, None) if not isinstance(s, (list, tuple)) else s for s 
in series)
-
-        arrs = []
-        for s, t in series:
-            arrs.append(self._create_array(s, t))
-
-        return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in 
range(len(arrs))])
-
-    def _create_array(self, series, arrow_type):
-        import pyarrow as pa
-        import geopandas as gpd
-
-        if hasattr(series.array, "__arrow_array__"):
-            mask = None
-        else:
-            mask = series.isnull()
-
-        try:
-            if isinstance(series, gpd.GeoSeries):
-                import geoarrow.pyarrow as ga
-                # If the series is a GeoSeries, convert it to an Arrow array 
using geoarrow
-                return ga.array(series)
-
-            array = pa.Array.from_pandas(series, mask=mask, type=arrow_type)
-            return array
-        except TypeError as e:
-            error_msg = (
-                "Exception thrown when converting pandas.Series (%s) "
-                "with name '%s' to Arrow Array (%s)."
-            )
-            raise PySparkTypeError(error_msg % (series.dtype, series.name, 
arrow_type)) from e
-        except ValueError as e:
-            error_msg = (
-                "Exception thrown when converting pandas.Series (%s) "
-                "with name '%s' to Arrow Array (%s)."
-            )
-            raise PySparkValueError(error_msg % (series.dtype, series.name, 
arrow_type)) from e
-
-    def dump_stream(self, iterator, stream):
-        """
-        Override because Pandas UDFs require a START_ARROW_STREAM before the 
Arrow stream is sent.
-        This should be sent after creating the first record batch so in case 
of an error, it can
-        be sent back to the JVM before the Arrow stream starts.
-        """
-
-        def init_stream_yield_batches():
-            should_write_start_length = True
-            for series in iterator:
-                batch = self._create_batch(series)
-                if should_write_start_length:
-                    write_int(SpecialLengths.START_ARROW_STREAM, stream)
-                    should_write_start_length = False
-                yield batch
-
-        return ArrowStreamSerializer.dump_stream(self, 
init_stream_yield_batches(), stream)
-
-    def __repr__(self):
-        return "ArrowStreamPandasUDFSerializer"
-
-
diff --git a/sedonaworker/worker.py b/sedonaworker/worker.py
deleted file mode 100644
index c98b34b1d8..0000000000
--- a/sedonaworker/worker.py
+++ /dev/null
@@ -1,428 +0,0 @@
-#
-# Licensed to the Apache Software Foundation (ASF) under one or more
-# contributor license agreements.  See the NOTICE file distributed with
-# this work for additional information regarding copyright ownership.
-# The ASF licenses this file to You under the Apache License, Version 2.0
-# (the "License"); you may not use this file except in compliance with
-# the License.  You may obtain a copy of the License at
-#
-#    http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing, software
-# distributed under the License is distributed on an "AS IS" BASIS,
-# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
-# See the License for the specific language governing permissions and
-# limitations under the License.
-#
-
-"""
-Worker that receives input from Piped RDD.
-"""
-import os
-import sys
-import time
-from inspect import currentframe, getframeinfo
-import importlib
-
-from serializer import SedonaArrowStreamPandasUDFSerializer
-
-has_resource_module = True
-try:
-    import resource
-except ImportError:
-    has_resource_module = False
-import traceback
-import warnings
-import faulthandler
-
-from pyspark.accumulators import _accumulatorRegistry
-from pyspark.broadcast import Broadcast, _broadcastRegistry
-from pyspark.java_gateway import local_connect_and_auth
-from pyspark.taskcontext import BarrierTaskContext, TaskContext
-from pyspark.files import SparkFiles
-from pyspark.resource import ResourceInformation
-from pyspark.rdd import PythonEvalType
-from pyspark.serializers import (
-    write_with_length,
-    write_int,
-    read_long,
-    read_bool,
-    write_long,
-    read_int,
-    SpecialLengths,
-    UTF8Deserializer,
-    CPickleSerializer,
-)
-
-from pyspark.sql.pandas.types import to_arrow_type
-from pyspark.sql.types import StructType
-from pyspark.util import fail_on_stopiteration, try_simplify_traceback
-from pyspark import shuffle
-from pyspark.errors import PySparkRuntimeError, PySparkTypeError
-
-pickleSer = CPickleSerializer()
-utf8_deserializer = UTF8Deserializer()
-
-
-def report_times(outfile, boot, init, finish):
-    write_int(SpecialLengths.TIMING_DATA, outfile)
-    write_long(int(1000 * boot), outfile)
-    write_long(int(1000 * init), outfile)
-    write_long(int(1000 * finish), outfile)
-
-
-def add_path(path):
-    # worker can be used, so do not add path multiple times
-    if path not in sys.path:
-        # overwrite system packages
-        sys.path.insert(1, path)
-
-
-def read_command(serializer, file):
-    command = serializer._read_with_length(file)
-    if isinstance(command, Broadcast):
-        command = serializer.loads(command.value)
-    return command
-
-
-def chain(f, g):
-    """chain two functions together"""
-    return lambda *a: g(f(*a))
-
-def wrap_scalar_pandas_udf(f, return_type):
-    arrow_return_type = to_arrow_type(return_type)
-
-    def verify_result_type(result):
-        if not hasattr(result, "__len__"):
-            pd_type = "pandas.DataFrame" if type(return_type) == StructType 
else "pandas.Series"
-            raise PySparkTypeError(
-                error_class="UDF_RETURN_TYPE",
-                message_parameters={
-                    "expected": pd_type,
-                    "actual": type(result).__name__,
-                },
-            )
-        return result
-
-    def verify_result_length(result, length):
-        if len(result) != length:
-            raise PySparkRuntimeError(
-                error_class="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
-                message_parameters={
-                    "expected": str(length),
-                    "actual": str(len(result)),
-                },
-            )
-        return result
-
-    return lambda *a: (
-        verify_result_length(verify_result_type(f(*a)), len(a[0])),
-        arrow_return_type,
-    )
-
-
-def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
-    num_arg = read_int(infile)
-    arg_offsets = [read_int(infile) for i in range(num_arg)]
-    chained_func = None
-    for i in range(read_int(infile)):
-        f, return_type = read_command(pickleSer, infile)
-        if chained_func is None:
-            chained_func = f
-        else:
-            chained_func = chain(chained_func, f)
-
-    func = fail_on_stopiteration(chained_func)
-
-    # the last returnType will be the return type of UDF
-    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
-        return arg_offsets, wrap_scalar_pandas_udf(func, return_type), 
return_type
-    else:
-        raise ValueError("Unknown eval type: {}".format(eval_type))
-
-
-# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and 
SQL_ARROW_BATCHED_UDF when
-# returning StructType
-def assign_cols_by_name(runner_conf):
-    return (
-            runner_conf.get(
-                
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true"
-            ).lower()
-            == "true"
-    )
-
-
-def read_udfs(pickleSer, infile, eval_type):
-    from sedona.sql.types import GeometryType
-    import geopandas as gpd
-    runner_conf = {}
-
-    # Load conf used for pandas_udf evaluation
-    num_conf = read_int(infile)
-    for i in range(num_conf):
-        k = utf8_deserializer.loads(infile)
-        v = utf8_deserializer.loads(infile)
-        runner_conf[k] = v
-
-    timezone = runner_conf.get("spark.sql.session.timeZone", None)
-    safecheck = (
-            
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", 
"false").lower()
-            == "true"
-    )
-
-    num_udfs = read_int(infile)
-
-    udfs = []
-    for i in range(num_udfs):
-        udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, 
udf_index=i))
-
-    def mapper(a):
-        results = []
-
-        for (arg_offsets, f, return_type) in udfs:
-            result = f(*[a[o] for o in arg_offsets])
-            if isinstance(return_type, GeometryType):
-                results.append((
-                    gpd.GeoSeries(result[0]),
-                    result[1],
-                ))
-
-                continue
-
-            results.append(result)
-
-        if len(results) == 1:
-            return results[0]
-        else:
-            return results
-
-    def func(_, it):
-        return map(mapper, it)
-
-    ser = SedonaArrowStreamPandasUDFSerializer(
-        timezone,
-        safecheck,
-        assign_cols_by_name(runner_conf),
-    )
-
-    # profiling is not supported for UDF
-    return func, None, ser, ser
-
-
-def main(infile, outfile):
-    faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
-    try:
-        if faulthandler_log_path:
-            faulthandler_log_path = os.path.join(faulthandler_log_path, 
str(os.getpid()))
-            faulthandler_log_file = open(faulthandler_log_path, "w")
-            faulthandler.enable(file=faulthandler_log_file)
-
-        boot_time = time.time()
-        split_index = read_int(infile)
-        if split_index == -1:  # for unit tests
-            sys.exit(-1)
-
-        version = utf8_deserializer.loads(infile)
-
-        if version != "%d.%d" % sys.version_info[:2]:
-            raise PySparkRuntimeError(
-                error_class="PYTHON_VERSION_MISMATCH",
-                message_parameters={
-                    "worker_version": str(sys.version_info[:2]),
-                    "driver_version": str(version),
-                },
-            )
-
-        # read inputs only for a barrier task
-        isBarrier = read_bool(infile)
-        boundPort = read_int(infile)
-        secret = UTF8Deserializer().loads(infile)
-
-        # set up memory limits
-        memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", 
"-1"))
-        if memory_limit_mb > 0 and has_resource_module:
-            total_memory = resource.RLIMIT_AS
-            try:
-                (soft_limit, hard_limit) = resource.getrlimit(total_memory)
-                msg = "Current mem limits: {0} of max 
{1}\n".format(soft_limit, hard_limit)
-                print(msg, file=sys.stderr)
-
-                # convert to bytes
-                new_limit = memory_limit_mb * 1024 * 1024
-
-                if soft_limit == resource.RLIM_INFINITY or new_limit < 
soft_limit:
-                    msg = "Setting mem limits to {0} of max 
{1}\n".format(new_limit, new_limit)
-                    print(msg, file=sys.stderr)
-                    resource.setrlimit(total_memory, (new_limit, new_limit))
-
-            except (resource.error, OSError, ValueError) as e:
-                # not all systems support resource limits, so warn instead of 
failing
-                lineno = (
-                    getframeinfo(currentframe()).lineno + 1 if currentframe() 
is not None else 0
-                )
-                if "__file__" in globals():
-                    print(
-                        warnings.formatwarning(
-                            "Failed to set memory limit: {0}".format(e),
-                            ResourceWarning,
-                            __file__,
-                            lineno,
-                        ),
-                        file=sys.stderr,
-                    )
-
-        # initialize global state
-        taskContext = None
-        if isBarrier:
-            taskContext = BarrierTaskContext._getOrCreate()
-            BarrierTaskContext._initialize(boundPort, secret)
-            # Set the task context instance here, so we can get it by 
TaskContext.get for
-            # both TaskContext and BarrierTaskContext
-            TaskContext._setTaskContext(taskContext)
-        else:
-            taskContext = TaskContext._getOrCreate()
-        # read inputs for TaskContext info
-        taskContext._stageId = read_int(infile)
-        taskContext._partitionId = read_int(infile)
-        taskContext._attemptNumber = read_int(infile)
-        taskContext._taskAttemptId = read_long(infile)
-        taskContext._cpus = read_int(infile)
-        taskContext._resources = {}
-        for r in range(read_int(infile)):
-            key = utf8_deserializer.loads(infile)
-            name = utf8_deserializer.loads(infile)
-            addresses = []
-            taskContext._resources = {}
-            for a in range(read_int(infile)):
-                addresses.append(utf8_deserializer.loads(infile))
-            taskContext._resources[key] = ResourceInformation(name, addresses)
-
-        taskContext._localProperties = dict()
-        for i in range(read_int(infile)):
-            k = utf8_deserializer.loads(infile)
-            v = utf8_deserializer.loads(infile)
-            taskContext._localProperties[k] = v
-
-        shuffle.MemoryBytesSpilled = 0
-        shuffle.DiskBytesSpilled = 0
-        _accumulatorRegistry.clear()
-
-        # fetch name of workdir
-        spark_files_dir = utf8_deserializer.loads(infile)
-        SparkFiles._root_directory = spark_files_dir
-        SparkFiles._is_running_on_worker = True
-
-        # fetch names of includes (*.zip and *.egg files) and construct 
PYTHONPATH
-        add_path(spark_files_dir)  # *.py files that were added will be copied 
here
-        num_python_includes = read_int(infile)
-        for _ in range(num_python_includes):
-            filename = utf8_deserializer.loads(infile)
-            add_path(os.path.join(spark_files_dir, filename))
-
-        importlib.invalidate_caches()
-
-        # fetch names and values of broadcast variables
-        needs_broadcast_decryption_server = read_bool(infile)
-        num_broadcast_variables = read_int(infile)
-        if needs_broadcast_decryption_server:
-            # read the decrypted data from a server in the jvm
-            port = read_int(infile)
-            auth_secret = utf8_deserializer.loads(infile)
-            (broadcast_sock_file, _) = local_connect_and_auth(port, 
auth_secret)
-
-        for _ in range(num_broadcast_variables):
-            bid = read_long(infile)
-            if bid >= 0:
-                if needs_broadcast_decryption_server:
-                    read_bid = read_long(broadcast_sock_file)
-                    assert read_bid == bid
-                    _broadcastRegistry[bid] = 
Broadcast(sock_file=broadcast_sock_file)
-                else:
-                    path = utf8_deserializer.loads(infile)
-                    _broadcastRegistry[bid] = Broadcast(path=path)
-
-            else:
-                bid = -bid - 1
-                _broadcastRegistry.pop(bid)
-
-        if needs_broadcast_decryption_server:
-            broadcast_sock_file.write(b"1")
-            broadcast_sock_file.close()
-
-        _accumulatorRegistry.clear()
-        eval_type = read_int(infile)
-        func, profiler, deserializer, serializer = read_udfs(pickleSer, 
infile, eval_type)
-        init_time = time.time()
-
-        def process():
-            iterator = deserializer.load_stream(infile)
-            out_iter = func(split_index, iterator)
-            try:
-                serializer.dump_stream(out_iter, outfile)
-            finally:
-                if hasattr(out_iter, "close"):
-                    out_iter.close()
-
-        if profiler:
-            profiler.profile(process)
-        else:
-            process()
-
-        # Reset task context to None. This is a guard code to avoid residual 
context when worker
-        # reuse.
-        TaskContext._setTaskContext(None)
-        BarrierTaskContext._setTaskContext(None)
-    except BaseException as e:
-        try:
-            exc_info = None
-            if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
-                tb = try_simplify_traceback(sys.exc_info()[-1])
-                if tb is not None:
-                    e.__cause__ = None
-                    exc_info = "".join(traceback.format_exception(type(e), e, 
tb))
-            if exc_info is None:
-                exc_info = traceback.format_exc()
-
-            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
-            write_with_length(exc_info.encode("utf-8"), outfile)
-        except IOError:
-            # JVM close the socket
-            pass
-        except BaseException:
-            # Write the error to stderr if it happened while serializing
-            print("PySpark worker failed with exception:", file=sys.stderr)
-            print(traceback.format_exc(), file=sys.stderr)
-        sys.exit(-1)
-    finally:
-        if faulthandler_log_path:
-            faulthandler.disable()
-            faulthandler_log_file.close()
-            os.remove(faulthandler_log_path)
-    finish_time = time.time()
-    report_times(outfile, boot_time, init_time, finish_time)
-    write_long(shuffle.MemoryBytesSpilled, outfile)
-    write_long(shuffle.DiskBytesSpilled, outfile)
-
-    # Mark the beginning of the accumulators section of the output
-    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-    write_int(len(_accumulatorRegistry), outfile)
-    for (aid, accum) in _accumulatorRegistry.items():
-        pickleSer._write_with_length((aid, accum._value), outfile)
-
-    # check end of stream
-    if read_int(infile) == SpecialLengths.END_OF_STREAM:
-        write_int(SpecialLengths.END_OF_STREAM, outfile)
-    else:
-        # write a different value to tell JVM to not reuse this worker
-        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
-        sys.exit(-1)
-
-
-if __name__ == "__main__":
-    # Read information about how to connect back to the JVM from the 
environment.
-    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
-    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
-    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
-    write_int(os.getpid(), sock_file)
-    sock_file.flush()
-    main(sock_file, sock_file)

Reply via email to