This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch add-geom-from-mysql-function in repository https://gitbox.apache.org/repos/asf/sedona.git
commit ea1219fc449d14150843bba8d93778b98aa144cf Author: pawelkocinski <[email protected]> AuthorDate: Fri Aug 15 15:33:11 2025 +0200 SEDONA-743 Add geom from mysql function. --- pom.xml | 12 +- sedonaworker/__init__.py | 0 sedonaworker/serializer.py | 100 --- sedonaworker/worker.py | 428 ------------ .../org/apache/sedona/spark/SedonaContext.scala | 2 +- .../scala/org/apache/spark/SedonaSparkEnv.scala | 515 -------------- .../spark/api/python/SedonaPythonRunner.scala | 742 --------------------- .../execution/python/SedonaArrowPythonRunner.scala | 93 --- .../sql/execution/python/SedonaArrowStrategy.scala | 256 ------- .../sql/execution/python/SedonaArrowUtils.scala | 256 ------- .../execution/python/SedonaPythonArrowInput.scala | 171 ----- .../execution/python/SedonaPythonArrowOutput.scala | 169 ----- .../execution/python/SedonaPythonUDFRunner.scala | 179 ----- .../apache/spark/sql/udf/SedonaArrowStrategy.scala | 89 +++ .../org/apache/spark/sql/udf/StrategySuite.scala | 50 +- .../apache/spark/sql/udf/TestScalarPandasUDF.scala | 107 +-- 16 files changed, 133 insertions(+), 3036 deletions(-) diff --git a/pom.xml b/pom.xml index 5aff90376e..7865db02ed 100644 --- a/pom.xml +++ b/pom.xml @@ -18,12 +18,12 @@ --> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> <modelVersion>4.0.0</modelVersion> -<!-- <parent>--> -<!-- <groupId>org.apache</groupId>--> -<!-- <artifactId>apache</artifactId>--> -<!-- <version>23</version>--> -<!-- <relativePath />--> -<!-- </parent>--> + <parent> + <groupId>org.apache</groupId> + <artifactId>apache</artifactId> + <version>23</version> + <relativePath /> + </parent> <groupId>org.apache.sedona</groupId> <artifactId>sedona-parent</artifactId> <version>1.8.0-SNAPSHOT</version> diff --git a/sedonaworker/__init__.py b/sedonaworker/__init__.py deleted file mode 100644 index e69de29bb2..0000000000 diff --git a/sedonaworker/serializer.py b/sedonaworker/serializer.py deleted file mode 100644 index 229c06ce04..0000000000 --- a/sedonaworker/serializer.py +++ /dev/null @@ -1,100 +0,0 @@ -from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer, ArrowStreamSerializer -from pyspark.errors import PySparkTypeError, PySparkValueError -import struct - -from pyspark.serializers import write_int - -def write_int(value, stream): - stream.write(struct.pack("!i", value)) - -class SpecialLengths: - START_ARROW_STREAM = -6 - - -class SedonaArrowStreamPandasUDFSerializer(ArrowStreamPandasSerializer): - def __init__(self, timezone, safecheck, assign_cols_by_name): - super(SedonaArrowStreamPandasUDFSerializer, self).__init__(timezone, safecheck) - self._assign_cols_by_name = assign_cols_by_name - - def load_stream(self, stream): - import geoarrow.pyarrow as ga - import pyarrow as pa - - batches = super(ArrowStreamPandasSerializer, self).load_stream(stream) - for batch in batches: - table = pa.Table.from_batches(batches=[batch]) - data = [] - - for c in table.itercolumns(): - meta = table.schema.field(c._name).metadata - if meta and meta[b"ARROW:extension:name"] == b'geoarrow.wkb': - data.append(ga.to_geopandas(c)) - continue - - data.append(self.arrow_to_pandas(c)) - - yield data - - def _create_batch(self, series): - import pyarrow as pa - - series = ((s, None) if not isinstance(s, (list, tuple)) else s for s in series) - - arrs = [] - for s, t in series: - arrs.append(self._create_array(s, t)) - - return pa.RecordBatch.from_arrays(arrs, ["_%d" % i for i in range(len(arrs))]) - - def _create_array(self, series, arrow_type): - import pyarrow as pa - import geopandas as gpd - - if hasattr(series.array, "__arrow_array__"): - mask = None - else: - mask = series.isnull() - - try: - if isinstance(series, gpd.GeoSeries): - import geoarrow.pyarrow as ga - # If the series is a GeoSeries, convert it to an Arrow array using geoarrow - return ga.array(series) - - array = pa.Array.from_pandas(series, mask=mask, type=arrow_type) - return array - except TypeError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - raise PySparkTypeError(error_msg % (series.dtype, series.name, arrow_type)) from e - except ValueError as e: - error_msg = ( - "Exception thrown when converting pandas.Series (%s) " - "with name '%s' to Arrow Array (%s)." - ) - raise PySparkValueError(error_msg % (series.dtype, series.name, arrow_type)) from e - - def dump_stream(self, iterator, stream): - """ - Override because Pandas UDFs require a START_ARROW_STREAM before the Arrow stream is sent. - This should be sent after creating the first record batch so in case of an error, it can - be sent back to the JVM before the Arrow stream starts. - """ - - def init_stream_yield_batches(): - should_write_start_length = True - for series in iterator: - batch = self._create_batch(series) - if should_write_start_length: - write_int(SpecialLengths.START_ARROW_STREAM, stream) - should_write_start_length = False - yield batch - - return ArrowStreamSerializer.dump_stream(self, init_stream_yield_batches(), stream) - - def __repr__(self): - return "ArrowStreamPandasUDFSerializer" - - diff --git a/sedonaworker/worker.py b/sedonaworker/worker.py deleted file mode 100644 index 98ee38242b..0000000000 --- a/sedonaworker/worker.py +++ /dev/null @@ -1,428 +0,0 @@ -# -# Licensed to the Apache Software Foundation (ASF) under one or more -# contributor license agreements. See the NOTICE file distributed with -# this work for additional information regarding copyright ownership. -# The ASF licenses this file to You under the Apache License, Version 2.0 -# (the "License"); you may not use this file except in compliance with -# the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# - -""" -Worker that receives input from Piped RDD. -""" -import os -import sys -import time -from inspect import currentframe, getframeinfo -import importlib - -from sedonaworker.serializer import SedonaArrowStreamPandasUDFSerializer - -has_resource_module = True -try: - import resource -except ImportError: - has_resource_module = False -import traceback -import warnings -import faulthandler - -from pyspark.accumulators import _accumulatorRegistry -from pyspark.broadcast import Broadcast, _broadcastRegistry -from pyspark.java_gateway import local_connect_and_auth -from pyspark.taskcontext import BarrierTaskContext, TaskContext -from pyspark.files import SparkFiles -from pyspark.resource import ResourceInformation -from pyspark.rdd import PythonEvalType -from pyspark.serializers import ( - write_with_length, - write_int, - read_long, - read_bool, - write_long, - read_int, - SpecialLengths, - UTF8Deserializer, - CPickleSerializer, -) - -from pyspark.sql.pandas.types import to_arrow_type -from pyspark.sql.types import StructType -from pyspark.util import fail_on_stopiteration, try_simplify_traceback -from pyspark import shuffle -from pyspark.errors import PySparkRuntimeError, PySparkTypeError - -pickleSer = CPickleSerializer() -utf8_deserializer = UTF8Deserializer() - - -def report_times(outfile, boot, init, finish): - write_int(SpecialLengths.TIMING_DATA, outfile) - write_long(int(1000 * boot), outfile) - write_long(int(1000 * init), outfile) - write_long(int(1000 * finish), outfile) - - -def add_path(path): - # worker can be used, so do not add path multiple times - if path not in sys.path: - # overwrite system packages - sys.path.insert(1, path) - - -def read_command(serializer, file): - command = serializer._read_with_length(file) - if isinstance(command, Broadcast): - command = serializer.loads(command.value) - return command - - -def chain(f, g): - """chain two functions together""" - return lambda *a: g(f(*a)) - -def wrap_scalar_pandas_udf(f, return_type): - arrow_return_type = to_arrow_type(return_type) - - def verify_result_type(result): - if not hasattr(result, "__len__"): - pd_type = "pandas.DataFrame" if type(return_type) == StructType else "pandas.Series" - raise PySparkTypeError( - error_class="UDF_RETURN_TYPE", - message_parameters={ - "expected": pd_type, - "actual": type(result).__name__, - }, - ) - return result - - def verify_result_length(result, length): - if len(result) != length: - raise PySparkRuntimeError( - error_class="SCHEMA_MISMATCH_FOR_PANDAS_UDF", - message_parameters={ - "expected": str(length), - "actual": str(len(result)), - }, - ) - return result - - return lambda *a: ( - verify_result_length(verify_result_type(f(*a)), len(a[0])), - arrow_return_type, - ) - - -def read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index): - num_arg = read_int(infile) - arg_offsets = [read_int(infile) for i in range(num_arg)] - chained_func = None - for i in range(read_int(infile)): - f, return_type = read_command(pickleSer, infile) - if chained_func is None: - chained_func = f - else: - chained_func = chain(chained_func, f) - - func = fail_on_stopiteration(chained_func) - - # the last returnType will be the return type of UDF - if eval_type == PythonEvalType.SQL_SCALAR_PANDAS_UDF: - return arg_offsets, wrap_scalar_pandas_udf(func, return_type), return_type - else: - raise ValueError("Unknown eval type: {}".format(eval_type)) - - -# Used by SQL_GROUPED_MAP_PANDAS_UDF and SQL_SCALAR_PANDAS_UDF and SQL_ARROW_BATCHED_UDF when -# returning StructType -def assign_cols_by_name(runner_conf): - return ( - runner_conf.get( - "spark.sql.legacy.execution.pandas.groupedMap.assignColumnsByName", "true" - ).lower() - == "true" - ) - - -def read_udfs(pickleSer, infile, eval_type): - from sedona.sql.types import GeometryType - import geopandas as gpd - runner_conf = {} - - # Load conf used for pandas_udf evaluation - num_conf = read_int(infile) - for i in range(num_conf): - k = utf8_deserializer.loads(infile) - v = utf8_deserializer.loads(infile) - runner_conf[k] = v - - timezone = runner_conf.get("spark.sql.session.timeZone", None) - safecheck = ( - runner_conf.get("spark.sql.execution.pandas.convertToArrowArraySafely", "false").lower() - == "true" - ) - - num_udfs = read_int(infile) - - udfs = [] - for i in range(num_udfs): - udfs.append(read_single_udf(pickleSer, infile, eval_type, runner_conf, udf_index=i)) - - def mapper(a): - results = [] - - for (arg_offsets, f, return_type) in udfs: - result = f(*[a[o] for o in arg_offsets]) - if isinstance(return_type, GeometryType): - results.append(( - gpd.GeoSeries(result[0]), - result[1], - )) - - continue - - results.append(result) - - if len(results) == 1: - return results[0] - else: - return results - - def func(_, it): - return map(mapper, it) - - ser = SedonaArrowStreamPandasUDFSerializer( - timezone, - safecheck, - assign_cols_by_name(runner_conf), - ) - - # profiling is not supported for UDF - return func, None, ser, ser - - -def main(infile, outfile): - faulthandler_log_path = os.environ.get("PYTHON_FAULTHANDLER_DIR", None) - try: - if faulthandler_log_path: - faulthandler_log_path = os.path.join(faulthandler_log_path, str(os.getpid())) - faulthandler_log_file = open(faulthandler_log_path, "w") - faulthandler.enable(file=faulthandler_log_file) - - boot_time = time.time() - split_index = read_int(infile) - if split_index == -1: # for unit tests - sys.exit(-1) - - version = utf8_deserializer.loads(infile) - - if version != "%d.%d" % sys.version_info[:2]: - raise PySparkRuntimeError( - error_class="PYTHON_VERSION_MISMATCH", - message_parameters={ - "worker_version": str(sys.version_info[:2]), - "driver_version": str(version), - }, - ) - - # read inputs only for a barrier task - isBarrier = read_bool(infile) - boundPort = read_int(infile) - secret = UTF8Deserializer().loads(infile) - - # set up memory limits - memory_limit_mb = int(os.environ.get("PYSPARK_EXECUTOR_MEMORY_MB", "-1")) - if memory_limit_mb > 0 and has_resource_module: - total_memory = resource.RLIMIT_AS - try: - (soft_limit, hard_limit) = resource.getrlimit(total_memory) - msg = "Current mem limits: {0} of max {1}\n".format(soft_limit, hard_limit) - print(msg, file=sys.stderr) - - # convert to bytes - new_limit = memory_limit_mb * 1024 * 1024 - - if soft_limit == resource.RLIM_INFINITY or new_limit < soft_limit: - msg = "Setting mem limits to {0} of max {1}\n".format(new_limit, new_limit) - print(msg, file=sys.stderr) - resource.setrlimit(total_memory, (new_limit, new_limit)) - - except (resource.error, OSError, ValueError) as e: - # not all systems support resource limits, so warn instead of failing - lineno = ( - getframeinfo(currentframe()).lineno + 1 if currentframe() is not None else 0 - ) - if "__file__" in globals(): - print( - warnings.formatwarning( - "Failed to set memory limit: {0}".format(e), - ResourceWarning, - __file__, - lineno, - ), - file=sys.stderr, - ) - - # initialize global state - taskContext = None - if isBarrier: - taskContext = BarrierTaskContext._getOrCreate() - BarrierTaskContext._initialize(boundPort, secret) - # Set the task context instance here, so we can get it by TaskContext.get for - # both TaskContext and BarrierTaskContext - TaskContext._setTaskContext(taskContext) - else: - taskContext = TaskContext._getOrCreate() - # read inputs for TaskContext info - taskContext._stageId = read_int(infile) - taskContext._partitionId = read_int(infile) - taskContext._attemptNumber = read_int(infile) - taskContext._taskAttemptId = read_long(infile) - taskContext._cpus = read_int(infile) - taskContext._resources = {} - for r in range(read_int(infile)): - key = utf8_deserializer.loads(infile) - name = utf8_deserializer.loads(infile) - addresses = [] - taskContext._resources = {} - for a in range(read_int(infile)): - addresses.append(utf8_deserializer.loads(infile)) - taskContext._resources[key] = ResourceInformation(name, addresses) - - taskContext._localProperties = dict() - for i in range(read_int(infile)): - k = utf8_deserializer.loads(infile) - v = utf8_deserializer.loads(infile) - taskContext._localProperties[k] = v - - shuffle.MemoryBytesSpilled = 0 - shuffle.DiskBytesSpilled = 0 - _accumulatorRegistry.clear() - - # fetch name of workdir - spark_files_dir = utf8_deserializer.loads(infile) - SparkFiles._root_directory = spark_files_dir - SparkFiles._is_running_on_worker = True - - # fetch names of includes (*.zip and *.egg files) and construct PYTHONPATH - add_path(spark_files_dir) # *.py files that were added will be copied here - num_python_includes = read_int(infile) - for _ in range(num_python_includes): - filename = utf8_deserializer.loads(infile) - add_path(os.path.join(spark_files_dir, filename)) - - importlib.invalidate_caches() - - # fetch names and values of broadcast variables - needs_broadcast_decryption_server = read_bool(infile) - num_broadcast_variables = read_int(infile) - if needs_broadcast_decryption_server: - # read the decrypted data from a server in the jvm - port = read_int(infile) - auth_secret = utf8_deserializer.loads(infile) - (broadcast_sock_file, _) = local_connect_and_auth(port, auth_secret) - - for _ in range(num_broadcast_variables): - bid = read_long(infile) - if bid >= 0: - if needs_broadcast_decryption_server: - read_bid = read_long(broadcast_sock_file) - assert read_bid == bid - _broadcastRegistry[bid] = Broadcast(sock_file=broadcast_sock_file) - else: - path = utf8_deserializer.loads(infile) - _broadcastRegistry[bid] = Broadcast(path=path) - - else: - bid = -bid - 1 - _broadcastRegistry.pop(bid) - - if needs_broadcast_decryption_server: - broadcast_sock_file.write(b"1") - broadcast_sock_file.close() - - _accumulatorRegistry.clear() - eval_type = read_int(infile) - func, profiler, deserializer, serializer = read_udfs(pickleSer, infile, eval_type) - init_time = time.time() - - def process(): - iterator = deserializer.load_stream(infile) - out_iter = func(split_index, iterator) - try: - serializer.dump_stream(out_iter, outfile) - finally: - if hasattr(out_iter, "close"): - out_iter.close() - - if profiler: - profiler.profile(process) - else: - process() - - # Reset task context to None. This is a guard code to avoid residual context when worker - # reuse. - TaskContext._setTaskContext(None) - BarrierTaskContext._setTaskContext(None) - except BaseException as e: - try: - exc_info = None - if os.environ.get("SPARK_SIMPLIFIED_TRACEBACK", False): - tb = try_simplify_traceback(sys.exc_info()[-1]) - if tb is not None: - e.__cause__ = None - exc_info = "".join(traceback.format_exception(type(e), e, tb)) - if exc_info is None: - exc_info = traceback.format_exc() - - write_int(SpecialLengths.PYTHON_EXCEPTION_THROWN, outfile) - write_with_length(exc_info.encode("utf-8"), outfile) - except IOError: - # JVM close the socket - pass - except BaseException: - # Write the error to stderr if it happened while serializing - print("PySpark worker failed with exception:", file=sys.stderr) - print(traceback.format_exc(), file=sys.stderr) - sys.exit(-1) - finally: - if faulthandler_log_path: - faulthandler.disable() - faulthandler_log_file.close() - os.remove(faulthandler_log_path) - finish_time = time.time() - report_times(outfile, boot_time, init_time, finish_time) - write_long(shuffle.MemoryBytesSpilled, outfile) - write_long(shuffle.DiskBytesSpilled, outfile) - - # Mark the beginning of the accumulators section of the output - write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - write_int(len(_accumulatorRegistry), outfile) - for (aid, accum) in _accumulatorRegistry.items(): - pickleSer._write_with_length((aid, accum._value), outfile) - - # check end of stream - if read_int(infile) == SpecialLengths.END_OF_STREAM: - write_int(SpecialLengths.END_OF_STREAM, outfile) - else: - # write a different value to tell JVM to not reuse this worker - write_int(SpecialLengths.END_OF_DATA_SECTION, outfile) - sys.exit(-1) - - -if __name__ == "__main__": - # Read information about how to connect back to the JVM from the environment. - java_port = int(os.environ["PYTHON_WORKER_FACTORY_PORT"]) - auth_secret = os.environ["PYTHON_WORKER_FACTORY_SECRET"] - (sock_file, _) = local_connect_and_auth(java_port, auth_secret) - write_int(os.getpid(), sock_file) - sock_file.flush() - main(sock_file, sock_file) diff --git a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala index 233d759bc1..9619837691 100644 --- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala +++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala @@ -70,7 +70,7 @@ object SedonaContext { val sedonaArrowStrategy = Try( Class - .forName("org.apache.spark.sql.execution.python.SedonaArrowStrategy") + .forName("org.apache.spark.sql.udf.SedonaArrowStrategy") .getDeclaredConstructor() .newInstance() .asInstanceOf[SparkStrategy]) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala deleted file mode 100644 index b89fe93890..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala +++ /dev/null @@ -1,515 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark - -import java.io.File -import java.net.Socket -import java.util.Locale - -import scala.collection.JavaConverters._ -import scala.collection.concurrent -import scala.collection.mutable -import scala.util.Properties - -import com.google.common.cache.CacheBuilder -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.PythonWorkerFactory -import org.apache.spark.broadcast.BroadcastManager -import org.apache.spark.executor.ExecutorBackend -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.internal.config._ -import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager} -import org.apache.spark.metrics.{MetricsSystem, MetricsSystemInstances} -import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} -import org.apache.spark.network.shuffle.ExternalBlockStoreClient -import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} -import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint -import org.apache.spark.security.CryptoStreamUtils -import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager -import org.apache.spark.storage._ -import org.apache.spark.util.{RpcUtils, Utils} - -/** - * :: DeveloperApi :: Holds all the runtime environment objects for a running Spark instance - * (either master or worker), including the serializer, RpcEnv, block manager, map output tracker, - * etc. Currently Spark code finds the SparkEnv through a global variable, so all the threads can - * access the same SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a - * SparkContext). - */ -@DeveloperApi -class SedonaSparkEnv( - val executorId: String, - private[spark] val rpcEnv: RpcEnv, - val serializer: Serializer, - val closureSerializer: Serializer, - val serializerManager: SerializerManager, - val mapOutputTracker: MapOutputTracker, - val shuffleManager: ShuffleManager, - val broadcastManager: BroadcastManager, - val blockManager: BlockManager, - val securityManager: SecurityManager, - val metricsSystem: MetricsSystem, - val memoryManager: MemoryManager, - val outputCommitCoordinator: OutputCommitCoordinator, - val conf: SparkConf) - extends Logging { - - @volatile private[spark] var isStopped = false - private val pythonWorkers = - mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() - - // A general, soft-reference map for metadata needed during HadoopRDD split computation - // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). - private[spark] val hadoopJobMetadata = - CacheBuilder.newBuilder().maximumSize(1000).softValues().build[String, AnyRef]().asMap() - - private[spark] var driverTmpDir: Option[String] = None - - private[spark] var executorBackend: Option[ExecutorBackend] = None - - private[spark] def stop(): Unit = { - - if (!isStopped) { - isStopped = true - pythonWorkers.values.foreach(_.stop()) - mapOutputTracker.stop() - shuffleManager.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - outputCommitCoordinator.stop() - rpcEnv.shutdown() - rpcEnv.awaitTermination() - - // If we only stop sc, but the driver process still run as a services then we need to delete - // the tmp dir, if not, it will create too many tmp dirs. - // We only need to delete the tmp dir create by driver - driverTmpDir match { - case Some(path) => - try { - Utils.deleteRecursively(new File(path)) - } catch { - case e: Exception => - logWarning(s"Exception while deleting Spark temp dir: $path", e) - } - case None => // We just need to delete tmp dir created by driver, so do nothing on executor - } - } - } - - private[spark] def createPythonWorker( - pythonExec: String, - envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { - synchronized { - val key = (pythonExec, envVars) - pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() - } - } - - private[spark] def destroyPythonWorker( - pythonExec: String, - envVars: Map[String, String], - worker: Socket): Unit = { - synchronized { - val key = (pythonExec, envVars) - pythonWorkers.get(key).foreach(_.stopWorker(worker)) - } - } - - private[spark] def releasePythonWorker( - pythonExec: String, - envVars: Map[String, String], - worker: Socket): Unit = { - synchronized { - val key = (pythonExec, envVars) - pythonWorkers.get(key).foreach(_.releaseWorker(worker)) - } - } -} - -object SedonaSparkEnv extends Logging { - @volatile private var env: SedonaSparkEnv = _ - - private[spark] val driverSystemName = "sparkDriver" - private[spark] val executorSystemName = "sparkExecutor" - - def set(e: SedonaSparkEnv): Unit = { - env = e - } - - /** - * Returns the SparkEnv. - */ - def get: SedonaSparkEnv = { - env - } - - /** - * Create a SparkEnv for the driver. - */ - private[spark] def createDriverEnv( - conf: SparkConf, - isLocal: Boolean, - listenerBus: LiveListenerBus, - numCores: Int, - sparkContext: SparkContext, - mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { - assert( - conf.contains(DRIVER_HOST_ADDRESS), - s"${DRIVER_HOST_ADDRESS.key} is not set on the driver!") - assert(conf.contains(DRIVER_PORT), s"${DRIVER_PORT.key} is not set on the driver!") - val bindAddress = conf.get(DRIVER_BIND_ADDRESS) - val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS) - val port = conf.get(DRIVER_PORT) - val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { - Some(CryptoStreamUtils.createKey(conf)) - } else { - None - } - create( - conf, - SparkContext.DRIVER_IDENTIFIER, - bindAddress, - advertiseAddress, - Option(port), - isLocal, - numCores, - ioEncryptionKey, - listenerBus = listenerBus, - Option(sparkContext), - mockOutputCommitCoordinator = mockOutputCommitCoordinator) - } - - /** - * Create a SparkEnv for an executor. In coarse-grained mode, the executor provides an RpcEnv - * that is already instantiated. - */ - private[spark] def createExecutorEnv( - conf: SparkConf, - executorId: String, - bindAddress: String, - hostname: String, - numCores: Int, - ioEncryptionKey: Option[Array[Byte]], - isLocal: Boolean): SparkEnv = { - val env = - create(conf, executorId, bindAddress, hostname, None, isLocal, numCores, ioEncryptionKey) - SparkEnv.set(env) - env - } - - private[spark] def createExecutorEnv( - conf: SparkConf, - executorId: String, - hostname: String, - numCores: Int, - ioEncryptionKey: Option[Array[Byte]], - isLocal: Boolean): SparkEnv = { - createExecutorEnv(conf, executorId, hostname, hostname, numCores, ioEncryptionKey, isLocal) - } - - /** - * Helper method to create a SparkEnv for a driver or an executor. - */ - // scalastyle:off argcount - private def create( - conf: SparkConf, - executorId: String, - bindAddress: String, - advertiseAddress: String, - port: Option[Int], - isLocal: Boolean, - numUsableCores: Int, - ioEncryptionKey: Option[Array[Byte]], - listenerBus: LiveListenerBus = null, - sc: Option[SparkContext] = None, - mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { - // scalastyle:on argcount - - val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER - - // Listener bus is only used on the driver - if (isDriver) { - assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") - } - val authSecretFileConf = if (isDriver) AUTH_SECRET_FILE_DRIVER else AUTH_SECRET_FILE_EXECUTOR - val securityManager = new SecurityManager(conf, ioEncryptionKey, authSecretFileConf) - if (isDriver) { - securityManager.initializeAuth() - } - - ioEncryptionKey.foreach { _ => - if (!securityManager.isEncryptionEnabled()) { - logWarning( - "I/O encryption enabled without RPC encryption: keys will be visible on the " + - "wire.") - } - } - - val systemName = if (isDriver) driverSystemName else executorSystemName - val rpcEnv = RpcEnv.create( - systemName, - bindAddress, - advertiseAddress, - port.getOrElse(-1), - conf, - securityManager, - numUsableCores, - !isDriver) - - // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. - if (isDriver) { - conf.set(DRIVER_PORT, rpcEnv.address.port) - } - - val serializer = Utils.instantiateSerializerFromConf[Serializer](SERIALIZER, conf, isDriver) - logDebug(s"Using serializer: ${serializer.getClass}") - - val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) - - val closureSerializer = new JavaSerializer(conf) - - def registerOrLookupEndpoint( - name: String, - endpointCreator: => RpcEndpoint): RpcEndpointRef = { - if (isDriver) { - logInfo("Registering " + name) - rpcEnv.setupEndpoint(name, endpointCreator) - } else { - RpcUtils.makeDriverRef(name, conf, rpcEnv) - } - } - - val broadcastManager = new BroadcastManager(isDriver, conf) - - val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf, broadcastManager, isLocal) - } else { - new MapOutputTrackerWorker(conf) - } - - // Have to assign trackerEndpoint after initialization as MapOutputTrackerEndpoint - // requires the MapOutputTracker itself - mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint( - MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint( - rpcEnv, - mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], - conf)) - - // Let the user specify short names for shuffle managers - val shortShuffleMgrNames = Map( - "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName, - "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName) - val shuffleMgrName = conf.get(config.SHUFFLE_MANAGER) - val shuffleMgrClass = - shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName) - val shuffleManager = - Utils.instantiateSerializerOrShuffleManager[ShuffleManager](shuffleMgrClass, conf, isDriver) - - val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores) - - val blockManagerPort = if (isDriver) { - conf.get(DRIVER_BLOCK_MANAGER_PORT) - } else { - conf.get(BLOCK_MANAGER_PORT) - } - - val externalShuffleClient = if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { - val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - Some( - new ExternalBlockStoreClient( - transConf, - securityManager, - securityManager.isAuthenticationEnabled(), - conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT))) - } else { - None - } - - // Mapping from block manager id to the block manager's information. - val blockManagerInfo = new concurrent.TrieMap[BlockManagerId, BlockManagerInfo]() - val blockManagerMaster = new BlockManagerMaster( - registerOrLookupEndpoint( - BlockManagerMaster.DRIVER_ENDPOINT_NAME, - new BlockManagerMasterEndpoint( - rpcEnv, - isLocal, - conf, - listenerBus, - if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { - externalShuffleClient - } else { - None - }, - blockManagerInfo, - mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], - shuffleManager, - isDriver)), - registerOrLookupEndpoint( - BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME, - new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)), - conf, - isDriver) - - val blockTransferService = - new NettyBlockTransferService( - conf, - securityManager, - serializerManager, - bindAddress, - advertiseAddress, - blockManagerPort, - numUsableCores, - blockManagerMaster.driverEndpoint) - - // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager( - executorId, - rpcEnv, - blockManagerMaster, - serializerManager, - conf, - memoryManager, - mapOutputTracker, - shuffleManager, - blockTransferService, - securityManager, - externalShuffleClient) - - val metricsSystem = if (isDriver) { - // Don't start metrics system right now for Driver. - // We need to wait for the task scheduler to give us an app ID. - // Then we can start the metrics system. - MetricsSystem.createMetricsSystem(MetricsSystemInstances.DRIVER, conf) - } else { - // We need to set the executor ID before the MetricsSystem is created because sources and - // sinks specified in the metrics configuration file will want to incorporate this executor's - // ID into the metrics they report. - conf.set(EXECUTOR_ID, executorId) - val ms = MetricsSystem.createMetricsSystem(MetricsSystemInstances.EXECUTOR, conf) - ms.start(conf.get(METRICS_STATIC_SOURCES_ENABLED)) - ms - } - - val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { - if (isDriver) { - new OutputCommitCoordinator(conf, isDriver, sc) - } else { - new OutputCommitCoordinator(conf, isDriver) - } - - } - val outputCommitCoordinatorRef = registerOrLookupEndpoint( - "OutputCommitCoordinator", - new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) - outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) - - val envInstance = new SparkEnv( - executorId, - rpcEnv, - serializer, - closureSerializer, - serializerManager, - mapOutputTracker, - shuffleManager, - broadcastManager, - blockManager, - securityManager, - metricsSystem, - memoryManager, - outputCommitCoordinator, - conf) - - // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is - // called, and we only need to do it for driver. Because driver may run as a service, and if we - // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs. - if (isDriver) { - val sparkFilesDir = - Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath - envInstance.driverTmpDir = Some(sparkFilesDir) - } - - envInstance - } - - /** - * Return a map representation of jvm information, Spark properties, system properties, and - * class paths. Map keys define the category, and map values represent the corresponding - * attributes as a sequence of KV pairs. This is used mainly for SparkListenerEnvironmentUpdate. - */ - private[spark] def environmentDetails( - conf: SparkConf, - hadoopConf: Configuration, - schedulingMode: String, - addedJars: Seq[String], - addedFiles: Seq[String], - addedArchives: Seq[String], - metricsProperties: Map[String, String]): Map[String, Seq[(String, String)]] = { - - import Properties._ - val jvmInformation = Seq( - ("Java Version", s"$javaVersion ($javaVendor)"), - ("Java Home", javaHome), - ("Scala Version", versionString)).sorted - - // Spark properties - // This includes the scheduling mode whether or not it is configured (used by SparkUI) - val schedulerMode = - if (!conf.contains(SCHEDULER_MODE)) { - Seq((SCHEDULER_MODE.key, schedulingMode)) - } else { - Seq.empty[(String, String)] - } - val sparkProperties = (conf.getAll ++ schedulerMode).sorted - - // System properties that are not java classpaths - val systemProperties = Utils.getSystemProperties.toSeq - val otherProperties = systemProperties.filter { case (k, _) => - k != "java.class.path" && !k.startsWith("spark.") - }.sorted - - // Class paths including all added jars and files - val classPathEntries = javaClassPath - .split(File.pathSeparator) - .filterNot(_.isEmpty) - .map((_, "System Classpath")) - val addedJarsAndFiles = (addedJars ++ addedFiles ++ addedArchives).map((_, "Added By User")) - val classPaths = (addedJarsAndFiles ++ classPathEntries).sorted - - // Add Hadoop properties, it will not ignore configs including in Spark. Some spark - // conf starting with "spark.hadoop" may overwrite it. - val hadoopProperties = hadoopConf.asScala - .map(entry => (entry.getKey, entry.getValue)) - .toSeq - .sorted - Map[String, Seq[(String, String)]]( - "JVM Information" -> jvmInformation, - "Spark Properties" -> sparkProperties, - "Hadoop Properties" -> hadoopProperties, - "System Properties" -> otherProperties, - "Classpath Entries" -> classPaths, - "Metrics Properties" -> metricsProperties.toSeq.sorted) - } -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala deleted file mode 100644 index c510d0cd93..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala +++ /dev/null @@ -1,742 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.api.python - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import org.apache.spark._ -import org.apache.spark.SedonaSparkEnv -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.Python._ -import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} -import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} -import org.apache.spark.security.SocketAuthHelper -import org.apache.spark.util._ - -import java.io._ -import java.net._ -import java.nio.charset.StandardCharsets -import java.nio.charset.StandardCharsets.UTF_8 -import java.nio.file.{Path, Files => JavaFiles} -import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.JavaConverters._ -import scala.util.control.NonFatal - -/** - * Enumerate the type of command that will be sent to the Python worker - */ -private[spark] object PythonEvalType { - val NON_UDF = 0 - - val SQL_BATCHED_UDF = 100 - val SQL_ARROW_BATCHED_UDF = 101 - - val SQL_SCALAR_PANDAS_UDF = 200 - val SQL_GROUPED_MAP_PANDAS_UDF = 201 - val SQL_GROUPED_AGG_PANDAS_UDF = 202 - val SQL_WINDOW_AGG_PANDAS_UDF = 203 - val SQL_SCALAR_PANDAS_ITER_UDF = 204 - val SQL_MAP_PANDAS_ITER_UDF = 205 - val SQL_COGROUPED_MAP_PANDAS_UDF = 206 - val SQL_MAP_ARROW_ITER_UDF = 207 - val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 - - val SQL_TABLE_UDF = 300 - val SQL_ARROW_TABLE_UDF = 301 - - def toString(pythonEvalType: Int): String = pythonEvalType match { - case NON_UDF => "NON_UDF" - case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" - case SQL_ARROW_BATCHED_UDF => "SQL_ARROW_BATCHED_UDF" - case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" - case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" - case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" - case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" - case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF" - case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" - case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" - case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" - case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE" - case SQL_TABLE_UDF => "SQL_TABLE_UDF" - case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF" - } -} - -private object SedonaBasePythonRunner { - - private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler") - - private def faultHandlerLogPath(pid: Int): Path = { - new File(faultHandlerLogDir, pid.toString).toPath - } -} - -/** - * A helper class to run Python mapPartition/UDFs in Spark. - * - * funcs is a list of independent Python functions, each one of them is a list of chained Python - * functions (from bottom to top). - */ -private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( - protected val funcs: Seq[ChainedPythonFunctions], - protected val evalType: Int, - protected val argOffsets: Array[Array[Int]], - protected val jobArtifactUUID: Option[String]) - extends Logging { - - require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") - - private val conf = SparkEnv.get.conf - protected val bufferSize: Int = conf.get(BUFFER_SIZE) - protected val authSocketTimeout = conf.get(PYTHON_AUTH_SOCKET_TIMEOUT) - private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) - private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) - protected val simplifiedTraceback: Boolean = false - - // All the Python functions should have the same exec, version and envvars. - protected val envVars: java.util.Map[String, String] = funcs.head.funcs.head.envVars - protected val pythonExec: String = funcs.head.funcs.head.pythonExec - protected val pythonVer: String = funcs.head.funcs.head.pythonVer - - // TODO: support accumulator in multiple UDF - protected val accumulator: PythonAccumulatorV2 = funcs.head.funcs.head.accumulator - - // Python accumulator is always set in production except in tests. See SPARK-27893 - private val maybeAccumulator: Option[PythonAccumulatorV2] = Option(accumulator) - - // Expose a ServerSocket to support method calls via socket from Python side. - private[spark] var serverSocket: Option[ServerSocket] = None - - // Authentication helper used when serving method calls via socket from Python side. - private lazy val authHelper = new SocketAuthHelper(conf) - - // each python worker gets an equal part of the allocation. the worker pool will grow to the - // number of concurrent tasks, which is determined by the number of cores in this executor. - private def getWorkerMemoryMb(mem: Option[Long], cores: Int): Option[Long] = { - mem.map(_ / cores) - } - - def compute( - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext): Iterator[OUT] = { - val startTime = System.currentTimeMillis - val sedonaEnv = SedonaSparkEnv.get - val env = SparkEnv.get - - // Get the executor cores and pyspark memory, they are passed via the local properties when - // the user specified them in a ResourceProfile. - val execCoresProp = Option(context.getLocalProperty(EXECUTOR_CORES_LOCAL_PROPERTY)) - val memoryMb = Option(context.getLocalProperty(PYSPARK_MEMORY_LOCAL_PROPERTY)).map(_.toLong) - val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") - // If OMP_NUM_THREADS is not explicitly set, override it with the number of task cpus. - // See SPARK-42613 for details. - if (conf.getOption("spark.executorEnv.OMP_NUM_THREADS").isEmpty) { - envVars.put("OMP_NUM_THREADS", conf.get("spark.task.cpus", "1")) - } - envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread - if (reuseWorker) { - envVars.put("SPARK_REUSE_WORKER", "1") - } - if (simplifiedTraceback) { - envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") - } - // SPARK-30299 this could be wrong with standalone mode when executor - // cores might not be correct because it defaults to all cores on the box. - val execCores = execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES)) - val workerMemoryMb = getWorkerMemoryMb(memoryMb, execCores) - if (workerMemoryMb.isDefined) { - envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString) - } - envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) - envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) - if (faultHandlerEnabled) { - envVars.put("PYTHON_FAULTHANDLER_DIR", SedonaBasePythonRunner.faultHandlerLogDir.toString) - } - - envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) - - val (worker: Socket, pid: Option[Int]) = - env.createPythonWorker(pythonExec, envVars.asScala.toMap) - // Whether is the worker released into idle pool or closed. When any codes try to release or - // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make - // sure there is only one winner that is going to release or close the worker. - val releasedOrClosed = new AtomicBoolean(false) - - // Start a thread to feed the process input from our parent's iterator - val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) - - context.addTaskCompletionListener[Unit] { _ => - writerThread.shutdownOnTaskCompletion() - if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) { - try { - worker.close() - } catch { - case e: Exception => - logWarning("Failed to close worker socket", e) - } - } - } - - writerThread.start() - new WriterMonitorThread(SparkEnv.get, worker, writerThread, context).start() - if (reuseWorker) { - val key = (worker, context.taskAttemptId) - // SPARK-35009: avoid creating multiple monitor threads for the same python worker - // and task context - if (PythonRunner.runningMonitorThreads.add(key)) { - new MonitorThread(SparkEnv.get, worker, context).start() - } - } else { - new MonitorThread(SparkEnv.get, worker, context).start() - } - - // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) - - val stdoutIterator = newReaderIterator( - stream, - writerThread, - startTime, - env, - worker, - pid, - releasedOrClosed, - context) - new InterruptibleIterator(context, stdoutIterator) - } - - protected def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext): WriterThread - - protected def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[OUT] - - /** - * The thread responsible for writing the data from the PythonRDD's parent iterator to the - * Python process. - */ - abstract class WriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext) - extends Thread(s"stdout writer for $pythonExec") { - - @volatile private var _exception: Throwable = null - - private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet - private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) - - setDaemon(true) - - /** Contains the throwable thrown while writing the parent iterator to the Python process. */ - def exception: Option[Throwable] = Option(_exception) - - /** - * Terminates the writer thread and waits for it to exit, ignoring any exceptions that may - * occur due to cleanup. - */ - def shutdownOnTaskCompletion(): Unit = { - assert(context.isCompleted) - this.interrupt() - // Task completion listeners that run after this method returns may invalidate - // `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized - // reader, a task completion listener will free the underlying off-heap buffers. If the writer - // thread is still running when `inputIterator` is invalidated, it can cause a use-after-free - // bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer - // thread to exit before returning. - this.join() - } - - /** - * Writes a command section to the stream connected to the Python worker. - */ - protected def writeCommand(dataOut: DataOutputStream): Unit - - /** - * Writes input data to the stream connected to the Python worker. - */ - protected def writeIteratorToStream(dataOut: DataOutputStream): Unit - - override def run(): Unit = Utils.logUncaughtExceptions { - try { - TaskContext.setTaskContext(context) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - // Partition index - dataOut.writeInt(partitionIndex) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) - // Init a ServerSocket to accept method calls from Python side. - val isBarrier = context.isInstanceOf[BarrierTaskContext] - if (isBarrier) { - serverSocket = Some( - new ServerSocket( - /* port */ 0, - /* backlog */ 1, - InetAddress.getByName("localhost"))) - // A call to accept() for ServerSocket shall block infinitely. - serverSocket.foreach(_.setSoTimeout(0)) - new Thread("accept-connections") { - setDaemon(true) - - override def run(): Unit = { - while (!serverSocket.get.isClosed()) { - var sock: Socket = null - try { - sock = serverSocket.get.accept() - // Wait for function call from python side. - sock.setSoTimeout(10000) - authHelper.authClient(sock) - val input = new DataInputStream(sock.getInputStream()) - val requestMethod = input.readInt() - // The BarrierTaskContext function may wait infinitely, socket shall not timeout - // before the function finishes. - sock.setSoTimeout(0) - requestMethod match { - case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - barrierAndServe(requestMethod, sock) - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val length = input.readInt() - val message = new Array[Byte](length) - input.readFully(message) - barrierAndServe(requestMethod, sock, new String(message, UTF_8)) - case _ => - val out = - new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) - } - } catch { - case e: SocketException if e.getMessage.contains("Socket closed") => - // It is possible that the ServerSocket is not closed, but the native socket - // has already been closed, we shall catch and silently ignore this case. - } finally { - if (sock != null) { - sock.close() - } - } - } - } - }.start() - } - val secret = if (isBarrier) { - authHelper.secret - } else { - "" - } - // Close ServerSocket on task completion. - serverSocket.foreach { server => - context.addTaskCompletionListener[Unit](_ => server.close()) - } - val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) - if (boundPort == -1) { - val message = "ServerSocket failed to bind to Java side." - logError(message) - throw new SparkException(message) - } else if (isBarrier) { - logDebug(s"Started ServerSocket on port $boundPort.") - } - // Write out the TaskContextInfo - dataOut.writeBoolean(isBarrier) - dataOut.writeInt(boundPort) - val secretBytes = secret.getBytes(UTF_8) - dataOut.writeInt(secretBytes.length) - dataOut.write(secretBytes, 0, secretBytes.length) - dataOut.writeInt(context.stageId()) - dataOut.writeInt(context.partitionId()) - dataOut.writeInt(context.attemptNumber()) - dataOut.writeLong(context.taskAttemptId()) - dataOut.writeInt(context.cpus()) - val resources = context.resources() - dataOut.writeInt(resources.size) - resources.foreach { case (k, v) => - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v.name, dataOut) - dataOut.writeInt(v.addresses.size) - v.addresses.foreach { case addr => - PythonRDD.writeUTF(addr, dataOut) - } - } - val localProps = context.getLocalProperties.asScala - dataOut.writeInt(localProps.size) - localProps.foreach { case (k, v) => - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v, dataOut) - } - - // sparkFilesDir - val root = jobArtifactUUID - .map { uuid => - new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath - } - .getOrElse(SparkFiles.getRootDirectory()) - PythonRDD.writeUTF(root, dataOut) - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - // Broadcast variables - val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet - // number of different broadcasts - val toRemove = oldBids.diff(newBids) - val addedBids = newBids.diff(oldBids) - val cnt = toRemove.size + addedBids.size - val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty - dataOut.writeBoolean(needsDecryptionServer) - dataOut.writeInt(cnt) - def sendBidsToRemove(): Unit = { - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(-bid - 1) // bid >= 0 - oldBids.remove(bid) - } - } - if (needsDecryptionServer) { - // if there is encryption, we setup a server which reads the encrypted files, and sends - // the decrypted data to python - val idsAndFiles = broadcastVars.flatMap { broadcast => - if (!oldBids.contains(broadcast.id)) { - oldBids.add(broadcast.id) - Some((broadcast.id, broadcast.value.path)) - } else { - None - } - } - val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) - dataOut.writeInt(server.port) - logTrace(s"broadcast decryption server setup on ${server.port}") - PythonRDD.writeUTF(server.secret, dataOut) - sendBidsToRemove() - idsAndFiles.foreach { case (id, _) => - // send new broadcast - dataOut.writeLong(id) - } - dataOut.flush() - logTrace("waiting for python to read decrypted broadcast data from server") - server.waitTillBroadcastDataSent() - logTrace("done sending decrypted data to python") - } else { - sendBidsToRemove() - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { - // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) - } - } - } - dataOut.flush() - - dataOut.writeInt(evalType) - writeCommand(dataOut) - writeIteratorToStream(dataOut) - - dataOut.writeInt(SpecialLengths.END_OF_STREAM) - dataOut.flush() - } catch { - case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => - if (context.isCompleted || context.isInterrupted) { - logDebug( - "Exception/NonFatal Error thrown after task completion (likely due to " + - "cleanup)", - t) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } else { - // We must avoid throwing exceptions/NonFatals here, because the thread uncaught - // exception handler will kill the whole executor (see - // org.apache.spark.executor.Executor). - _exception = t - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } - } - } - - /** - * Gateway to call BarrierTaskContext methods. - */ - def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { - require( - serverSocket.isDefined, - "No available ServerSocket to redirect the BarrierTaskContext method call.") - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - try { - val messages = requestMethod match { - case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - context.asInstanceOf[BarrierTaskContext].barrier() - Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS) - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - context.asInstanceOf[BarrierTaskContext].allGather(message) - } - out.writeInt(messages.length) - messages.foreach(writeUTF(_, out)) - } catch { - case e: SparkException => - writeUTF(e.getMessage, out) - } finally { - out.close() - } - } - - def writeUTF(str: String, dataOut: DataOutputStream): Unit = { - val bytes = str.getBytes(UTF_8) - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } - } - - abstract class ReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext) - extends Iterator[OUT] { - - private var nextObj: OUT = _ - private var eos = false - - override def hasNext: Boolean = nextObj != null || { - if (!eos) { - nextObj = read() - hasNext - } else { - false - } - } - - override def next(): OUT = { - if (hasNext) { - val obj = nextObj - nextObj = null.asInstanceOf[OUT] - obj - } else { - Iterator.empty.next() - } - } - - /** - * Reads next object from the stream. When the stream reaches end of data, needs to process - * the following sections, and then returns null. - */ - protected def read(): OUT - - protected def handleTimingData(): Unit = { - // Timing data from worker - val bootTime = stream.readLong() - val initTime = stream.readLong() - val finishTime = stream.readLong() - val boot = bootTime - startTime - val init = initTime - bootTime - val finish = finishTime - initTime - val total = finishTime - startTime - logInfo( - "Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) - val memoryBytesSpilled = stream.readLong() - val diskBytesSpilled = stream.readLong() - context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) - context.taskMetrics.incDiskBytesSpilled(diskBytesSpilled) - } - - protected def handlePythonException(): PythonException = { - // Signals that an exception has been thrown in python - val exLength = stream.readInt() - val obj = new Array[Byte](exLength) - stream.readFully(obj) - new PythonException(new String(obj, StandardCharsets.UTF_8), writerThread.exception.orNull) - } - - protected def handleEndOfDataSection(): Unit = { - // We've finished the data section of the output, but we can still - // read some accumulator updates: - val numAccumulatorUpdates = stream.readInt() - (1 to numAccumulatorUpdates).foreach { _ => - val updateLen = stream.readInt() - val update = new Array[Byte](updateLen) - stream.readFully(update) - maybeAccumulator.foreach(_.add(update)) - } - // Check whether the worker is ready to be re-used. - if (stream.readInt() == SpecialLengths.END_OF_STREAM) { - if (reuseWorker && releasedOrClosed.compareAndSet(false, true)) { - env.releasePythonWorker(pythonExec, envVars.asScala.toMap, worker) - } - } - eos = true - } - - protected val handleException: PartialFunction[Throwable, OUT] = { - case e: Exception if context.isInterrupted => - logDebug("Exception thrown after task interruption", e) - throw new TaskKilledException(context.getKillReason().getOrElse("unknown reason")) - - case e: Exception if writerThread.exception.isDefined => - logError("Python worker exited unexpectedly (crashed)", e) - logError("This may have been caused by a prior exception:", writerThread.exception.get) - throw writerThread.exception.get - - case eof: EOFException - if faultHandlerEnabled && pid.isDefined && - JavaFiles.exists(SedonaBasePythonRunner.faultHandlerLogPath(pid.get)) => - val path = SedonaBasePythonRunner.faultHandlerLogPath(pid.get) - val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n" - JavaFiles.deleteIfExists(path) - throw new SparkException(s"Python worker exited unexpectedly (crashed): $error", eof) - - case eof: EOFException => - throw new SparkException("Python worker exited unexpectedly (crashed)", eof) - } - } - - /** - * It is necessary to have a monitor thread for python workers if the user cancels with - * interrupts disabled. In that case we will need to explicitly kill the worker, otherwise the - * threads can block indefinitely. - */ - class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) - extends Thread(s"Worker Monitor for $pythonExec") { - - /** How long to wait before killing the python worker if a task cannot be interrupted. */ - private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) - - setDaemon(true) - - private def monitorWorker(): Unit = { - // Kill the worker if it is interrupted, checking until task completion. - // TODO: This has a race condition if interruption occurs, as completed may still become true. - while (!context.isInterrupted && !context.isCompleted) { - Thread.sleep(2000) - } - if (!context.isCompleted) { - Thread.sleep(taskKillTimeout) - if (!context.isCompleted) { - try { - // Mimic the task name used in `Executor` to help the user find out the task to blame. - val taskName = s"${context.partitionId}.${context.attemptNumber} " + - s"in stage ${context.stageId} (TID ${context.taskAttemptId})" - logWarning(s"Incomplete task $taskName interrupted: Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - } - - override def run(): Unit = { - try { - monitorWorker() - } finally { - if (reuseWorker) { - val key = (worker, context.taskAttemptId) - PythonRunner.runningMonitorThreads.remove(key) - } - } - } - } - - /** - * This thread monitors the WriterThread and kills it in case of deadlock. - * - * A deadlock can arise if the task completes while the writer thread is sending input to the - * Python process (e.g. due to the use of `take()`), and the Python process is still producing - * output. When the inputs are sufficiently large, this can result in a deadlock due to the use - * of blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the socket. - */ - class WriterMonitorThread( - env: SparkEnv, - worker: Socket, - writerThread: WriterThread, - context: TaskContext) - extends Thread(s"Writer Monitor for $pythonExec (writer thread id ${writerThread.getId})") { - - /** - * How long to wait before closing the socket if the writer thread has not exited after the - * task ends. - */ - private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) - - setDaemon(true) - - override def run(): Unit = { - // Wait until the task is completed (or the writer thread exits, in which case this thread has - // nothing to do). - while (!context.isCompleted && writerThread.isAlive) { - Thread.sleep(2000) - } - if (writerThread.isAlive) { - Thread.sleep(taskKillTimeout) - // If the writer thread continues running, this indicates a deadlock. Kill the worker to - // resolve the deadlock. - if (writerThread.isAlive) { - try { - // Mimic the task name used in `Executor` to help the user find out the task to blame. - val taskName = s"${context.partitionId}.${context.attemptNumber} " + - s"in stage ${context.stageId} (TID ${context.taskAttemptId})" - logWarning( - s"Detected deadlock while completing task $taskName: " + - "Attempting to kill Python Worker") - env.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) - } catch { - case e: Exception => - logError("Exception when trying to kill worker", e) - } - } - } - } - } -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala deleted file mode 100644 index 976d034e08..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala +++ /dev/null @@ -1,93 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.execution.python - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import org.apache.spark.api.python._ -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types._ -import org.apache.spark.sql.vectorized.ColumnarBatch - -/** - * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. - */ -class SedonaArrowPythonRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - protected override val schema: StructType, - protected override val timeZoneId: String, - protected override val largeVarTypes: Boolean, - protected override val workerConf: Map[String, String], - val pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) - extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, - evalType, - argOffsets, - jobArtifactUUID) - with SedonaBasicPythonArrowInput - with SedonaBasicPythonArrowOutput { - - override val pythonExec: String = - SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(funcs.head.funcs.head.pythonExec) - - override val errorOnDuplicatedFieldNames: Boolean = true - - override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - - override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize - require( - bufferSize >= 4, - "Pandas execution requires more than 4 bytes. Please set higher buffer. " + - s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") -} - -object SedonaArrowPythonRunner { - - /** Return Map with conf settings to be used in ArrowPythonRunner */ - def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { - val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) - val pandasColsByName = Seq( - SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> - conf.pandasGroupedMapAssignColumnsByName.toString) - val arrowSafeTypeCheck = Seq( - SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key -> - conf.arrowSafeTypeConversion.toString) - Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*) - } -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala deleted file mode 100644 index 375d6536ca..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.execution.python - -import org.apache.sedona.sql.UDF.PythonEvalType -import org.apache.spark.api.python.ChainedPythonFunctions -import org.apache.spark.rdd.RDD -import org.apache.spark.sql.Strategy -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.InternalRow.copyValue -import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences -import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, UnsafeProjection} -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection.createObject -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.vectorized.{ColumnarBatchRow, ColumnarRow} -//import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences -//import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection -import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, CodeGeneratorWithInterpretedFallback, Expression, InterpretedUnsafeProjection, JoinedRow, MutableProjection, Projection, PythonUDF, UnsafeRow} -import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan -import org.apache.spark.sql.catalyst.util.GenericArrayData -import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.internal.SQLConf -import org.apache.spark.sql.types.{DataType, StructField, StructType} -import org.apache.spark.sql.udf.SedonaArrowEvalPython -import org.apache.spark.util.Utils -import org.apache.spark.{ContextAwareIterator, JobArtifactSet, SparkEnv, TaskContext} -import org.locationtech.jts.io.WKTReader -import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder -import java.io.File -import scala.collection.JavaConverters.asScalaIteratorConverter -import scala.collection.mutable.ArrayBuffer - -// We use custom Strategy to avoid Apache Spark assert on types, we -// can consider extending this to support other engines working with -// arrow data -class SedonaArrowStrategy extends Strategy { - override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { - case SedonaArrowEvalPython(udfs, output, child, evalType) => - SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil - case _ => Nil - } -} - -/** - * The factory object for `UnsafeProjection`. - */ -object SedonaUnsafeProjection { - - def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - GenerateUnsafeProjection.generate( - bindReferences(exprs, inputSchema), - SQLConf.get.subexpressionEliminationEnabled) -// createObject(bindReferences(exprs, inputSchema)) - } -} -// It's modification og Apache Spark's ArrowEvalPythonExec, we remove the check on the types to allow geometry types -// here, it's initial version to allow the vectorized udf for Sedona geometry types. We can consider extending this -// to support other engines working with arrow data -case class SedonaArrowEvalPythonExec( - udfs: Seq[PythonUDF], - resultAttrs: Seq[Attribute], - child: SparkPlan, - evalType: Int) - extends EvalPythonExec - with PythonSQLMetrics { - - private val batchSize = conf.arrowMaxRecordsPerBatch - private val sessionLocalTimeZone = conf.sessionLocalTimeZone - private val largeVarTypes = conf.arrowUseLargeVarTypes - private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) - private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) - - protected override def evaluate( - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]], - iter: Iterator[InternalRow], - schema: StructType, - context: TaskContext): Iterator[InternalRow] = { - - val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) - - val columnarBatchIter = new SedonaArrowPythonRunner( - funcs, - evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, - argOffsets, - schema, - sessionLocalTimeZone, - largeVarTypes, - pythonRunnerConf, - pythonMetrics, - jobArtifactUUID).compute(batchIter, context.partitionId(), context) - - columnarBatchIter.flatMap { batch => - batch.rowIterator.asScala - } - } - - override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = - copy(child = newChild) - - private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { - udf.children match { - case Seq(u: PythonUDF) => - val (chained, children) = collectFunctions(u) - (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) - case children => - // There should not be any other UDFs, or the children can't be evaluated directly. - assert(children.forall(!_.exists(_.isInstanceOf[PythonUDF]))) - (ChainedPythonFunctions(Seq(udf.func)), udf.children) - } - } - - override def doExecute(): RDD[InternalRow] = { - - val customProjection = new Projection with Serializable { - def apply(row: InternalRow): InternalRow = { - row match { - case joinedRow: JoinedRow => - val arrowField = joinedRow.getRight.asInstanceOf[ColumnarBatchRow] - val left = joinedRow.getLeft - -// resultAttrs.zipWithIndex.map { -// case (x, y) => -// if (x.dataType.isInstanceOf[GeometryUDT]) { -// val wkbReader = new org.locationtech.jts.io.WKBReader() -// wkbReader.read(left.getBinary(y)) -// -// println("ssss") -// } -// GeometryUDT -// left.getByte(y) -// -// left.setByte(y, 1.toByte) -// -// println(left.getByte(y)) -// } -// -// println("ssss") -// arrowField. - row - // We need to convert JoinedRow to UnsafeRow -// val leftUnsafe = left.asInstanceOf[UnsafeRow] -// val rightUnsafe = right.asInstanceOf[UnsafeRow] -// val joinedUnsafe = new UnsafeRow(leftUnsafe.numFields + rightUnsafe.numFields) -// joinedUnsafe.pointTo( -// leftUnsafe.getBaseObject, leftUnsafe.getBaseOffset, -// leftUnsafe.getSizeInBytes + rightUnsafe.getSizeInBytes) -// joinedUnsafe.setLeft(rightUnsafe) -// joinedUnsafe.setRight(leftUnsafe) -// joinedUnsafe -// val wktReader = new WKTReader() - val resultProj = SedonaUnsafeProjection.create(output, output) -// val WKBWriter = new org.locationtech.jts.io.WKBWriter() - resultProj(new JoinedRow(left, arrowField)) - case _ => - println(row.getClass) - throw new UnsupportedOperationException("Unsupported row type") - } - } - } - val inputRDD = child.execute().map(_.copy()) - - inputRDD.mapPartitions { iter => - val context = TaskContext.get() - val contextAwareIterator = new ContextAwareIterator(context, iter) - - // The queue used to buffer input rows so we can drain it to - // combine input with output from Python. - val queue = HybridRowQueue( - context.taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), - child.output.length) - context.addTaskCompletionListener[Unit] { ctx => - queue.close() - } - - val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip - - // flatten all the arguments - val allInputs = new ArrayBuffer[Expression] - val dataTypes = new ArrayBuffer[DataType] - val argOffsets = inputs.map { input => - input.map { e => - if (allInputs.exists(_.semanticEquals(e))) { - allInputs.indexWhere(_.semanticEquals(e)) - } else { - allInputs += e - dataTypes += e.dataType - allInputs.length - 1 - } - }.toArray - }.toArray - val projection = MutableProjection.create(allInputs.toSeq, child.output) - projection.initialize(context.partitionId()) - val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => - StructField(s"_$i", dt) - }.toArray) - - // Add rows to queue to join later with the result. - val projectedRowIter = contextAwareIterator.map { inputRow => - queue.add(inputRow.asInstanceOf[UnsafeRow]) - projection(inputRow) - } - - val outputRowIterator = evaluate(pyFuncs, argOffsets, projectedRowIter, schema, context) - - val joined = new JoinedRow - - outputRowIterator.map { outputRow => - val joinedRow = joined(queue.remove(), outputRow) - - val projected = customProjection(joinedRow) - - val numFields = projected.numFields - val startField = numFields - resultAttrs.length - println(resultAttrs.length) - - val row = new GenericInternalRow(numFields) - - resultAttrs.zipWithIndex.map { case (attr, index) => - if (attr.dataType.isInstanceOf[GeometryUDT]) { - // Convert the geometry type to WKB - val wkbReader = new org.locationtech.jts.io.WKBReader() - val wkbWriter = new org.locationtech.jts.io.WKBWriter() - val geom = wkbReader.read(projected.getBinary(startField + index)) - - row.update(startField + index, wkbWriter.write(geom)) - - println("ssss") - } - } - - println("ssss") -// 3.2838116E-8 - row - } - } - } -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala deleted file mode 100644 index bf33cde1c1..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala +++ /dev/null @@ -1,256 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.execution.python - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import java.util.concurrent.atomic.AtomicInteger -import scala.collection.JavaConverters._ -import org.apache.arrow.memory.RootAllocator -import org.apache.arrow.vector.complex.MapVector -import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, IntervalUnit, TimeUnit} -import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema} -import org.apache.spark.sql.errors.ExecutionErrors -import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types._ - -private[sql] object SedonaArrowUtils { - - val rootAllocator = new RootAllocator(Long.MaxValue) - - // todo: support more types. - - /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ - def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = - dt match { - case BooleanType => ArrowType.Bool.INSTANCE - case ByteType => new ArrowType.Int(8, true) - case ShortType => new ArrowType.Int(8 * 2, true) - case IntegerType => new ArrowType.Int(8 * 4, true) - case LongType => new ArrowType.Int(8 * 8, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE - case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE - case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE - case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE - case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) - case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType if timeZoneId == null => - throw new IllegalStateException("Missing timezoneId where it is mandatory.") - case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) - case TimestampNTZType => - new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) - case NullType => ArrowType.Null.INSTANCE - case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) - case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) - case _ => - throw ExecutionErrors.unsupportedDataTypeError(dt) - } - - def fromArrowType(dt: ArrowType): DataType = dt match { - case ArrowType.Bool.INSTANCE => BooleanType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => ByteType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => ShortType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType - case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType - case float: ArrowType.FloatingPoint - if float.getPrecision() == FloatingPointPrecision.SINGLE => - FloatType - case float: ArrowType.FloatingPoint - if float.getPrecision() == FloatingPointPrecision.DOUBLE => - DoubleType - case ArrowType.Utf8.INSTANCE => StringType - case ArrowType.Binary.INSTANCE => BinaryType - case ArrowType.LargeUtf8.INSTANCE => StringType - case ArrowType.LargeBinary.INSTANCE => BinaryType - case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) - case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType - case ts: ArrowType.Timestamp - if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => - TimestampNTZType - case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType - case ArrowType.Null.INSTANCE => NullType - case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => - YearMonthIntervalType() - case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() - case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt) - } - - /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ - def toArrowField( - name: String, - dt: DataType, - nullable: Boolean, - timeZoneId: String, - largeVarTypes: Boolean = false): Field = { - dt match { - case GeometryUDT => - val jsonData = - """{"crs": {"$schema": "https://proj.org/schemas/v0.7/projjson.schema.json", "type": "GeographicCRS", "name": "WGS 84", "datum_ensemble": {"name": "World Geodetic System 1984 ensemble", "members": [{"name": "World Geodetic System 1984 (Transit)", "id": {"authority": "EPSG", "code": 1166}}, {"name": "World Geodetic System 1984 (G730)", "id": {"authority": "EPSG", "code": 1152}}, {"name": "World Geodetic System 1984 (G873)", "id": {"authority": "EPSG", "code": 1153}}, {"name": "W [...] - val metadata = Map( - "ARROW:extension:name" -> "geoarrow.wkb", - "ARROW:extension:metadata" -> jsonData).asJava - - val fieldType = new FieldType(nullable, ArrowType.Binary.INSTANCE, null, metadata) - new Field(name, fieldType, Seq.empty[Field].asJava) - - case ArrayType(elementType, containsNull) => - val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) - new Field( - name, - fieldType, - Seq( - toArrowField("element", elementType, containsNull, timeZoneId, largeVarTypes)).asJava) - case StructType(fields) => - val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) - new Field( - name, - fieldType, - fields - .map { field => - toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes) - } - .toSeq - .asJava) - case MapType(keyType, valueType, valueContainsNull) => - val mapType = new FieldType(nullable, new ArrowType.Map(false), null) - // Note: Map Type struct can not be null, Struct Type key field can not be null - new Field( - name, - mapType, - Seq( - toArrowField( - MapVector.DATA_VECTOR_NAME, - new StructType() - .add(MapVector.KEY_NAME, keyType, nullable = false) - .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), - nullable = false, - timeZoneId, - largeVarTypes)).asJava) - case udt: UserDefinedType[_] => - toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) - case dataType => - val fieldType = - new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes), null) - new Field(name, fieldType, Seq.empty[Field].asJava) - } - } - - def fromArrowField(field: Field): DataType = { - field.getType match { - case _: ArrowType.Map => - val elementField = field.getChildren.get(0) - val keyType = fromArrowField(elementField.getChildren.get(0)) - val valueType = fromArrowField(elementField.getChildren.get(1)) - MapType(keyType, valueType, elementField.getChildren.get(1).isNullable) - case ArrowType.List.INSTANCE => - val elementField = field.getChildren().get(0) - val elementType = fromArrowField(elementField) - ArrayType(elementType, containsNull = elementField.isNullable) - case ArrowType.Struct.INSTANCE => - val fields = field.getChildren().asScala.map { child => - val dt = fromArrowField(child) - StructField(child.getName, dt, child.isNullable) - } - StructType(fields.toArray) - case arrowType => fromArrowType(arrowType) - } - } - - /** - * Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType - */ - def toArrowSchema( - schema: StructType, - timeZoneId: String, - errorOnDuplicatedFieldNames: Boolean, - largeVarTypes: Boolean = false): Schema = { - new Schema(schema.map { field => - toArrowField( - field.name, - deduplicateFieldNames(field.dataType, errorOnDuplicatedFieldNames), - field.nullable, - timeZoneId, - largeVarTypes) - }.asJava) - } - - def fromArrowSchema(schema: Schema): StructType = { - StructType(schema.getFields.asScala.map { field => - val dt = fromArrowField(field) - StructField(field.getName, dt, field.isNullable) - }.toArray) - } - - private def deduplicateFieldNames( - dt: DataType, - errorOnDuplicatedFieldNames: Boolean): DataType = dt match { - case geometryType: GeometryUDT => geometryType - case udt: UserDefinedType[_] => - deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) - case st @ StructType(fields) => - val newNames = if (st.names.toSet.size == st.names.length) { - st.names - } else { - if (errorOnDuplicatedFieldNames) { - throw ExecutionErrors.duplicatedFieldNameInArrowStructError(st.names) - } - val genNawName = st.names.groupBy(identity).map { - case (name, names) if names.length > 1 => - val i = new AtomicInteger() - name -> { () => s"${name}_${i.getAndIncrement()}" } - case (name, _) => name -> { () => name } - } - st.names.map(genNawName(_)()) - } - val newFields = - fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) => - StructField( - name, - deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), - nullable, - metadata) - } - StructType(newFields) - case ArrayType(elementType, containsNull) => - ArrayType(deduplicateFieldNames(elementType, errorOnDuplicatedFieldNames), containsNull) - case MapType(keyType, valueType, valueContainsNull) => - MapType( - deduplicateFieldNames(keyType, errorOnDuplicatedFieldNames), - deduplicateFieldNames(valueType, errorOnDuplicatedFieldNames), - valueContainsNull) - case _ => dt - } -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala deleted file mode 100644 index 178227a66d..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala +++ /dev/null @@ -1,171 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.execution.python - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.ipc.ArrowStreamWriter -import org.apache.spark.api.python -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD, SedonaBasePythonRunner} -import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.execution.arrow.ArrowWriter -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.util.Utils -import org.apache.spark.{SparkEnv, TaskContext} - -import java.io.DataOutputStream -import java.net.Socket - -/** - * A trait that can be mixed-in with [[python.BasePythonRunner]]. It implements the logic from JVM - * (an iterator of internal rows + additional data if required) to Python (Arrow). - */ -private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[IN, _] => - protected val workerConf: Map[String, String] - - protected val schema: StructType - - protected val timeZoneId: String - - protected val errorOnDuplicatedFieldNames: Boolean - - protected val largeVarTypes: Boolean - - protected def pythonMetrics: Map[String, SQLMetric] - - protected def writeIteratorToArrowStream( - root: VectorSchemaRoot, - writer: ArrowStreamWriter, - dataOut: DataOutputStream, - inputIterator: Iterator[IN]): Unit - - protected def writeUDF( - dataOut: DataOutputStream, - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]]): Unit = - SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - - protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { - // Write config for the worker as a number of key -> value pairs of strings - stream.writeInt(workerConf.size) - for ((k, v) <- workerConf) { - PythonRDD.writeUTF(k, stream) - PythonRDD.writeUTF(v, stream) - } - } - - protected override def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext): WriterThread = { - new WriterThread(env, worker, inputIterator, partitionIndex, context) { - - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - handleMetadataBeforeExec(dataOut) - writeUDF(dataOut, funcs, argOffsets) - } - - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = SedonaArrowUtils.toArrowSchema( - schema, - timeZoneId, - errorOnDuplicatedFieldNames, - largeVarTypes) - val allocator = SedonaArrowUtils.rootAllocator.newChildAllocator( - s"stdout writer for $pythonExec", - 0, - Long.MaxValue) - val root = VectorSchemaRoot.create(arrowSchema, allocator) - - Utils.tryWithSafeFinally { - val writer = new ArrowStreamWriter(root, null, dataOut) - writer.start() - - writeIteratorToArrowStream(root, writer, dataOut, inputIterator) - - // end writes footer to the output stream and doesn't clean any resources. - // It could throw exception if the output stream is closed, so it should be - // in the try block. - writer.end() - } { - // If we close root and allocator in TaskCompletionListener, there could be a race - // condition where the writer thread keeps writing to the VectorSchemaRoot while - // it's being closed by the TaskCompletion listener. - // Closing root and allocator here is cleaner because root and allocator is owned - // by the writer thread and is only visible to the writer thread. - // - // If the writer thread is interrupted by TaskCompletionListener, it should either - // (1) in the try block, in which case it will get an InterruptedException when - // performing io, and goes into the finally block or (2) in the finally block, - // in which case it will ignore the interruption and close the resources. - root.close() - allocator.close() - } - } - } - } -} - -private[python] trait SedonaBasicPythonArrowInput - extends SedonaPythonArrowInput[Iterator[InternalRow]] { - self: SedonaBasePythonRunner[Iterator[InternalRow], _] => - - protected def writeIteratorToArrowStream( - root: VectorSchemaRoot, - writer: ArrowStreamWriter, - dataOut: DataOutputStream, - inputIterator: Iterator[Iterator[InternalRow]]): Unit = { - val arrowWriter = ArrowWriter.create(root) - - while (inputIterator.hasNext) { - val startData = dataOut.size() - val nextBatch = inputIterator.next() - - while (nextBatch.hasNext) { - arrowWriter.write(nextBatch.next()) - } - - arrowWriter.finish() - writer.writeBatch() - arrowWriter.reset() - val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData - } - } -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala deleted file mode 100644 index 12f6e60eb9..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala +++ /dev/null @@ -1,169 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.execution.python - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.ipc.ArrowStreamReader -import org.apache.spark.api.python.{BasePythonRunner, SedonaBasePythonRunner, SpecialLengths} -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} -import org.apache.spark.{SparkEnv, TaskContext} - -import java.io.DataInputStream -import java.net.Socket -import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.JavaConverters._ - -/** - * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from Python - * (Arrow) to JVM (output type being deserialized from ColumnarBatch). - */ -private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { - self: SedonaBasePythonRunner[_, OUT] => - - protected def pythonMetrics: Map[String, SQLMetric] - - protected def handleMetadataAfterExec(stream: DataInputStream): Unit = {} - - protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT - - protected def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[OUT] = { - - new ReaderIterator( - stream, - writerThread, - startTime, - env, - worker, - pid, - releasedOrClosed, - context) { - - private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", - 0, - Long.MaxValue) - - private var reader: ArrowStreamReader = _ - private var root: VectorSchemaRoot = _ - private var schema: StructType = _ - private var vectors: Array[ColumnVector] = _ - - context.addTaskCompletionListener[Unit] { _ => - if (reader != null) { - reader.close(false) - } - allocator.close() - } - - private var batchLoaded = true - - protected override def handleEndOfDataSection(): Unit = { - handleMetadataAfterExec(stream) - super.handleEndOfDataSection() - } - - protected override def read(): OUT = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - if (reader != null && batchLoaded) { - val bytesReadStart = reader.bytesRead() - batchLoaded = reader.loadNextBatch() - if (batchLoaded) { - val batch = new ColumnarBatch(vectors) - val rowCount = root.getRowCount - batch.setNumRows(root.getRowCount) - val bytesReadEnd = reader.bytesRead() - pythonMetrics("pythonNumRowsReceived") += rowCount - pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart - val result = deserializeColumnarBatch(batch, schema) - result - } else { - reader.close(false) - allocator.close() - // Reach end of stream. Call `read()` again to read control data. - read() - } - } else { - stream.readInt() match { - case SpecialLengths.START_ARROW_STREAM => - reader = new ArrowStreamReader(stream, allocator) - root = reader.getVectorSchemaRoot() - schema = ArrowUtils.fromArrowSchema(root.getSchema()) - vectors = root - .getFieldVectors() - .asScala - .map { vector => - new ArrowColumnVector(vector) - } - .toArray[ColumnVector] - read() - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null.asInstanceOf[OUT] - } - } - } catch handleException - } - } - } -} - -private[python] trait SedonaBasicPythonArrowOutput - extends SedonaPythonArrowOutput[ColumnarBatch] { - self: SedonaBasePythonRunner[_, ColumnarBatch] => - - protected def deserializeColumnarBatch( - batch: ColumnarBatch, - schema: StructType): ColumnarBatch = batch -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala deleted file mode 100644 index 56bfb782b1..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala +++ /dev/null @@ -1,179 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one - * or more contributor license agreements. See the NOTICE file - * distributed with this work for additional information - * regarding copyright ownership. The ASF licenses this file - * to you under the Apache License, Version 2.0 (the - * "License"); you may not use this file except in compliance - * with the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, - * software distributed under the License is distributed on an - * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY - * KIND, either express or implied. See the License for the - * specific language governing permissions and limitations - * under the License. - */ -package org.apache.spark.sql.execution.python - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import java.io._ -import java.net._ -import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark._ -import org.apache.spark.api.python._ -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf - -/** - * A helper class to run Python UDFs in Spark. - */ -abstract class SedonaBasePythonUDFRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) - extends SedonaBasePythonRunner[Array[Byte], Array[Byte]]( - funcs, - evalType, - argOffsets, - jobArtifactUUID) { - - override val pythonExec: String = - SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(funcs.head.funcs.head.pythonExec) - - override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - - abstract class SedonaPythonUDFWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[Array[Byte]], - partitionIndex: Int, - context: TaskContext) - extends WriterThread(env, worker, inputIterator, partitionIndex, context) { - - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val startData = dataOut.size() - - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) - - val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData - } - } - - protected override def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[Array[Byte]] = { - new ReaderIterator( - stream, - writerThread, - startTime, - env, - worker, - pid, - releasedOrClosed, - context) { - - protected override def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - pythonMetrics("pythonDataReceived") += length - obj - case 0 => Array.emptyByteArray - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } - } catch handleException - } - } - } -} - -class SedonaPythonUDFRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) - extends SedonaBasePythonUDFRunner( - funcs, - evalType, - argOffsets, - pythonMetrics, - jobArtifactUUID) { - - protected override def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[Array[Byte]], - partitionIndex: Int, - context: TaskContext): WriterThread = { - new SedonaPythonUDFWriterThread(env, worker, inputIterator, partitionIndex, context) { - - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - } - - } - } -} - -object SedonaPythonUDFRunner { - - def writeUDFs( - dataOut: DataOutputStream, - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]]): Unit = { - dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => - dataOut.writeInt(offsets.length) - offsets.foreach { offset => - dataOut.writeInt(offset) - } - dataOut.writeInt(chained.funcs.length) - chained.funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command.toArray) - } - } - } -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala new file mode 100644 index 0000000000..a403fa6b9e --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala @@ -0,0 +1,89 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.udf + +import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator, EvalPythonExec, PythonSQLMetrics} +import org.apache.spark.sql.types.StructType + +import scala.collection.JavaConverters.asScalaIteratorConverter + +// We use custom Strategy to avoid Apache Spark assert on types, we +// can consider extending this to support other engines working with +// arrow data +class SedonaArrowStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case SedonaArrowEvalPython(udfs, output, child, evalType) => + SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil + case _ => Nil + } +} + +// It's modification og Apache Spark's ArrowEvalPythonExec, we remove the check on the types to allow geometry types +// here, it's initial version to allow the vectorized udf for Sedona geometry types. We can consider extending this +// to support other engines working with arrow data +case class SedonaArrowEvalPythonExec( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: SparkPlan, + evalType: Int) + extends EvalPythonExec + with PythonSQLMetrics { + + private val batchSize = conf.arrowMaxRecordsPerBatch + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val largeVarTypes = conf.arrowUseLargeVarTypes + private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + protected override def evaluate( + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] = { + + val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) + + val columnarBatchIter = new ArrowPythonRunner( + funcs, + evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, + argOffsets, + schema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID).compute(batchIter, context.partitionId(), context) + + columnarBatchIter.flatMap { batch => + batch.rowIterator.asScala + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala index 0a6b416314..adbb97819f 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.udf import org.apache.sedona.spark.SedonaContext import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.functions.{col, expr} -import org.apache.spark.sql.udf.ScalarUDF.{geometryToGeometryFunction, geometryToNonGeometryFunction, geopandasGeometryToGeometryFunction, nonGeometryToGeometryFunction} +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.udf.ScalarUDF.geoPandasScalaFunction import org.locationtech.jts.io.WKTReader import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers @@ -44,36 +44,24 @@ class StrategySuite extends AnyFunSuite with Matchers { import spark.implicits._ test("sedona geospatial UDF") { -// spark.sql("select 1").show() - val df = spark.read - .format("parquet") - .load("/Users/pawelkocinski/Desktop/projects/sedona-book/apache-sedona-book/book/chapter10/data/buildings/partitioned") - .select( - geometryToNonGeometryFunction(col("geometry")), - geometryToGeometryFunction(col("geometry")), - nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), - col("geohash")) + val df = Seq( + (1, "value", wktReader.read("POINT(21 52)")), + (2, "value1", wktReader.read("POINT(20 50)")), + (3, "value2", wktReader.read("POINT(20 49)")), + (4, "value3", wktReader.read("POINT(20 48)")), + (5, "value4", wktReader.read("POINT(20 47)"))) + .toDF("id", "value", "geom") + .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) - df.show() + df.count shouldEqual 5 -// val df = Seq( -// (1, "value", wktReader.read("POINT(21 52)")), -// (2, "value1", wktReader.read("POINT(20 50)")), -// (3, "value2", wktReader.read("POINT(20 49)")), -// (4, "value3", wktReader.read("POINT(20 48)")), -// (5, "value4", wktReader.read("POINT(20 47)"))) -// .toDF("id", "value", "geom") -// .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) - -// df.count shouldEqual 5 - -// df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") -// .as[String] -// .collect() should contain theSameElementsAs Seq( -// "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", -// "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", -// "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", -// "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", -// "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") + df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") + .as[String] + .collect() should contain theSameElementsAs Seq( + "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", + "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", + "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", + "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", + "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index 1ca705e297..c0a2d8f260 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -19,13 +19,11 @@ package org.apache.spark.sql.udf import org.apache.sedona.sql.UDF -import org.apache.spark.{SparkEnv, TestUtils} +import org.apache.spark.TestUtils import org.apache.spark.api.python._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.config.Python.{PYTHON_DAEMON_MODULE, PYTHON_USE_DAEMON, PYTHON_WORKER_MODULE} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.FloatType import org.apache.spark.util.Utils import java.io.File @@ -72,22 +70,30 @@ object ScalarUDF { finally Utils.deleteRecursively(path) } - val additionalModule = "spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf" - - val geopandasGeometryToNonGeometry: Array[Byte] = { + val pandasFunc: Array[Byte] = { var binaryPandasFunc: Array[Byte] = null withTempPath { path => + println(path) Process( Seq( pythonExec, "-c", f""" - |from pyspark.sql.types import FloatType + |from pyspark.sql.types import IntegerType + |from shapely.geometry import Point + |from sedona.sql.types import GeometryType |from pyspark.serializers import CloudPickleSerializer + |from sedona.utils import geometry_serde + |from shapely import box |f = open('$path', 'wb'); - |def apply_geopandas(x): - | return x.area - |f.write(CloudPickleSerializer().dumps((apply_geopandas, FloatType()))) + |def w(x): + | def apply_function(w): + | geom, offset = geometry_serde.deserialize(w) + | bounds = geom.buffer(1).bounds + | x = box(*bounds) + | return geometry_serde.serialize(x) + | return x.apply(apply_function) + |f.write(CloudPickleSerializer().dumps((w, GeometryType()))) |""".stripMargin), None, "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! @@ -97,90 +103,13 @@ object ScalarUDF { binaryPandasFunc } - val geopandasGeometryToGeometryFunction: Array[Byte] = { - var binaryPandasFunc: Array[Byte] = null - withTempPath { path => - Process( - Seq( - pythonExec, - "-c", - f""" - |from sedona.sql.types import GeometryType - |from pyspark.serializers import CloudPickleSerializer - |f = open('$path', 'wb'); - |def apply_geopandas(x): - | return x.buffer(1) - |f.write(CloudPickleSerializer().dumps((apply_geopandas, GeometryType()))) - |""".stripMargin), - None, - "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! - binaryPandasFunc = Files.readAllBytes(path.toPath) - } - assert(binaryPandasFunc != null) - binaryPandasFunc - } - - val geopandasNonGeometryToGeometryFunction: Array[Byte] = { - var binaryPandasFunc: Array[Byte] = null - withTempPath { path => - Process( - Seq( - pythonExec, - "-c", - f""" - |from sedona.sql.types import GeometryType - |from shapely.wkt import loads - |from pyspark.serializers import CloudPickleSerializer - |f = open('$path', 'wb'); - |def apply_geopandas(x): - | return x.apply(lambda wkt: loads(wkt).buffer(1)) - |f.write(CloudPickleSerializer().dumps((apply_geopandas, GeometryType()))) - |""".stripMargin), - None, - "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! - binaryPandasFunc = Files.readAllBytes(path.toPath) - } - assert(binaryPandasFunc != null) - binaryPandasFunc - } - private val workerEnv = new java.util.HashMap[String, String]() workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") - SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "sedonaworker.worker") - SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false) - - val geometryToNonGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( - name = "geospatial_udf", - func = SimplePythonFunction( - command = geopandasGeometryToNonGeometry, - envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], - pythonIncludes = List.empty[String].asJava, - pythonExec = pythonExec, - pythonVer = pythonVer, - broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, - accumulator = null), - dataType = FloatType, - pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, - udfDeterministic = true) - - val geometryToGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( - name = "geospatial_udf", - func = SimplePythonFunction( - command = geopandasGeometryToGeometryFunction, - envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], - pythonIncludes = List.empty[String].asJava, - pythonExec = pythonExec, - pythonVer = pythonVer, - broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, - accumulator = null), - dataType = GeometryUDT, - pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, - udfDeterministic = true) - val nonGeometryToGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( + val geoPandasScalaFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( name = "geospatial_udf", func = SimplePythonFunction( - command = geopandasNonGeometryToGeometryFunction, + command = pandasFunc, envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], pythonIncludes = List.empty[String].asJava, pythonExec = pythonExec,
