This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch arrow-worker
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 172c9412e4bfb1d9b5d017851485bd1263c493af
Author: pawelkocinski <[email protected]>
AuthorDate: Sun Jul 27 00:26:37 2025 +0200

    SEDONA-738 Fix unit tests.
---
 pom.xml                                            |  12 +-
 sedonaworker/__init__.py                           |   0
 sedonaworker/worker.py                             | 643 ++++++++++++++++
 .../scala/org/apache/spark/SedonaSparkEnv.scala    | 495 +++++++++++++
 .../spark/api/python/SedonaPythonRunner.scala      | 811 +++++++++++++++++++++
 .../execution/python/SedonaArrowPythonRunner.scala |  70 ++
 .../execution/python/SedonaPythonArrowInput.scala  | 148 ++++
 .../execution/python/SedonaPythonArrowOutput.scala | 135 ++++
 .../execution/python/SedonaPythonUDFRunner.scala   | 147 ++++
 .../apache/spark/sql/udf/SedonaArrowStrategy.scala |   4 +-
 .../apache/spark/sql/udf/TestScalarPandasUDF.scala |  21 +-
 11 files changed, 2475 insertions(+), 11 deletions(-)

diff --git a/pom.xml b/pom.xml
index 44a1dcb16a..c8ae8b50e6 100644
--- a/pom.xml
+++ b/pom.xml
@@ -19,12 +19,12 @@
 
 <project xmlns="http://maven.apache.org/POM/4.0.0"; 
xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance"; 
xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 
http://maven.apache.org/maven-v4_0_0.xsd";>
     <modelVersion>4.0.0</modelVersion>
-    <parent>
-        <groupId>org.apache</groupId>
-        <artifactId>apache</artifactId>
-        <version>23</version>
-        <relativePath />
-    </parent>
+<!--    <parent>-->
+<!--        <groupId>org.apache</groupId>-->
+<!--        <artifactId>apache</artifactId>-->
+<!--        <version>23</version>-->
+<!--        <relativePath />-->
+<!--    </parent>-->
     <groupId>org.apache.sedona</groupId>
     <artifactId>sedona-parent</artifactId>
     <version>1.8.1-SNAPSHOT</version>
diff --git a/sedonaworker/__init__.py b/sedonaworker/__init__.py
new file mode 100644
index 0000000000..e69de29bb2
diff --git a/sedonaworker/worker.py b/sedonaworker/worker.py
new file mode 100644
index 0000000000..42fb20beb3
--- /dev/null
+++ b/sedonaworker/worker.py
@@ -0,0 +1,643 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one or more
+# contributor license agreements.  See the NOTICE file distributed with
+# this work for additional information regarding copyright ownership.
+# The ASF licenses this file to You under the Apache License, Version 2.0
+# (the "License"); you may not use this file except in compliance with
+# the License.  You may obtain a copy of the License at
+#
+#    http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing, software
+# distributed under the License is distributed on an "AS IS" BASIS,
+# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+# See the License for the specific language governing permissions and
+# limitations under the License.
+#
+
+"""
+Worker that receives input from Piped RDD.
+"""
+import os
+import sys
+import time
+from inspect import currentframe, getframeinfo, getfullargspec
+import importlib
+import json
+from typing import Any, Iterable, Iterator
+
+# 'resource' is a Unix specific module.
+has_resource_module = True
+try:
+    import resource
+except ImportError:
+    has_resource_module = False
+import traceback
+import warnings
+import faulthandler
+
+from pyspark.accumulators import _accumulatorRegistry
+from pyspark.broadcast import Broadcast, _broadcastRegistry
+from pyspark.java_gateway import local_connect_and_auth
+from pyspark.taskcontext import BarrierTaskContext, TaskContext
+from pyspark.files import SparkFiles
+from pyspark.resource import ResourceInformation
+from pyspark.rdd import PythonEvalType
+from pyspark.serializers import (
+    write_with_length,
+    write_int,
+    read_long,
+    read_bool,
+    write_long,
+    read_int,
+    SpecialLengths,
+    UTF8Deserializer,
+    CPickleSerializer,
+    BatchedSerializer,
+)
+from pyspark.sql.pandas.serializers import (
+    ArrowStreamPandasUDFSerializer,
+    ArrowStreamPandasUDTFSerializer,
+    CogroupUDFSerializer,
+    ArrowStreamUDFSerializer,
+    ApplyInPandasWithStateSerializer,
+)
+from pyspark.sql.pandas.types import to_arrow_type
+from pyspark.sql.types import BinaryType, StringType, StructType, 
_parse_datatype_json_string
+from pyspark.util import fail_on_stopiteration, try_simplify_traceback
+from pyspark import shuffle
+from pyspark.errors import PySparkRuntimeError, PySparkTypeError
+
+pickleSer = CPickleSerializer()
+utf8_deserializer = UTF8Deserializer()
+
+
+def report_times(outfile, boot, init, finish):
+    write_int(SpecialLengths.TIMING_DATA, outfile)
+    write_long(int(1000 * boot), outfile)
+    write_long(int(1000 * init), outfile)
+    write_long(int(1000 * finish), outfile)
+
+
+def add_path(path):
+    # worker can be used, so do not add path multiple times
+    if path not in sys.path:
+        # overwrite system packages
+        sys.path.insert(1, path)
+
+
+def read_command(serializer, file):
+    command = serializer._read_with_length(file)
+    if isinstance(command, Broadcast):
+        command = serializer.loads(command.value)
+    return command
+
+
+def chain(f, g):
+    """chain two functions together"""
+    return lambda *a: g(f(*a))
+
+
+# def wrap_udf(f, return_type):
+#     if return_type.needConversion():
+#         toInternal = return_type.toInternal
+#         return lambda *a: toInternal(f(*a))
+#     else:
+#         return lambda *a: f(*a)
+
+
+def wrap_scalar_pandas_udf(f, return_type):
+    arrow_return_type = to_arrow_type(return_type)
+
+    def verify_result_type(result):
+        if not hasattr(result, "__len__"):
+            pd_type = "pandas.DataFrame" if type(return_type) == StructType 
else "pandas.Series"
+            raise PySparkTypeError(
+                error_class="UDF_RETURN_TYPE",
+                message_parameters={
+                    "expected": pd_type,
+                    "actual": type(result).__name__,
+                },
+            )
+        return result
+
+    def verify_result_length(result, length):
+        if len(result) != length:
+            raise PySparkRuntimeError(
+                error_class="SCHEMA_MISMATCH_FOR_PANDAS_UDF",
+                message_parameters={
+                    "expected": str(length),
+                    "actual": str(len(result)),
+                },
+            )
+        return result
+
+    return lambda *a: (
+        verify_result_length(verify_result_type(f(*a)), len(a[0])),
+        arrow_return_type,
+    )
+
+
+def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index):
+    num_arg = read_int(infile)
+    arg_offsets = [read_int(infile) for i in range(num_arg)]
+    chained_func = None
+    for i in range(read_int(infile)):
+        f, return_type = read_command(pickleSer, infile)
+        if chained_func is None:
+            chained_func = f
+        else:
+            chained_func = chain(chained_func, f)
+
+    if eval_type in (
+            PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF,
+            PythonEvalType.SQL_ARROW_BATCHED_UDF,
+    ):
+        func = chained_func
+    else:
+        # make sure StopIteration's raised in the user code are not ignored
+        # when they are processed in a for loop, raise them as RuntimeError's 
instead
+        func = fail_on_stopiteration(chained_func)
+
+    # the last returnType will be the return type of UDF
+    if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF:
+        return arg_offsets, wrap_scalar_pandas_udf(func, return_type)
+    else:
+        raise ValueError("Unknown eval type: {}".format(eval_type))
+
+
+# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and 
SQL_ARROW_BATCHED_UDF when
+# returning StructType
+def assign_cols_by_name(runner_conf):
+    return (
+            runner_conf.get(
+                
"spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true"
+            ).lower()
+            == "true"
+    )
+
+
+def read_udfs(pickleSer, infile, eval_type):
+    runner_conf = {}
+
+    if eval_type in (
+            PythonEvalType.SQL_SCALAR_PANDAS_UDF,
+    ):
+
+        # Load conf used for pandas_udf evaluation
+        num_conf = read_int(infile)
+        for i in range(num_conf):
+            k = utf8_deserializer.loads(infile)
+            v = utf8_deserializer.loads(infile)
+            runner_conf[k] = v
+
+        state_object_schema = None
+        if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+            state_object_schema = 
StructType.fromJson(json.loads(utf8_deserializer.loads(infile)))
+
+        # NOTE: if timezone is set here, that implies respectSessionTimeZone 
is True
+        timezone = runner_conf.get("spark.sql.session.timeZone", None)
+        safecheck = (
+                
runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", 
"false").lower()
+                == "true"
+        )
+
+        if eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
+            ser = CogroupUDFSerializer(timezone, safecheck, 
assign_cols_by_name(runner_conf))
+        elif eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF:
+            ser = ArrowStreamUDFSerializer()
+        else:
+            # Scalar Pandas UDF handles struct type arguments as pandas 
DataFrames instead of
+            # pandas Series. See SPARK-27240.
+            df_for_struct = (
+                    eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF
+                    or eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
+                    or eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
+            )
+            # Arrow-optimized Python UDF takes a struct type argument as a Row
+            struct_in_pandas = (
+                "row" if eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF 
else "dict"
+            )
+            ndarray_as_list = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
+            # Arrow-optimized Python UDF uses explicit Arrow cast for type 
coercion
+            arrow_cast = eval_type == PythonEvalType.SQL_ARROW_BATCHED_UDF
+            ser = ArrowStreamPandasUDFSerializer(
+                timezone,
+                safecheck,
+                assign_cols_by_name(runner_conf),
+                df_for_struct,
+                struct_in_pandas,
+                ndarray_as_list,
+                arrow_cast,
+            )
+    else:
+        ser = BatchedSerializer(CPickleSerializer(), 100)
+
+    num_udfs = read_int(infile)
+
+    is_scalar_iter = eval_type == PythonEvalType.SQL_SCALAR_PANDAS_ITER_UDF
+    is_map_pandas_iter = eval_type == PythonEvalType.SQL_MAP_PANDAS_ITER_UDF
+    is_map_arrow_iter = eval_type == PythonEvalType.SQL_MAP_ARROW_ITER_UDF
+
+    if is_scalar_iter or is_map_pandas_iter or is_map_arrow_iter:
+        if is_scalar_iter:
+            assert num_udfs == 1, "One SCALAR_ITER UDF expected here."
+        if is_map_pandas_iter:
+            assert num_udfs == 1, "One MAP_PANDAS_ITER UDF expected here."
+        if is_map_arrow_iter:
+            assert num_udfs == 1, "One MAP_ARROW_ITER UDF expected here."
+
+        arg_offsets, udf = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+
+        def func(_, iterator):
+            num_input_rows = 0
+
+            def map_batch(batch):
+                nonlocal num_input_rows
+
+                udf_args = [batch[offset] for offset in arg_offsets]
+                num_input_rows += len(udf_args[0])
+                if len(udf_args) == 1:
+                    return udf_args[0]
+                else:
+                    return tuple(udf_args)
+
+            iterator = map(map_batch, iterator)
+            result_iter = udf(iterator)
+
+            num_output_rows = 0
+            for result_batch, result_type in result_iter:
+                num_output_rows += len(result_batch)
+                # This assert is for Scalar Iterator UDF to fail fast.
+                # The length of the entire input can only be explicitly known
+                # by consuming the input iterator in user side. Therefore,
+                # it's very unlikely the output length is higher than
+                # input length.
+                assert (
+                        is_map_pandas_iter or is_map_arrow_iter or 
num_output_rows <= num_input_rows
+                ), "Pandas SCALAR_ITER UDF outputted more rows than input 
rows."
+                yield (result_batch, result_type)
+
+            if is_scalar_iter:
+                try:
+                    next(iterator)
+                except StopIteration:
+                    pass
+                else:
+                    raise PySparkRuntimeError(
+                        
error_class="STOP_ITERATION_OCCURRED_FROM_SCALAR_ITER_PANDAS_UDF",
+                        message_parameters={},
+                    )
+
+                if num_output_rows != num_input_rows:
+                    raise PySparkRuntimeError(
+                        
error_class="RESULT_LENGTH_MISMATCH_FOR_SCALAR_ITER_PANDAS_UDF",
+                        message_parameters={
+                            "output_length": str(num_output_rows),
+                            "input_length": str(num_input_rows),
+                        },
+                    )
+
+        # profiling is not supported for UDF
+        return func, None, ser, ser
+
+    def extract_key_value_indexes(grouped_arg_offsets):
+        """
+        Helper function to extract the key and value indexes from arg_offsets 
for the grouped and
+        cogrouped pandas udfs. See BasePandasGroupExec.resolveArgOffsets for 
equivalent scala code.
+
+        Parameters
+        ----------
+        grouped_arg_offsets:  list
+            List containing the key and value indexes of columns of the
+            DataFrames to be passed to the udf. It consists of n repeating 
groups where n is the
+            number of DataFrames.  Each group has the following format:
+                group[0]: length of group
+                group[1]: length of key indexes
+                group[2.. group[1] +2]: key attributes
+                group[group[1] +3 group[0]]: value attributes
+        """
+        parsed = []
+        idx = 0
+        while idx < len(grouped_arg_offsets):
+            offsets_len = grouped_arg_offsets[idx]
+            idx += 1
+            offsets = grouped_arg_offsets[idx : idx + offsets_len]
+            split_index = offsets[0] + 1
+            offset_keys = offsets[1:split_index]
+            offset_values = offsets[split_index:]
+            parsed.append([offset_keys, offset_values])
+            idx += offsets_len
+        return parsed
+
+    if eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF:
+        # We assume there is only one UDF here because grouped map doesn't
+        # support combining multiple UDFs.
+        assert num_udfs == 1
+
+        # See FlatMapGroupsInPandasExec for how arg_offsets are used to
+        # distinguish between grouping attributes and data attributes
+        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+        # Create function like this:
+        #   mapper a: f([a[0]], [a[0], a[1]])
+        def mapper(a):
+            keys = [a[o] for o in parsed_offsets[0][0]]
+            vals = [a[o] for o in parsed_offsets[0][1]]
+            return f(keys, vals)
+
+    elif eval_type == PythonEvalType.SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE:
+        # We assume there is only one UDF here because grouped map doesn't
+        # support combining multiple UDFs.
+        assert num_udfs == 1
+
+        # See FlatMapGroupsInPandas(WithState)Exec for how arg_offsets are 
used to
+        # distinguish between grouping attributes and data attributes
+        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+        parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+        def mapper(a):
+            """
+            The function receives (iterator of data, state) and performs 
extraction of key and
+            value from the data, with retaining lazy evaluation.
+
+            See `load_stream` in `ApplyInPandasWithStateSerializer` for more 
details on the input
+            and see `wrap_grouped_map_pandas_udf_with_state` for more details 
on how output will
+            be used.
+            """
+            from itertools import tee
+
+            state = a[1]
+            data_gen = (x[0] for x in a[0])
+
+            # We know there should be at least one item in the 
iterator/generator.
+            # We want to peek the first element to construct the key, hence 
applying
+            # tee to construct the key while we retain another 
iterator/generator
+            # for values.
+            keys_gen, values_gen = tee(data_gen)
+            keys_elem = next(keys_gen)
+            keys = [keys_elem[o] for o in parsed_offsets[0][0]]
+
+            # This must be generator comprehension - do not materialize.
+            vals = ([x[o] for o in parsed_offsets[0][1]] for x in values_gen)
+
+            return f(keys, vals, state)
+
+    elif eval_type == PythonEvalType.SQL_COGROUPED_MAP_PANDAS_UDF:
+        # We assume there is only one UDF here because cogrouped map doesn't
+        # support combining multiple UDFs.
+        assert num_udfs == 1
+        arg_offsets, f = read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=0)
+
+        parsed_offsets = extract_key_value_indexes(arg_offsets)
+
+        def mapper(a):
+            df1_keys = [a[0][o] for o in parsed_offsets[0][0]]
+            df1_vals = [a[0][o] for o in parsed_offsets[0][1]]
+            df2_keys = [a[1][o] for o in parsed_offsets[1][0]]
+            df2_vals = [a[1][o] for o in parsed_offsets[1][1]]
+            return f(df1_keys, df1_vals, df2_keys, df2_vals)
+
+    else:
+        udfs = []
+        for i in range(num_udfs):
+            udfs.append(read_single_udf(pickleSer, infile, eval_type, 
runner_conf, udf_index=i))
+
+        def mapper(a):
+            result = tuple(f(*[a[o] for o in arg_offsets]) for (arg_offsets, 
f) in udfs)
+            # In the special case of a single UDF this will return a single 
result rather
+            # than a tuple of results; this is the format that the JVM side 
expects.
+            if len(result) == 1:
+                return result[0]
+            else:
+                return result
+
+    def func(_, it):
+        return map(mapper, it)
+
+    # profiling is not supported for UDF
+    return func, None, ser, ser
+
+
+def main(infile, outfile):
+    faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None)
+    try:
+        if faulthandler_log_path:
+            faulthandler_log_path = os.path.join(faulthandler_log_path, 
str(os.getpid()))
+            faulthandler_log_file = open(faulthandler_log_path, "w")
+            faulthandler.enable(file=faulthandler_log_file)
+
+        boot_time = time.time()
+        split_index = read_int(infile)
+        if split_index == -1:  # for unit tests
+            sys.exit(-1)
+
+        version = utf8_deserializer.loads(infile)
+        if version != "%d.%d" % sys.version_info[:2]:
+            raise PySparkRuntimeError(
+                error_class="PYTHON_VERSION_MISMATCH",
+                message_parameters={
+                    "worker_version": str(sys.version_info[:2]),
+                    "driver_version": str(version),
+                },
+            )
+
+        # read inputs only for a barrier task
+        isBarrier = read_bool(infile)
+        boundPort = read_int(infile)
+        secret = UTF8Deserializer().loads(infile)
+
+        # set up memory limits
+        memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", 
"-1"))
+        if memory_limit_mb > 0 and has_resource_module:
+            total_memory = resource.RLIMIT_AS
+            try:
+                (soft_limit, hard_limit) = resource.getrlimit(total_memory)
+                msg = "Current mem limits: {0} of max 
{1}\n".format(soft_limit, hard_limit)
+                print(msg, file=sys.stderr)
+
+                # convert to bytes
+                new_limit = memory_limit_mb * 1024 * 1024
+
+                if soft_limit == resource.RLIM_INFINITY or new_limit < 
soft_limit:
+                    msg = "Setting mem limits to {0} of max 
{1}\n".format(new_limit, new_limit)
+                    print(msg, file=sys.stderr)
+                    resource.setrlimit(total_memory, (new_limit, new_limit))
+
+            except (resource.error, OSError, ValueError) as e:
+                # not all systems support resource limits, so warn instead of 
failing
+                lineno = (
+                    getframeinfo(currentframe()).lineno + 1 if currentframe() 
is not None else 0
+                )
+                if "__file__" in globals():
+                    print(
+                        warnings.formatwarning(
+                            "Failed to set memory limit: {0}".format(e),
+                            ResourceWarning,
+                            __file__,
+                            lineno,
+                        ),
+                        file=sys.stderr,
+                    )
+
+        # initialize global state
+        taskContext = None
+        if isBarrier:
+            taskContext = BarrierTaskContext._getOrCreate()
+            BarrierTaskContext._initialize(boundPort, secret)
+            # Set the task context instance here, so we can get it by 
TaskContext.get for
+            # both TaskContext and BarrierTaskContext
+            TaskContext._setTaskContext(taskContext)
+        else:
+            taskContext = TaskContext._getOrCreate()
+        # read inputs for TaskContext info
+        taskContext._stageId = read_int(infile)
+        taskContext._partitionId = read_int(infile)
+        taskContext._attemptNumber = read_int(infile)
+        taskContext._taskAttemptId = read_long(infile)
+        taskContext._cpus = read_int(infile)
+        taskContext._resources = {}
+        for r in range(read_int(infile)):
+            key = utf8_deserializer.loads(infile)
+            name = utf8_deserializer.loads(infile)
+            addresses = []
+            taskContext._resources = {}
+            for a in range(read_int(infile)):
+                addresses.append(utf8_deserializer.loads(infile))
+            taskContext._resources[key] = ResourceInformation(name, addresses)
+
+        taskContext._localProperties = dict()
+        for i in range(read_int(infile)):
+            k = utf8_deserializer.loads(infile)
+            v = utf8_deserializer.loads(infile)
+            taskContext._localProperties[k] = v
+
+        shuffle.MemoryBytesSpilled = 0
+        shuffle.DiskBytesSpilled = 0
+        _accumulatorRegistry.clear()
+
+        # fetch name of workdir
+        spark_files_dir = utf8_deserializer.loads(infile)
+        SparkFiles._root_directory = spark_files_dir
+        SparkFiles._is_running_on_worker = True
+
+        # fetch names of includes (*.zip and *.egg files) and construct 
PYTHONPATH
+        add_path(spark_files_dir)  # *.py files that were added will be copied 
here
+        num_python_includes = read_int(infile)
+        for _ in range(num_python_includes):
+            filename = utf8_deserializer.loads(infile)
+            add_path(os.path.join(spark_files_dir, filename))
+
+        importlib.invalidate_caches()
+
+        # fetch names and values of broadcast variables
+        needs_broadcast_decryption_server = read_bool(infile)
+        num_broadcast_variables = read_int(infile)
+        if needs_broadcast_decryption_server:
+            # read the decrypted data from a server in the jvm
+            port = read_int(infile)
+            auth_secret = utf8_deserializer.loads(infile)
+            (broadcast_sock_file, _) = local_connect_and_auth(port, 
auth_secret)
+
+        for _ in range(num_broadcast_variables):
+            bid = read_long(infile)
+            if bid >= 0:
+                if needs_broadcast_decryption_server:
+                    read_bid = read_long(broadcast_sock_file)
+                    assert read_bid == bid
+                    _broadcastRegistry[bid] = 
Broadcast(sock_file=broadcast_sock_file)
+                else:
+                    path = utf8_deserializer.loads(infile)
+                    _broadcastRegistry[bid] = Broadcast(path=path)
+
+            else:
+                bid = -bid - 1
+                _broadcastRegistry.pop(bid)
+
+        if needs_broadcast_decryption_server:
+            broadcast_sock_file.write(b"1")
+            broadcast_sock_file.close()
+
+        _accumulatorRegistry.clear()
+        eval_type = read_int(infile)
+        if eval_type == PythonEvalType.NON_UDF:
+            func, profiler, deserializer, serializer = read_command(pickleSer, 
infile)
+        else:
+            func, profiler, deserializer, serializer = read_udfs(pickleSer, 
infile, eval_type)
+
+        init_time = time.time()
+
+        def process():
+            iterator = deserializer.load_stream(infile)
+            out_iter = func(split_index, iterator)
+            try:
+                serializer.dump_stream(out_iter, outfile)
+            finally:
+                if hasattr(out_iter, "close"):
+                    out_iter.close()
+
+        if profiler:
+            profiler.profile(process)
+        else:
+            process()
+
+        # Reset task context to None. This is a guard code to avoid residual 
context when worker
+        # reuse.
+        TaskContext._setTaskContext(None)
+        BarrierTaskContext._setTaskContext(None)
+    except BaseException as e:
+        try:
+            exc_info = None
+            if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False):
+                tb = try_simplify_traceback(sys.exc_info()[-1])
+                if tb is not None:
+                    e.__cause__ = None
+                    exc_info = "".join(traceback.format_exception(type(e), e, 
tb))
+            if exc_info is None:
+                exc_info = traceback.format_exc()
+
+            write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile)
+            write_with_length(exc_info.encode("utf-8"), outfile)
+        except IOError:
+            # JVM close the socket
+            pass
+        except BaseException:
+            # Write the error to stderr if it happened while serializing
+            print("PySpark worker failed with exception:", file=sys.stderr)
+            print(traceback.format_exc(), file=sys.stderr)
+        sys.exit(-1)
+    finally:
+        if faulthandler_log_path:
+            faulthandler.disable()
+            faulthandler_log_file.close()
+            os.remove(faulthandler_log_path)
+    finish_time = time.time()
+    report_times(outfile, boot_time, init_time, finish_time)
+    write_long(shuffle.MemoryBytesSpilled, outfile)
+    write_long(shuffle.DiskBytesSpilled, outfile)
+
+    # Mark the beginning of the accumulators section of the output
+    write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+    write_int(len(_accumulatorRegistry), outfile)
+    for (aid, accum) in _accumulatorRegistry.items():
+        pickleSer._write_with_length((aid, accum._value), outfile)
+
+    # check end of stream
+    if read_int(infile) == SpecialLengths.END_OF_STREAM:
+        write_int(SpecialLengths.END_OF_STREAM, outfile)
+    else:
+        # write a different value to tell JVM to not reuse this worker
+        write_int(SpecialLengths.END_OF_DATA_SECTION, outfile)
+        sys.exit(-1)
+
+
+if __name__ == "__main__":
+    # Read information about how to connect back to the JVM from the 
environment.
+    java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"])
+    auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"]
+    (sock_file, _) = local_connect_and_auth(java_port, auth_secret)
+    # TODO: Remove the following two lines and use `Process.pid()` when we 
drop JDK 8.
+    write_int(os.getpid(), sock_file)
+    sock_file.flush()
+    main(sock_file, sock_file)
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala 
b/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala
new file mode 100644
index 0000000000..9449a291f5
--- /dev/null
+++ b/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala
@@ -0,0 +1,495 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+package org.apache.spark
+
+import java.io.File
+import java.net.Socket
+import java.util.Locale
+
+import scala.collection.JavaConverters._
+import scala.collection.concurrent
+import scala.collection.mutable
+import scala.util.Properties
+
+import com.google.common.cache.CacheBuilder
+import org.apache.hadoop.conf.Configuration
+
+import org.apache.spark.annotation.DeveloperApi
+import org.apache.spark.api.python.PythonWorkerFactory
+import org.apache.spark.broadcast.BroadcastManager
+import org.apache.spark.executor.ExecutorBackend
+import org.apache.spark.internal.{config, Logging}
+import org.apache.spark.internal.config._
+import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager}
+import org.apache.spark.metrics.{MetricsSystem, MetricsSystemInstances}
+import org.apache.spark.network.netty.{NettyBlockTransferService, 
SparkTransportConf}
+import org.apache.spark.network.shuffle.ExternalBlockStoreClient
+import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv}
+import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator}
+import 
org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint
+import org.apache.spark.security.CryptoStreamUtils
+import org.apache.spark.serializer.{JavaSerializer, Serializer, 
SerializerManager}
+import org.apache.spark.shuffle.ShuffleManager
+import org.apache.spark.storage._
+import org.apache.spark.util.{RpcUtils, Utils}
+
+/**
+ * :: DeveloperApi ::
+ * Holds all the runtime environment objects for a running Spark instance 
(either master or worker),
+ * including the serializer, RpcEnv, block manager, map output tracker, etc. 
Currently
+ * Spark code finds the SparkEnv through a global variable, so all the threads 
can access the same
+ * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a 
SparkContext).
+ */
+@DeveloperApi
+class SedonaSparkEnv (
+                 val executorId: String,
+                 private[spark] val rpcEnv: RpcEnv,
+                 val serializer: Serializer,
+                 val closureSerializer: Serializer,
+                 val serializerManager: SerializerManager,
+                 val mapOutputTracker: MapOutputTracker,
+                 val shuffleManager: ShuffleManager,
+                 val broadcastManager: BroadcastManager,
+                 val blockManager: BlockManager,
+                 val securityManager: SecurityManager,
+                 val metricsSystem: MetricsSystem,
+                 val memoryManager: MemoryManager,
+                 val outputCommitCoordinator: OutputCommitCoordinator,
+                 val conf: SparkConf) extends Logging {
+
+  @volatile private[spark] var isStopped = false
+  private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), 
PythonWorkerFactory]()
+
+  // A general, soft-reference map for metadata needed during HadoopRDD split 
computation
+  // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats).
+  private[spark] val hadoopJobMetadata =
+    CacheBuilder.newBuilder().maximumSize(1000).softValues().build[String, 
AnyRef]().asMap()
+
+  private[spark] var driverTmpDir: Option[String] = None
+
+  private[spark] var executorBackend: Option[ExecutorBackend] = None
+
+  private[spark] def stop(): Unit = {
+
+    if (!isStopped) {
+      isStopped = true
+      pythonWorkers.values.foreach(_.stop())
+      mapOutputTracker.stop()
+      shuffleManager.stop()
+      broadcastManager.stop()
+      blockManager.stop()
+      blockManager.master.stop()
+      metricsSystem.stop()
+      outputCommitCoordinator.stop()
+      rpcEnv.shutdown()
+      rpcEnv.awaitTermination()
+
+      // If we only stop sc, but the driver process still run as a services 
then we need to delete
+      // the tmp dir, if not, it will create too many tmp dirs.
+      // We only need to delete the tmp dir create by driver
+      driverTmpDir match {
+        case Some(path) =>
+          try {
+            Utils.deleteRecursively(new File(path))
+          } catch {
+            case e: Exception =>
+              logWarning(s"Exception while deleting Spark temp dir: $path", e)
+          }
+        case None => // We just need to delete tmp dir created by driver, so 
do nothing on executor
+      }
+    }
+  }
+
+  private[spark]
+  def createPythonWorker(
+                          pythonExec: String,
+                          envVars: Map[String, String]): (java.net.Socket, 
Option[Int]) = {
+    synchronized {
+      val key = (pythonExec, envVars)
+      pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, 
envVars)).create()
+    }
+  }
+
+  private[spark]
+  def destroyPythonWorker(pythonExec: String,
+                          envVars: Map[String, String], worker: Socket): Unit 
= {
+    synchronized {
+      val key = (pythonExec, envVars)
+      pythonWorkers.get(key).foreach(_.stopWorker(worker))
+    }
+  }
+
+  private[spark]
+  def releasePythonWorker(pythonExec: String,
+                          envVars: Map[String, String], worker: Socket): Unit 
= {
+    synchronized {
+      val key = (pythonExec, envVars)
+      pythonWorkers.get(key).foreach(_.releaseWorker(worker))
+    }
+  }
+}
+
+object SedonaSparkEnv extends Logging {
+  @volatile private var env: SedonaSparkEnv = _
+
+  private[spark] val driverSystemName = "sparkDriver"
+  private[spark] val executorSystemName = "sparkExecutor"
+
+  def set(e: SedonaSparkEnv): Unit = {
+    env = e
+  }
+
+  /**
+   * Returns the SparkEnv.
+   */
+  def get: SedonaSparkEnv = {
+    env
+  }
+
+  /**
+   * Create a SparkEnv for the driver.
+   */
+  private[spark] def createDriverEnv(
+                                      conf: SparkConf,
+                                      isLocal: Boolean,
+                                      listenerBus: LiveListenerBus,
+                                      numCores: Int,
+                                      sparkContext: SparkContext,
+                                      mockOutputCommitCoordinator: 
Option[OutputCommitCoordinator] = None): SparkEnv = {
+    assert(conf.contains(DRIVER_HOST_ADDRESS),
+      s"${DRIVER_HOST_ADDRESS.key} is not set on the driver!")
+    assert(conf.contains(DRIVER_PORT), s"${DRIVER_PORT.key} is not set on the 
driver!")
+    val bindAddress = conf.get(DRIVER_BIND_ADDRESS)
+    val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS)
+    val port = conf.get(DRIVER_PORT)
+    val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) {
+      Some(CryptoStreamUtils.createKey(conf))
+    } else {
+      None
+    }
+    create(
+      conf,
+      SparkContext.DRIVER_IDENTIFIER,
+      bindAddress,
+      advertiseAddress,
+      Option(port),
+      isLocal,
+      numCores,
+      ioEncryptionKey,
+      listenerBus = listenerBus,
+      Option(sparkContext),
+      mockOutputCommitCoordinator = mockOutputCommitCoordinator
+    )
+  }
+
+  /**
+   * Create a SparkEnv for an executor.
+   * In coarse-grained mode, the executor provides an RpcEnv that is already 
instantiated.
+   */
+  private[spark] def createExecutorEnv(
+                                        conf: SparkConf,
+                                        executorId: String,
+                                        bindAddress: String,
+                                        hostname: String,
+                                        numCores: Int,
+                                        ioEncryptionKey: Option[Array[Byte]],
+                                        isLocal: Boolean): SparkEnv = {
+    val env = create(
+      conf,
+      executorId,
+      bindAddress,
+      hostname,
+      None,
+      isLocal,
+      numCores,
+      ioEncryptionKey
+    )
+    SparkEnv.set(env)
+    env
+  }
+
+  private[spark] def createExecutorEnv(
+                                        conf: SparkConf,
+                                        executorId: String,
+                                        hostname: String,
+                                        numCores: Int,
+                                        ioEncryptionKey: Option[Array[Byte]],
+                                        isLocal: Boolean): SparkEnv = {
+    createExecutorEnv(conf, executorId, hostname,
+      hostname, numCores, ioEncryptionKey, isLocal)
+  }
+
+  /**
+   * Helper method to create a SparkEnv for a driver or an executor.
+   */
+  // scalastyle:off argcount
+  private def create(
+                      conf: SparkConf,
+                      executorId: String,
+                      bindAddress: String,
+                      advertiseAddress: String,
+                      port: Option[Int],
+                      isLocal: Boolean,
+                      numUsableCores: Int,
+                      ioEncryptionKey: Option[Array[Byte]],
+                      listenerBus: LiveListenerBus = null,
+                      sc: Option[SparkContext] = None,
+                      mockOutputCommitCoordinator: 
Option[OutputCommitCoordinator] = None): SparkEnv = {
+    // scalastyle:on argcount
+
+    val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER
+
+    // Listener bus is only used on the driver
+    if (isDriver) {
+      assert(listenerBus != null, "Attempted to create driver SparkEnv with 
null listener bus!")
+    }
+    val authSecretFileConf = if (isDriver) AUTH_SECRET_FILE_DRIVER else 
AUTH_SECRET_FILE_EXECUTOR
+    val securityManager = new SecurityManager(conf, ioEncryptionKey, 
authSecretFileConf)
+    if (isDriver) {
+      securityManager.initializeAuth()
+    }
+
+    ioEncryptionKey.foreach { _ =>
+      if (!securityManager.isEncryptionEnabled()) {
+        logWarning("I/O encryption enabled without RPC encryption: keys will 
be visible on the " +
+          "wire.")
+      }
+    }
+
+    val systemName = if (isDriver) driverSystemName else executorSystemName
+    val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, 
port.getOrElse(-1), conf,
+      securityManager, numUsableCores, !isDriver)
+
+    // Figure out which port RpcEnv actually bound to in case the original 
port is 0 or occupied.
+    if (isDriver) {
+      conf.set(DRIVER_PORT, rpcEnv.address.port)
+    }
+
+    val serializer = 
Utils.instantiateSerializerFromConf[Serializer](SERIALIZER, conf, isDriver)
+    logDebug(s"Using serializer: ${serializer.getClass}")
+
+    val serializerManager = new SerializerManager(serializer, conf, 
ioEncryptionKey)
+
+    val closureSerializer = new JavaSerializer(conf)
+
+    def registerOrLookupEndpoint(
+                                  name: String, endpointCreator: => 
RpcEndpoint):
+    RpcEndpointRef = {
+      if (isDriver) {
+        logInfo("Registering " + name)
+        rpcEnv.setupEndpoint(name, endpointCreator)
+      } else {
+        RpcUtils.makeDriverRef(name, conf, rpcEnv)
+      }
+    }
+
+    val broadcastManager = new BroadcastManager(isDriver, conf)
+
+    val mapOutputTracker = if (isDriver) {
+      new MapOutputTrackerMaster(conf, broadcastManager, isLocal)
+    } else {
+      new MapOutputTrackerWorker(conf)
+    }
+
+    // Have to assign trackerEndpoint after initialization as 
MapOutputTrackerEndpoint
+    // requires the MapOutputTracker itself
+    mapOutputTracker.trackerEndpoint = 
registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME,
+      new MapOutputTrackerMasterEndpoint(
+        rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf))
+
+    // Let the user specify short names for shuffle managers
+    val shortShuffleMgrNames = Map(
+      "sort" -> 
classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName,
+      "tungsten-sort" -> 
classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName)
+    val shuffleMgrName = conf.get(config.SHUFFLE_MANAGER)
+    val shuffleMgrClass =
+      shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), 
shuffleMgrName)
+    val shuffleManager = 
Utils.instantiateSerializerOrShuffleManager[ShuffleManager](
+      shuffleMgrClass, conf, isDriver)
+
+    val memoryManager: MemoryManager = UnifiedMemoryManager(conf, 
numUsableCores)
+
+    val blockManagerPort = if (isDriver) {
+      conf.get(DRIVER_BLOCK_MANAGER_PORT)
+    } else {
+      conf.get(BLOCK_MANAGER_PORT)
+    }
+
+    val externalShuffleClient = if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) {
+      val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", 
numUsableCores)
+      Some(new ExternalBlockStoreClient(transConf, securityManager,
+        securityManager.isAuthenticationEnabled(), 
conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT)))
+    } else {
+      None
+    }
+
+    // Mapping from block manager id to the block manager's information.
+    val blockManagerInfo = new concurrent.TrieMap[BlockManagerId, 
BlockManagerInfo]()
+    val blockManagerMaster = new BlockManagerMaster(
+      registerOrLookupEndpoint(
+        BlockManagerMaster.DRIVER_ENDPOINT_NAME,
+        new BlockManagerMasterEndpoint(
+          rpcEnv,
+          isLocal,
+          conf,
+          listenerBus,
+          if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) {
+            externalShuffleClient
+          } else {
+            None
+          }, blockManagerInfo,
+          mapOutputTracker.asInstanceOf[MapOutputTrackerMaster],
+          shuffleManager,
+          isDriver)),
+      registerOrLookupEndpoint(
+        BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME,
+        new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, 
blockManagerInfo)),
+      conf,
+      isDriver)
+
+    val blockTransferService =
+      new NettyBlockTransferService(conf, securityManager, serializerManager, 
bindAddress,
+        advertiseAddress, blockManagerPort, numUsableCores, 
blockManagerMaster.driverEndpoint)
+
+    // NB: blockManager is not valid until initialize() is called later.
+    val blockManager = new BlockManager(
+      executorId,
+      rpcEnv,
+      blockManagerMaster,
+      serializerManager,
+      conf,
+      memoryManager,
+      mapOutputTracker,
+      shuffleManager,
+      blockTransferService,
+      securityManager,
+      externalShuffleClient)
+
+    val metricsSystem = if (isDriver) {
+      // Don't start metrics system right now for Driver.
+      // We need to wait for the task scheduler to give us an app ID.
+      // Then we can start the metrics system.
+      MetricsSystem.createMetricsSystem(MetricsSystemInstances.DRIVER, conf)
+    } else {
+      // We need to set the executor ID before the MetricsSystem is created 
because sources and
+      // sinks specified in the metrics configuration file will want to 
incorporate this executor's
+      // ID into the metrics they report.
+      conf.set(EXECUTOR_ID, executorId)
+      val ms = 
MetricsSystem.createMetricsSystem(MetricsSystemInstances.EXECUTOR, conf)
+      ms.start(conf.get(METRICS_STATIC_SOURCES_ENABLED))
+      ms
+    }
+
+    val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse {
+      if (isDriver) {
+        new OutputCommitCoordinator(conf, isDriver, sc)
+      } else {
+        new OutputCommitCoordinator(conf, isDriver)
+      }
+
+    }
+    val outputCommitCoordinatorRef = 
registerOrLookupEndpoint("OutputCommitCoordinator",
+      new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator))
+    outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef)
+
+    val envInstance = new SparkEnv(
+      executorId,
+      rpcEnv,
+      serializer,
+      closureSerializer,
+      serializerManager,
+      mapOutputTracker,
+      shuffleManager,
+      broadcastManager,
+      blockManager,
+      securityManager,
+      metricsSystem,
+      memoryManager,
+      outputCommitCoordinator,
+      conf)
+
+    // Add a reference to tmp dir created by driver, we will delete this tmp 
dir when stop() is
+    // called, and we only need to do it for driver. Because driver may run as 
a service, and if we
+    // don't delete this tmp dir when sc is stopped, then will create too many 
tmp dirs.
+    if (isDriver) {
+      val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), 
"userFiles").getAbsolutePath
+      envInstance.driverTmpDir = Some(sparkFilesDir)
+    }
+
+    envInstance
+  }
+
+  /**
+   * Return a map representation of jvm information, Spark properties, system 
properties, and
+   * class paths. Map keys define the category, and map values represent the 
corresponding
+   * attributes as a sequence of KV pairs. This is used mainly for 
SparkListenerEnvironmentUpdate.
+   */
+  private[spark] def environmentDetails(
+                                         conf: SparkConf,
+                                         hadoopConf: Configuration,
+                                         schedulingMode: String,
+                                         addedJars: Seq[String],
+                                         addedFiles: Seq[String],
+                                         addedArchives: Seq[String],
+                                         metricsProperties: Map[String, 
String]): Map[String, Seq[(String, String)]] = {
+
+    import Properties._
+    val jvmInformation = Seq(
+      ("Java Version", s"$javaVersion ($javaVendor)"),
+      ("Java Home", javaHome),
+      ("Scala Version", versionString)
+    ).sorted
+
+    // Spark properties
+    // This includes the scheduling mode whether or not it is configured (used 
by SparkUI)
+    val schedulerMode =
+      if (!conf.contains(SCHEDULER_MODE)) {
+        Seq((SCHEDULER_MODE.key, schedulingMode))
+      } else {
+        Seq.empty[(String, String)]
+      }
+    val sparkProperties = (conf.getAll ++ schedulerMode).sorted
+
+    // System properties that are not java classpaths
+    val systemProperties = Utils.getSystemProperties.toSeq
+    val otherProperties = systemProperties.filter { case (k, _) =>
+      k != "java.class.path" && !k.startsWith("spark.")
+    }.sorted
+
+    // Class paths including all added jars and files
+    val classPathEntries = javaClassPath
+      .split(File.pathSeparator)
+      .filterNot(_.isEmpty)
+      .map((_, "System Classpath"))
+    val addedJarsAndFiles = (addedJars ++ addedFiles ++ addedArchives).map((_, 
"Added By User"))
+    val classPaths = (addedJarsAndFiles ++ classPathEntries).sorted
+
+    // Add Hadoop properties, it will not ignore configs including in Spark. 
Some spark
+    // conf starting with "spark.hadoop" may overwrite it.
+    val hadoopProperties = hadoopConf.asScala
+      .map(entry => (entry.getKey, entry.getValue)).toSeq.sorted
+    Map[String, Seq[(String, String)]](
+      "JVM Information" -> jvmInformation,
+      "Spark Properties" -> sparkProperties,
+      "Hadoop Properties" -> hadoopProperties,
+      "System Properties" -> otherProperties,
+      "Classpath Entries" -> classPaths,
+      "Metrics Properties" -> metricsProperties.toSeq.sorted)
+  }
+}
+
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
new file mode 100644
index 0000000000..026518272c
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
@@ -0,0 +1,811 @@
+package org.apache.spark.api.python
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.spark._
+import org.apache.spark.SedonaSparkEnv
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.Python._
+import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
+import 
org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, 
PYSPARK_MEMORY_LOCAL_PROPERTY}
+import org.apache.spark.security.SocketAuthHelper
+import org.apache.spark.util._
+
+import java.io._
+import java.net._
+import java.nio.charset.StandardCharsets
+import java.nio.charset.StandardCharsets.UTF_8
+import java.nio.file.{Path, Files => JavaFiles}
+import java.util.concurrent.ConcurrentHashMap
+import java.util.concurrent.atomic.AtomicBoolean
+import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
+
+
+/**
+ * Enumerate the type of command that will be sent to the Python worker
+ */
+private[spark] object PythonEvalType {
+  val NON_UDF = 0
+
+  val SQL_BATCHED_UDF = 100
+  val SQL_ARROW_BATCHED_UDF = 101
+
+  val SQL_SCALAR_PANDAS_UDF = 200
+  val SQL_GROUPED_MAP_PANDAS_UDF = 201
+  val SQL_GROUPED_AGG_PANDAS_UDF = 202
+  val SQL_WINDOW_AGG_PANDAS_UDF = 203
+  val SQL_SCALAR_PANDAS_ITER_UDF = 204
+  val SQL_MAP_PANDAS_ITER_UDF = 205
+  val SQL_COGROUPED_MAP_PANDAS_UDF = 206
+  val SQL_MAP_ARROW_ITER_UDF = 207
+  val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
+
+  val SQL_TABLE_UDF = 300
+  val SQL_ARROW_TABLE_UDF = 301
+
+  def toString(pythonEvalType: Int): String = pythonEvalType match {
+    case NON_UDF => "NON_UDF"
+    case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
+    case SQL_ARROW_BATCHED_UDF => "SQL_ARROW_BATCHED_UDF"
+    case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF"
+    case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF"
+    case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
+    case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
+    case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
+    case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
+    case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
+    case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
+    case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => 
"SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
+    case SQL_TABLE_UDF => "SQL_TABLE_UDF"
+    case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF"
+  }
+}
+
+private object SedonaBasePythonRunner {
+
+  private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = 
"faulthandler")
+
+  private def faultHandlerLogPath(pid: Int): Path = {
+    new File(faultHandlerLogDir, pid.toString).toPath
+  }
+}
+
+/**
+ * A helper class to run Python mapPartition/UDFs in Spark.
+ *
+ * funcs is a list of independent Python functions, each one of them is a list 
of chained Python
+ * functions (from bottom to top).
+ */
+private[spark] abstract class SedonaBasePythonRunner[IN, OUT](
+                                                         protected val funcs: 
Seq[ChainedPythonFunctions],
+                                                         protected val 
evalType: Int,
+                                                         protected val 
argOffsets: Array[Array[Int]],
+                                                         protected val 
jobArtifactUUID: Option[String])
+  extends Logging {
+
+  require(funcs.length == argOffsets.length, "argOffsets should have the same 
length as funcs")
+
+  private val conf = SparkEnv.get.conf
+  protected val bufferSize: Int = conf.get(BUFFER_SIZE)
+  protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT)
+  private val reuseWorker = conf.get(PYTHON_WORKER_REUSE)
+  private val faultHandlerEnabled = 
conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED)
+  protected val simplifiedTraceback: Boolean = false
+
+  // All the Python functions should have the same exec, version and envvars.
+  protected val envVars: java.util.Map[String, String] = 
funcs.head.funcs.head.envVars
+  protected val pythonExec: String = funcs.head.funcs.head.pythonExec
+  protected val pythonVer: String = funcs.head.funcs.head.pythonVer
+
+  // TODO: support accumulator in multiple UDF
+  protected val accumulator: PythonAccumulatorV2 = 
funcs.head.funcs.head.accumulator
+
+  // Python accumulator is always set in production except in tests. See 
SPARK-27893
+  private val maybeAccumulator: Option[PythonAccumulatorV2] = 
Option(accumulator)
+
+  // Expose a ServerSocket to support method calls via socket from Python side.
+  private[spark] var serverSocket: Option[ServerSocket] = None
+
+  // Authentication helper used when serving method calls via socket from 
Python side.
+  private lazy val authHelper = new SocketAuthHelper(conf)
+
+  // each python worker gets an equal part of the allocation. the worker pool 
will grow to the
+  // number of concurrent tasks, which is determined by the number of cores in 
this executor.
+  private def getWorkerMemoryMb(mem: Option[Long], cores: Int): Option[Long] = 
{
+    mem.map(_ / cores)
+  }
+
+  def compute(
+               inputIterator: Iterator[IN],
+               partitionIndex: Int,
+               context: TaskContext): Iterator[OUT] = {
+    val startTime = System.currentTimeMillis
+    val sedonaEnv = SedonaSparkEnv.get
+    val env = SparkEnv.get
+
+    // Get the executor cores and pyspark memory, they are passed via the 
local properties when
+    // the user specified them in a ResourceProfile.
+    val execCoresProp = 
Option(context.getLocalProperty(EXECUTOR_CORES_LOCAL_PROPERTY))
+    val memoryMb = 
Option(context.getLocalProperty(PYSPARK_MEMORY_LOCAL_PROPERTY)).map(_.toLong)
+    val localdir = env.blockManager.diskBlockManager.localDirs.map(f => 
f.getPath()).mkString(",")
+    // If OMP_NUM_THREADS is not explicitly set, override it with the number 
of task cpus.
+    // See SPARK-42613 for details.
+    if (conf.getOption("spark.executorEnv.OMP_NUM_THREADS").isEmpty) {
+      envVars.put("OMP_NUM_THREADS", conf.get("spark.task.cpus", "1"))
+    }
+    envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor 
thread
+    if (reuseWorker) {
+      envVars.put("SPARK_REUSE_WORKER", "1")
+    }
+    if (simplifiedTraceback) {
+      envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1")
+    }
+    // SPARK-30299 this could be wrong with standalone mode when executor
+    // cores might not be correct because it defaults to all cores on the box.
+    val execCores = 
execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES))
+    val workerMemoryMb = getWorkerMemoryMb(memoryMb, execCores)
+    if (workerMemoryMb.isDefined) {
+      envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString)
+    }
+    envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString)
+    envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString)
+    if (faultHandlerEnabled) {
+      envVars.put("PYTHON_FAULTHANDLER_DIR", 
SedonaBasePythonRunner.faultHandlerLogDir.toString)
+    }
+
+    envVars.put("SPARK_JOB_ARTIFACT_UUID", 
jobArtifactUUID.getOrElse("default"))
+
+    val (worker: Socket, pid: Option[Int]) = env.createPythonWorker(
+      pythonExec, envVars.asScala.toMap)
+    // Whether is the worker released into idle pool or closed. When any codes 
try to release or
+    // close a worker, they should use `releasedOrClosed.compareAndSet` to 
flip the state to make
+    // sure there is only one winner that is going to release or close the 
worker.
+    val releasedOrClosed = new AtomicBoolean(false)
+
+    // Start a thread to feed the process input from our parent's iterator
+    val writerThread = newWriterThread(env, worker, inputIterator, 
partitionIndex, context)
+
+    context.addTaskCompletionListener[Unit] { _ =>
+      writerThread.shutdownOnTaskCompletion()
+      if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) {
+        try {
+          worker.close()
+        } catch {
+          case e: Exception =>
+            logWarning("Failed to close worker socket", e)
+        }
+      }
+    }
+
+    writerThread.start()
+    new WriterMonitorThread(SparkEnv.get, worker, writerThread, 
context).start()
+    if (reuseWorker) {
+      val key = (worker, context.taskAttemptId)
+      // SPARK-35009: avoid creating multiple monitor threads for the same 
python worker
+      // and task context
+      if (PythonRunner.runningMonitorThreads.add(key)) {
+        new MonitorThread(SparkEnv.get, worker, context).start()
+      }
+    } else {
+      new MonitorThread(SparkEnv.get, worker, context).start()
+    }
+
+    // Return an iterator that read lines from the process's stdout
+    val stream = new DataInputStream(new 
BufferedInputStream(worker.getInputStream, bufferSize))
+
+    val stdoutIterator = newReaderIterator(
+      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context)
+    new InterruptibleIterator(context, stdoutIterator)
+  }
+
+  protected def newWriterThread(
+                                 env: SparkEnv,
+                                 worker: Socket,
+                                 inputIterator: Iterator[IN],
+                                 partitionIndex: Int,
+                                 context: TaskContext): WriterThread
+
+  protected def newReaderIterator(
+                                   stream: DataInputStream,
+                                   writerThread: WriterThread,
+                                   startTime: Long,
+                                   env: SparkEnv,
+                                   worker: Socket,
+                                   pid: Option[Int],
+                                   releasedOrClosed: AtomicBoolean,
+                                   context: TaskContext): Iterator[OUT]
+
+  /**
+   * The thread responsible for writing the data from the PythonRDD's parent 
iterator to the
+   * Python process.
+   */
+  abstract class WriterThread(
+                               env: SparkEnv,
+                               worker: Socket,
+                               inputIterator: Iterator[IN],
+                               partitionIndex: Int,
+                               context: TaskContext)
+    extends Thread(s"stdout writer for $pythonExec") {
+
+    @volatile private var _exception: Throwable = null
+
+    private val pythonIncludes = 
funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
+    private val broadcastVars = 
funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
+
+    setDaemon(true)
+
+    /** Contains the throwable thrown while writing the parent iterator to the 
Python process. */
+    def exception: Option[Throwable] = Option(_exception)
+
+    /**
+     * Terminates the writer thread and waits for it to exit, ignoring any 
exceptions that may occur
+     * due to cleanup.
+     */
+    def shutdownOnTaskCompletion(): Unit = {
+      assert(context.isCompleted)
+      this.interrupt()
+      // Task completion listeners that run after this method returns may 
invalidate
+      // `inputIterator`. For example, when `inputIterator` was generated by 
the off-heap vectorized
+      // reader, a task completion listener will free the underlying off-heap 
buffers. If the writer
+      // thread is still running when `inputIterator` is invalidated, it can 
cause a use-after-free
+      // bug that crashes the executor (SPARK-33277). Therefore this method 
must wait for the writer
+      // thread to exit before returning.
+      this.join()
+    }
+
+    /**
+     * Writes a command section to the stream connected to the Python worker.
+     */
+    protected def writeCommand(dataOut: DataOutputStream): Unit
+
+    /**
+     * Writes input data to the stream connected to the Python worker.
+     */
+    protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
+
+    override def run(): Unit = Utils.logUncaughtExceptions {
+      try {
+        TaskContext.setTaskContext(context)
+        val stream = new BufferedOutputStream(worker.getOutputStream, 
bufferSize)
+        val dataOut = new DataOutputStream(stream)
+        // Partition index
+        dataOut.writeInt(partitionIndex)
+        // Python version of driver
+        PythonRDD.writeUTF(pythonVer, dataOut)
+        // Init a ServerSocket to accept method calls from Python side.
+        val isBarrier = context.isInstanceOf[BarrierTaskContext]
+        if (isBarrier) {
+          serverSocket = Some(new ServerSocket(/* port */ 0,
+            /* backlog */ 1,
+            InetAddress.getByName("localhost")))
+          // A call to accept() for ServerSocket shall block infinitely.
+          serverSocket.foreach(_.setSoTimeout(0))
+          new Thread("accept-connections") {
+            setDaemon(true)
+
+            override def run(): Unit = {
+              while (!serverSocket.get.isClosed()) {
+                var sock: Socket = null
+                try {
+                  sock = serverSocket.get.accept()
+                  // Wait for function call from python side.
+                  sock.setSoTimeout(10000)
+                  authHelper.authClient(sock)
+                  val input = new DataInputStream(sock.getInputStream())
+                  val requestMethod = input.readInt()
+                  // The BarrierTaskContext function may wait infinitely, 
socket shall not timeout
+                  // before the function finishes.
+                  sock.setSoTimeout(0)
+                  requestMethod match {
+                    case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
+                      barrierAndServe(requestMethod, sock)
+                    case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION 
=>
+                      val length = input.readInt()
+                      val message = new Array[Byte](length)
+                      input.readFully(message)
+                      barrierAndServe(requestMethod, sock, new String(message, 
UTF_8))
+                    case _ =>
+                      val out = new DataOutputStream(new BufferedOutputStream(
+                        sock.getOutputStream))
+                      
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
+                  }
+                } catch {
+                  case e: SocketException if e.getMessage.contains("Socket 
closed") =>
+                  // It is possible that the ServerSocket is not closed, but 
the native socket
+                  // has already been closed, we shall catch and silently 
ignore this case.
+                } finally {
+                  if (sock != null) {
+                    sock.close()
+                  }
+                }
+              }
+            }
+          }.start()
+        }
+        val secret = if (isBarrier) {
+          authHelper.secret
+        } else {
+          ""
+        }
+        // Close ServerSocket on task completion.
+        serverSocket.foreach { server =>
+          context.addTaskCompletionListener[Unit](_ => server.close())
+        }
+        val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
+        if (boundPort == -1) {
+          val message = "ServerSocket failed to bind to Java side."
+          logError(message)
+          throw new SparkException(message)
+        } else if (isBarrier) {
+          logDebug(s"Started ServerSocket on port $boundPort.")
+        }
+        // Write out the TaskContextInfo
+        dataOut.writeBoolean(isBarrier)
+        dataOut.writeInt(boundPort)
+        val secretBytes = secret.getBytes(UTF_8)
+        dataOut.writeInt(secretBytes.length)
+        dataOut.write(secretBytes, 0, secretBytes.length)
+        dataOut.writeInt(context.stageId())
+        dataOut.writeInt(context.partitionId())
+        dataOut.writeInt(context.attemptNumber())
+        dataOut.writeLong(context.taskAttemptId())
+        dataOut.writeInt(context.cpus())
+        val resources = context.resources()
+        dataOut.writeInt(resources.size)
+        resources.foreach { case (k, v) =>
+          PythonRDD.writeUTF(k, dataOut)
+          PythonRDD.writeUTF(v.name, dataOut)
+          dataOut.writeInt(v.addresses.size)
+          v.addresses.foreach { case addr =>
+            PythonRDD.writeUTF(addr, dataOut)
+          }
+        }
+        val localProps = context.getLocalProperties.asScala
+        dataOut.writeInt(localProps.size)
+        localProps.foreach { case (k, v) =>
+          PythonRDD.writeUTF(k, dataOut)
+          PythonRDD.writeUTF(v, dataOut)
+        }
+
+        // sparkFilesDir
+        val root = jobArtifactUUID.map { uuid =>
+          new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath
+        }.getOrElse(SparkFiles.getRootDirectory())
+        PythonRDD.writeUTF(root, dataOut)
+        // Python includes (*.zip and *.egg files)
+        dataOut.writeInt(pythonIncludes.size)
+        for (include <- pythonIncludes) {
+          PythonRDD.writeUTF(include, dataOut)
+        }
+        // Broadcast variables
+        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+        val newBids = broadcastVars.map(_.id).toSet
+        // number of different broadcasts
+        val toRemove = oldBids.diff(newBids)
+        val addedBids = newBids.diff(oldBids)
+        val cnt = toRemove.size + addedBids.size
+        val needsDecryptionServer = env.serializerManager.encryptionEnabled && 
addedBids.nonEmpty
+        dataOut.writeBoolean(needsDecryptionServer)
+        dataOut.writeInt(cnt)
+        def sendBidsToRemove(): Unit = {
+          for (bid <- toRemove) {
+            // remove the broadcast from worker
+            dataOut.writeLong(-bid - 1) // bid >= 0
+            oldBids.remove(bid)
+          }
+        }
+        if (needsDecryptionServer) {
+          // if there is encryption, we setup a server which reads the 
encrypted files, and sends
+          // the decrypted data to python
+          val idsAndFiles = broadcastVars.flatMap { broadcast =>
+            if (!oldBids.contains(broadcast.id)) {
+              oldBids.add(broadcast.id)
+              Some((broadcast.id, broadcast.value.path))
+            } else {
+              None
+            }
+          }
+          val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
+          dataOut.writeInt(server.port)
+          logTrace(s"broadcast decryption server setup on ${server.port}")
+          PythonRDD.writeUTF(server.secret, dataOut)
+          sendBidsToRemove()
+          idsAndFiles.foreach { case (id, _) =>
+            // send new broadcast
+            dataOut.writeLong(id)
+          }
+          dataOut.flush()
+          logTrace("waiting for python to read decrypted broadcast data from 
server")
+          server.waitTillBroadcastDataSent()
+          logTrace("done sending decrypted data to python")
+        } else {
+          sendBidsToRemove()
+          for (broadcast <- broadcastVars) {
+            if (!oldBids.contains(broadcast.id)) {
+              // send new broadcast
+              dataOut.writeLong(broadcast.id)
+              PythonRDD.writeUTF(broadcast.value.path, dataOut)
+              oldBids.add(broadcast.id)
+            }
+          }
+        }
+        dataOut.flush()
+
+        dataOut.writeInt(evalType)
+        writeCommand(dataOut)
+        writeIteratorToStream(dataOut)
+
+        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
+        dataOut.flush()
+      } catch {
+        case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) =>
+          if (context.isCompleted || context.isInterrupted) {
+            logDebug("Exception/NonFatal Error thrown after task completion 
(likely due to " +
+              "cleanup)", t)
+            if (!worker.isClosed) {
+              Utils.tryLog(worker.shutdownOutput())
+            }
+          } else {
+            // We must avoid throwing exceptions/NonFatals here, because the 
thread uncaught
+            // exception handler will kill the whole executor (see
+            // org.apache.spark.executor.Executor).
+            _exception = t
+            if (!worker.isClosed) {
+              Utils.tryLog(worker.shutdownOutput())
+            }
+          }
+      }
+    }
+
+    /**
+     * Gateway to call BarrierTaskContext methods.
+     */
+    def barrierAndServe(requestMethod: Int, sock: Socket, message: String = 
""): Unit = {
+      require(
+        serverSocket.isDefined,
+        "No available ServerSocket to redirect the BarrierTaskContext method 
call."
+      )
+      val out = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream))
+      try {
+        val messages = requestMethod match {
+          case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
+            context.asInstanceOf[BarrierTaskContext].barrier()
+            Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS)
+          case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
+            context.asInstanceOf[BarrierTaskContext].allGather(message)
+        }
+        out.writeInt(messages.length)
+        messages.foreach(writeUTF(_, out))
+      } catch {
+        case e: SparkException =>
+          writeUTF(e.getMessage, out)
+      } finally {
+        out.close()
+      }
+    }
+
+    def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
+      val bytes = str.getBytes(UTF_8)
+      dataOut.writeInt(bytes.length)
+      dataOut.write(bytes)
+    }
+  }
+
+  abstract class ReaderIterator(
+                                 stream: DataInputStream,
+                                 writerThread: WriterThread,
+                                 startTime: Long,
+                                 env: SparkEnv,
+                                 worker: Socket,
+                                 pid: Option[Int],
+                                 releasedOrClosed: AtomicBoolean,
+                                 context: TaskContext)
+    extends Iterator[OUT] {
+
+    private var nextObj: OUT = _
+    private var eos = false
+
+    override def hasNext: Boolean = nextObj != null || {
+      if (!eos) {
+        nextObj = read()
+        hasNext
+      } else {
+        false
+      }
+    }
+
+    override def next(): OUT = {
+      if (hasNext) {
+        val obj = nextObj
+        nextObj = null.asInstanceOf[OUT]
+        obj
+      } else {
+        Iterator.empty.next()
+      }
+    }
+
+    /**
+     * Reads next object from the stream.
+     * When the stream reaches end of data, needs to process the following 
sections,
+     * and then returns null.
+     */
+    protected def read(): OUT
+
+    protected def handleTimingData(): Unit = {
+      // Timing data from worker
+      val bootTime = stream.readLong()
+      val initTime = stream.readLong()
+      val finishTime = stream.readLong()
+      val boot = bootTime - startTime
+      val init = initTime - bootTime
+      val finish = finishTime - initTime
+      val total = finishTime - startTime
+      logInfo("Times: total = %s, boot = %s, init = %s, finish = 
%s".format(total, boot,
+        init, finish))
+      val memoryBytesSpilled = stream.readLong()
+      val diskBytesSpilled = stream.readLong()
+      context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled)
+      context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled)
+    }
+
+    protected def handlePythonException(): PythonException = {
+      // Signals that an exception has been thrown in python
+      val exLength = stream.readInt()
+      val obj = new Array[Byte](exLength)
+      stream.readFully(obj)
+      new PythonException(new String(obj, StandardCharsets.UTF_8),
+        writerThread.exception.orNull)
+    }
+
+    protected def handleEndOfDataSection(): Unit = {
+      // We've finished the data section of the output, but we can still
+      // read some accumulator updates:
+      val numAccumulatorUpdates = stream.readInt()
+      (1 to numAccumulatorUpdates).foreach { _ =>
+        val updateLen = stream.readInt()
+        val update = new Array[Byte](updateLen)
+        stream.readFully(update)
+        maybeAccumulator.foreach(_.add(update))
+      }
+      // Check whether the worker is ready to be re-used.
+      if (stream.readInt() == SpecialLengths.END_OF_STREAM) {
+        if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) {
+          env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker)
+        }
+      }
+      eos = true
+    }
+
+    protected val handleException: PartialFunction[Throwable, OUT] = {
+      case e: Exception if context.isInterrupted =>
+        logDebug("Exception thrown after task interruption", e)
+        throw new 
TaskKilledException(context.getKillReason().getOrElse("unknown reason"))
+
+      case e: Exception if writerThread.exception.isDefined =>
+        logError("Python worker exited unexpectedly (crashed)", e)
+        logError("This may have been caused by a prior exception:", 
writerThread.exception.get)
+        throw writerThread.exception.get
+
+      case eof: EOFException if faultHandlerEnabled && pid.isDefined &&
+        JavaFiles.exists(SedonaBasePythonRunner.faultHandlerLogPath(pid.get)) 
=>
+        val path = SedonaBasePythonRunner.faultHandlerLogPath(pid.get)
+        val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n"
+        JavaFiles.deleteIfExists(path)
+        throw new SparkException(s"Python worker exited unexpectedly 
(crashed): $error", eof)
+
+      case eof: EOFException =>
+        throw new SparkException("Python worker exited unexpectedly 
(crashed)", eof)
+    }
+  }
+
+  /**
+   * It is necessary to have a monitor thread for python workers if the user 
cancels with
+   * interrupts disabled. In that case we will need to explicitly kill the 
worker, otherwise the
+   * threads can block indefinitely.
+   */
+  class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext)
+    extends Thread(s"Worker Monitor for $pythonExec") {
+
+    /** How long to wait before killing the python worker if a task cannot be 
interrupted. */
+    private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT)
+
+    setDaemon(true)
+
+    private def monitorWorker(): Unit = {
+      // Kill the worker if it is interrupted, checking until task completion.
+      // TODO: This has a race condition if interruption occurs, as completed 
may still become true.
+      while (!context.isInterrupted && !context.isCompleted) {
+        Thread.sleep(2000)
+      }
+      if (!context.isCompleted) {
+        Thread.sleep(taskKillTimeout)
+        if (!context.isCompleted) {
+          try {
+            // Mimic the task name used in `Executor` to help the user find 
out the task to blame.
+            val taskName = s"${context.partitionId}.${context.attemptNumber} " 
+
+              s"in stage ${context.stageId} (TID ${context.taskAttemptId})"
+            logWarning(s"Incomplete task $taskName interrupted: Attempting to 
kill Python Worker")
+            env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
+          } catch {
+            case e: Exception =>
+              logError("Exception when trying to kill worker", e)
+          }
+        }
+      }
+    }
+
+    override def run(): Unit = {
+      try {
+        monitorWorker()
+      } finally {
+        if (reuseWorker) {
+          val key = (worker, context.taskAttemptId)
+          PythonRunner.runningMonitorThreads.remove(key)
+        }
+      }
+    }
+  }
+
+  /**
+   * This thread monitors the WriterThread and kills it in case of deadlock.
+   *
+   * A deadlock can arise if the task completes while the writer thread is 
sending input to the
+   * Python process (e.g. due to the use of `take()`), and the Python process 
is still producing
+   * output. When the inputs are sufficiently large, this can result in a 
deadlock due to the use of
+   * blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the 
socket.
+   */
+  class WriterMonitorThread(
+                             env: SparkEnv, worker: Socket, writerThread: 
WriterThread, context: TaskContext)
+    extends Thread(s"Writer Monitor for $pythonExec (writer thread id 
${writerThread.getId})") {
+
+    /**
+     * How long to wait before closing the socket if the writer thread has not 
exited after the task
+     * ends.
+     */
+    private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT)
+
+    setDaemon(true)
+
+    override def run(): Unit = {
+      // Wait until the task is completed (or the writer thread exits, in 
which case this thread has
+      // nothing to do).
+      while (!context.isCompleted && writerThread.isAlive) {
+        Thread.sleep(2000)
+      }
+      if (writerThread.isAlive) {
+        Thread.sleep(taskKillTimeout)
+        // If the writer thread continues running, this indicates a deadlock. 
Kill the worker to
+        // resolve the deadlock.
+        if (writerThread.isAlive) {
+          try {
+            // Mimic the task name used in `Executor` to help the user find 
out the task to blame.
+            val taskName = s"${context.partitionId}.${context.attemptNumber} " 
+
+              s"in stage ${context.stageId} (TID ${context.taskAttemptId})"
+            logWarning(
+              s"Detected deadlock while completing task $taskName: " +
+                "Attempting to kill Python Worker")
+            env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker)
+          } catch {
+            case e: Exception =>
+              logError("Exception when trying to kill worker", e)
+          }
+        }
+      }
+    }
+  }
+}
+
+private[spark] object PythonRunner {
+
+  // already running worker monitor threads for worker and task attempts ID 
pairs
+  val runningMonitorThreads = ConcurrentHashMap.newKeySet[(Socket, Long)]()
+
+  private var printPythonInfo: AtomicBoolean = new AtomicBoolean(true)
+
+  def apply(func: PythonFunction, jobArtifactUUID: Option[String]): 
PythonRunner = {
+    if (printPythonInfo.compareAndSet(true, false)) {
+      PythonUtils.logPythonInfo(func.pythonExec)
+    }
+    new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), jobArtifactUUID)
+  }
+}
+
+/**
+ * A helper class to run Python mapPartition in Spark.
+ */
+private[spark] class PythonRunner(
+                                   funcs: Seq[ChainedPythonFunctions], 
jobArtifactUUID: Option[String])
+  extends BasePythonRunner[Array[Byte], Array[Byte]](
+    funcs, PythonEvalType.NON_UDF, Array(Array(0)), jobArtifactUUID) {
+
+  protected override def newWriterThread(
+                                          env: SparkEnv,
+                                          worker: Socket,
+                                          inputIterator: Iterator[Array[Byte]],
+                                          partitionIndex: Int,
+                                          context: TaskContext): WriterThread 
= {
+    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
+
+      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+        val command = funcs.head.funcs.head.command
+        dataOut.writeInt(command.length)
+        dataOut.write(command.toArray)
+      }
+
+      protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
+        PythonRDD.writeIteratorToStream(inputIterator, dataOut)
+        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+      }
+    }
+  }
+
+  protected override def newReaderIterator(
+                                            stream: DataInputStream,
+                                            writerThread: WriterThread,
+                                            startTime: Long,
+                                            env: SparkEnv,
+                                            worker: Socket,
+                                            pid: Option[Int],
+                                            releasedOrClosed: AtomicBoolean,
+                                            context: TaskContext): 
Iterator[Array[Byte]] = {
+    new ReaderIterator(
+      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context) {
+
+      protected override def read(): Array[Byte] = {
+        if (writerThread.exception.isDefined) {
+          throw writerThread.exception.get
+        }
+        try {
+          stream.readInt() match {
+            case length if length > 0 =>
+              val obj = new Array[Byte](length)
+              stream.readFully(obj)
+              obj
+            case 0 => Array.emptyByteArray
+            case SpecialLengths.TIMING_DATA =>
+              handleTimingData()
+              read()
+            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+              throw handlePythonException()
+            case SpecialLengths.END_OF_DATA_SECTION =>
+              handleEndOfDataSection()
+              null
+          }
+        } catch handleException
+      }
+    }
+  }
+}
+
+private[spark] object SpecialLengths {
+  val END_OF_DATA_SECTION = -1
+  val PYTHON_EXCEPTION_THROWN = -2
+  val TIMING_DATA = -3
+  val END_OF_STREAM = -4
+  val NULL = -5
+  val START_ARROW_STREAM = -6
+  val END_OF_MICRO_BATCH = -7
+}
+
+private[spark] object BarrierTaskContextMessageProtocol {
+  val BARRIER_FUNCTION = 1
+  val ALL_GATHER_FUNCTION = 2
+  val BARRIER_RESULT_SUCCESS = "success"
+  val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python 
side."
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
new file mode 100644
index 0000000000..27e4b851ee
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
@@ -0,0 +1,70 @@
+package org.apache.spark.sql.execution.python
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.spark.api.python._
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types._
+import org.apache.spark.sql.vectorized.ColumnarBatch
+
+/**
+ * Similar to `PythonUDFRunner`, but exchange data with Python worker via 
Arrow stream.
+ */
+class SedonaArrowPythonRunner(
+                         funcs: Seq[ChainedPythonFunctions],
+                         evalType: Int,
+                         argOffsets: Array[Array[Int]],
+                         protected override val schema: StructType,
+                         protected override val timeZoneId: String,
+                         protected override val largeVarTypes: Boolean,
+                         protected override val workerConf: Map[String, 
String],
+                         val pythonMetrics: Map[String, SQLMetric],
+                         jobArtifactUUID: Option[String])
+  extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch](
+    funcs, evalType, argOffsets, jobArtifactUUID)
+    with SedonaBasicPythonArrowInput
+    with SedonaBasicPythonArrowOutput {
+
+  override val pythonExec: String =
+    SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
+      funcs.head.funcs.head.pythonExec)
+
+  override val errorOnDuplicatedFieldNames: Boolean = true
+
+  override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
+
+  override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize
+  require(
+    bufferSize >= 4,
+    "Pandas execution requires more than 4 bytes. Please set higher buffer. " +
+      s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
+}
+
+object SedonaArrowPythonRunner {
+  /** Return Map with conf settings to be used in ArrowPythonRunner */
+  def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
+    val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> 
conf.sessionLocalTimeZone)
+    val pandasColsByName = 
Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
+      conf.pandasGroupedMapAssignColumnsByName.toString)
+    val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key 
->
+      conf.arrowSafeTypeConversion.toString)
+    Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
+  }
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
new file mode 100644
index 0000000000..d2c390282c
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
@@ -0,0 +1,148 @@
+package org.apache.spark.sql.execution.python
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.apache.spark.api.python
+import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, 
PythonRDD, SedonaBasePythonRunner}
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.execution.arrow.ArrowWriter
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.util.Utils
+import org.apache.spark.{SparkEnv, TaskContext}
+
+import java.io.DataOutputStream
+import java.net.Socket
+
+/**
+ * A trait that can be mixed-in with [[python.BasePythonRunner]]. It 
implements the logic from
+ * JVM (an iterator of internal rows + additional data if required) to Python 
(Arrow).
+ */
+private[python] trait SedonaPythonArrowInput[IN] { self: 
SedonaBasePythonRunner[IN, _] =>
+  protected val workerConf: Map[String, String]
+
+  protected val schema: StructType
+
+  protected val timeZoneId: String
+
+  protected val errorOnDuplicatedFieldNames: Boolean
+
+  protected val largeVarTypes: Boolean
+
+  protected def pythonMetrics: Map[String, SQLMetric]
+
+  protected def writeIteratorToArrowStream(
+                                            root: VectorSchemaRoot,
+                                            writer: ArrowStreamWriter,
+                                            dataOut: DataOutputStream,
+                                            inputIterator: Iterator[IN]): Unit
+
+  protected def writeUDF(
+                          dataOut: DataOutputStream,
+                          funcs: Seq[ChainedPythonFunctions],
+                          argOffsets: Array[Array[Int]]): Unit =
+    SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+
+  protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
+    // Write config for the worker as a number of key -> value pairs of strings
+    stream.writeInt(workerConf.size)
+    for ((k, v) <- workerConf) {
+      PythonRDD.writeUTF(k, stream)
+      PythonRDD.writeUTF(v, stream)
+    }
+  }
+
+  protected override def newWriterThread(
+                                          env: SparkEnv,
+                                          worker: Socket,
+                                          inputIterator: Iterator[IN],
+                                          partitionIndex: Int,
+                                          context: TaskContext): WriterThread 
= {
+    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
+
+      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+        handleMetadataBeforeExec(dataOut)
+        writeUDF(dataOut, funcs, argOffsets)
+      }
+
+      protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
+        val arrowSchema = ArrowUtils.toArrowSchema(
+          schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes)
+        val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+          s"stdout writer for $pythonExec", 0, Long.MaxValue)
+        val root = VectorSchemaRoot.create(arrowSchema, allocator)
+
+        Utils.tryWithSafeFinally {
+          val writer = new ArrowStreamWriter(root, null, dataOut)
+          writer.start()
+
+          writeIteratorToArrowStream(root, writer, dataOut, inputIterator)
+
+          // end writes footer to the output stream and doesn't clean any 
resources.
+          // It could throw exception if the output stream is closed, so it 
should be
+          // in the try block.
+          writer.end()
+        } {
+          // If we close root and allocator in TaskCompletionListener, there 
could be a race
+          // condition where the writer thread keeps writing to the 
VectorSchemaRoot while
+          // it's being closed by the TaskCompletion listener.
+          // Closing root and allocator here is cleaner because root and 
allocator is owned
+          // by the writer thread and is only visible to the writer thread.
+          //
+          // If the writer thread is interrupted by TaskCompletionListener, it 
should either
+          // (1) in the try block, in which case it will get an 
InterruptedException when
+          // performing io, and goes into the finally block or (2) in the 
finally block,
+          // in which case it will ignore the interruption and close the 
resources.
+          root.close()
+          allocator.close()
+        }
+      }
+    }
+  }
+}
+
+
+private[python] trait SedonaBasicPythonArrowInput extends 
SedonaPythonArrowInput[Iterator[InternalRow]] {
+  self: SedonaBasePythonRunner[Iterator[InternalRow], _] =>
+
+  protected def writeIteratorToArrowStream(
+                                            root: VectorSchemaRoot,
+                                            writer: ArrowStreamWriter,
+                                            dataOut: DataOutputStream,
+                                            inputIterator: 
Iterator[Iterator[InternalRow]]): Unit = {
+    val arrowWriter = ArrowWriter.create(root)
+
+    while (inputIterator.hasNext) {
+      val startData = dataOut.size()
+      val nextBatch = inputIterator.next()
+
+      while (nextBatch.hasNext) {
+        arrowWriter.write(nextBatch.next())
+      }
+
+      arrowWriter.finish()
+      writer.writeBatch()
+      arrowWriter.reset()
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
+    }
+  }
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala
new file mode 100644
index 0000000000..91e840da58
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala
@@ -0,0 +1,135 @@
+package org.apache.spark.sql.execution.python
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamReader
+import org.apache.spark.api.python.{BasePythonRunner, SedonaBasePythonRunner, 
SpecialLengths}
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, 
ColumnarBatch}
+import org.apache.spark.{SparkEnv, TaskContext}
+
+import java.io.DataInputStream
+import java.net.Socket
+import java.util.concurrent.atomic.AtomicBoolean
+import scala.collection.JavaConverters._
+
+/**
+ * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the 
logic from
+ * Python (Arrow) to JVM (output type being deserialized from ColumnarBatch).
+ */
+private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: 
SedonaBasePythonRunner[_, OUT] =>
+
+  protected def pythonMetrics: Map[String, SQLMetric]
+
+  protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { }
+
+  protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: 
StructType): OUT
+
+  protected def newReaderIterator(
+                                   stream: DataInputStream,
+                                   writerThread: WriterThread,
+                                   startTime: Long,
+                                   env: SparkEnv,
+                                   worker: Socket,
+                                   pid: Option[Int],
+                                   releasedOrClosed: AtomicBoolean,
+                                   context: TaskContext): Iterator[OUT] = {
+
+    new ReaderIterator(
+      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context) {
+
+      private val allocator = ArrowUtils.rootAllocator.newChildAllocator(
+        s"stdin reader for $pythonExec", 0, Long.MaxValue)
+
+      private var reader: ArrowStreamReader = _
+      private var root: VectorSchemaRoot = _
+      private var schema: StructType = _
+      private var vectors: Array[ColumnVector] = _
+
+      context.addTaskCompletionListener[Unit] { _ =>
+        if (reader != null) {
+          reader.close(false)
+        }
+        allocator.close()
+      }
+
+      private var batchLoaded = true
+
+      protected override def handleEndOfDataSection(): Unit = {
+        handleMetadataAfterExec(stream)
+        super.handleEndOfDataSection()
+      }
+
+      protected override def read(): OUT = {
+        if (writerThread.exception.isDefined) {
+          throw writerThread.exception.get
+        }
+        try {
+          if (reader != null && batchLoaded) {
+            val bytesReadStart = reader.bytesRead()
+            batchLoaded = reader.loadNextBatch()
+            if (batchLoaded) {
+              val batch = new ColumnarBatch(vectors)
+              val rowCount = root.getRowCount
+              batch.setNumRows(root.getRowCount)
+              val bytesReadEnd = reader.bytesRead()
+              pythonMetrics("pythonNumRowsReceived") += rowCount
+              pythonMetrics("pythonDataReceived") += bytesReadEnd - 
bytesReadStart
+              deserializeColumnarBatch(batch, schema)
+            } else {
+              reader.close(false)
+              allocator.close()
+              // Reach end of stream. Call `read()` again to read control data.
+              read()
+            }
+          } else {
+            stream.readInt() match {
+              case SpecialLengths.START_ARROW_STREAM =>
+                reader = new ArrowStreamReader(stream, allocator)
+                root = reader.getVectorSchemaRoot()
+                schema = ArrowUtils.fromArrowSchema(root.getSchema())
+                vectors = root.getFieldVectors().asScala.map { vector =>
+                  new ArrowColumnVector(vector)
+                }.toArray[ColumnVector]
+                read()
+              case SpecialLengths.TIMING_DATA =>
+                handleTimingData()
+                read()
+              case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+                throw handlePythonException()
+              case SpecialLengths.END_OF_DATA_SECTION =>
+                handleEndOfDataSection()
+                null.asInstanceOf[OUT]
+            }
+          }
+        } catch handleException
+      }
+    }
+  }
+}
+
+private[python] trait SedonaBasicPythonArrowOutput extends 
SedonaPythonArrowOutput[ColumnarBatch] {
+  self: SedonaBasePythonRunner[_, ColumnarBatch] =>
+
+  protected def deserializeColumnarBatch(
+                                          batch: ColumnarBatch,
+                                          schema: StructType): ColumnarBatch = 
batch
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
new file mode 100644
index 0000000000..ced32cf801
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
@@ -0,0 +1,147 @@
+package org.apache.spark.sql.execution.python
+
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one or more
+ * contributor license agreements.  See the NOTICE file distributed with
+ * this work for additional information regarding copyright ownership.
+ * The ASF licenses this file to You under the Apache License, Version 2.0
+ * (the "License"); you may not use this file except in compliance with
+ * the License.  You may obtain a copy of the License at
+ *
+ *    http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing, software
+ * distributed under the License is distributed on an "AS IS" BASIS,
+ * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
+ * See the License for the specific language governing permissions and
+ * limitations under the License.
+ */
+
+import java.io._
+import java.net._
+import java.util.concurrent.atomic.AtomicBoolean
+import org.apache.spark._
+import org.apache.spark.api.python._
+import org.apache.spark.sql.execution.metric.SQLMetric
+import org.apache.spark.sql.internal.SQLConf
+
+/**
+ * A helper class to run Python UDFs in Spark.
+ */
+abstract class SedonaBasePythonUDFRunner(
+                                    funcs: Seq[ChainedPythonFunctions],
+                                    evalType: Int,
+                                    argOffsets: Array[Array[Int]],
+                                    pythonMetrics: Map[String, SQLMetric],
+                                    jobArtifactUUID: Option[String])
+  extends SedonaBasePythonRunner[Array[Byte], Array[Byte]](
+    funcs, evalType, argOffsets, jobArtifactUUID) {
+
+  override val pythonExec: String =
+    SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
+      funcs.head.funcs.head.pythonExec)
+
+  override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
+
+  abstract class SedonaPythonUDFWriterThread(
+                                        env: SparkEnv,
+                                        worker: Socket,
+                                        inputIterator: Iterator[Array[Byte]],
+                                        partitionIndex: Int,
+                                        context: TaskContext)
+    extends WriterThread(env, worker, inputIterator, partitionIndex, context) {
+
+    protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
+      val startData = dataOut.size()
+
+      PythonRDD.writeIteratorToStream(inputIterator, dataOut)
+      dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
+
+      val deltaData = dataOut.size() - startData
+      pythonMetrics("pythonDataSent") += deltaData
+    }
+  }
+
+  protected override def newReaderIterator(
+                                            stream: DataInputStream,
+                                            writerThread: WriterThread,
+                                            startTime: Long,
+                                            env: SparkEnv,
+                                            worker: Socket,
+                                            pid: Option[Int],
+                                            releasedOrClosed: AtomicBoolean,
+                                            context: TaskContext): 
Iterator[Array[Byte]] = {
+    new ReaderIterator(
+      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context) {
+
+      protected override def read(): Array[Byte] = {
+        if (writerThread.exception.isDefined) {
+          throw writerThread.exception.get
+        }
+        try {
+          stream.readInt() match {
+            case length if length > 0 =>
+              val obj = new Array[Byte](length)
+              stream.readFully(obj)
+              pythonMetrics("pythonDataReceived") += length
+              obj
+            case 0 => Array.emptyByteArray
+            case SpecialLengths.TIMING_DATA =>
+              handleTimingData()
+              read()
+            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
+              throw handlePythonException()
+            case SpecialLengths.END_OF_DATA_SECTION =>
+              handleEndOfDataSection()
+              null
+          }
+        } catch handleException
+      }
+    }
+  }
+}
+
+class SedonaPythonUDFRunner(
+                       funcs: Seq[ChainedPythonFunctions],
+                       evalType: Int,
+                       argOffsets: Array[Array[Int]],
+                       pythonMetrics: Map[String, SQLMetric],
+                       jobArtifactUUID: Option[String])
+  extends SedonaBasePythonUDFRunner(funcs, evalType, argOffsets, 
pythonMetrics, jobArtifactUUID) {
+
+  protected override def newWriterThread(
+                                          env: SparkEnv,
+                                          worker: Socket,
+                                          inputIterator: Iterator[Array[Byte]],
+                                          partitionIndex: Int,
+                                          context: TaskContext): WriterThread 
= {
+    new SedonaPythonUDFWriterThread(env, worker, inputIterator, 
partitionIndex, context) {
+
+      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
+        SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
+      }
+
+    }
+  }
+}
+
+object SedonaPythonUDFRunner {
+
+  def writeUDFs(
+                 dataOut: DataOutputStream,
+                 funcs: Seq[ChainedPythonFunctions],
+                 argOffsets: Array[Array[Int]]): Unit = {
+    dataOut.writeInt(funcs.length)
+    funcs.zip(argOffsets).foreach { case (chained, offsets) =>
+      dataOut.writeInt(offsets.length)
+      offsets.foreach { offset =>
+        dataOut.writeInt(offset)
+      }
+      dataOut.writeInt(chained.funcs.length)
+      chained.funcs.foreach { f =>
+        dataOut.writeInt(f.command.length)
+        dataOut.write(f.command.toArray)
+      }
+    }
+  }
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
index a403fa6b9e..5883fd905d 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
@@ -26,7 +26,7 @@ import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
 import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.python.{ArrowPythonRunner, 
BatchIterator, EvalPythonExec, PythonSQLMetrics}
+import org.apache.spark.sql.execution.python.{ArrowPythonRunner, 
BatchIterator, EvalPythonExec, PythonSQLMetrics, SedonaArrowPythonRunner}
 import org.apache.spark.sql.types.StructType
 
 import scala.collection.JavaConverters.asScalaIteratorConverter
@@ -68,7 +68,7 @@ case class SedonaArrowEvalPythonExec(
 
     val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else 
Iterator(iter)
 
-    val columnarBatchIter = new ArrowPythonRunner(
+    val columnarBatchIter = new SedonaArrowPythonRunner(
       funcs,
       evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT,
       argOffsets,
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
index c0a2d8f260..80a4c64106 100644
--- 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
@@ -19,9 +19,10 @@
 package org.apache.spark.sql.udf
 
 import org.apache.sedona.sql.UDF
-import org.apache.spark.TestUtils
+import org.apache.spark.{SparkEnv, TestUtils}
 import org.apache.spark.api.python._
 import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.internal.config.Python.{PYTHON_DAEMON_MODULE, 
PYTHON_USE_DAEMON, PYTHON_WORKER_MODULE}
 import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.util.Utils
@@ -70,10 +71,11 @@ object ScalarUDF {
     finally Utils.deleteRecursively(path)
   }
 
+  val additionalModule = 
"spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf"
+
   val pandasFunc: Array[Byte] = {
     var binaryPandasFunc: Array[Byte] = null
     withTempPath { path =>
-      println(path)
       Process(
         Seq(
           pythonExec,
@@ -85,6 +87,17 @@ object ScalarUDF {
             |from pyspark.serializers import CloudPickleSerializer
             |from sedona.utils import geometry_serde
             |from shapely import box
+            |import logging
+            |logging.basicConfig(level=logging.INFO)
+            |logger = logging.getLogger(__name__)
+            |logger.info("Loading Sedona Python UDF")
+            |import os
+            |logger.info(os.getcwd())
+            |import sys
+            |import sys
+            |print("boring stuff")
+            |sys.path.append('$additionalModule')
+            |logger.info(sys.path)
             |f = open('$path', 'wb');
             |def w(x):
             |    def apply_function(w):
@@ -104,7 +117,9 @@ object ScalarUDF {
   }
 
   private val workerEnv = new java.util.HashMap[String, String]()
-  workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath")
+    workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath")
+    SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "sedonaworker.worker")
+    SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false)
 
   val geoPandasScalaFunction: UserDefinedPythonFunction = 
UserDefinedPythonFunction(
     name = "geospatial_udf",

Reply via email to