This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch arrow-worker in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 08dd543637d7814b80c5d6d6276a6a3e19c72052 Author: pawelkocinski <[email protected]> AuthorDate: Tue Dec 23 17:31:45 2025 +0100 add code so far --- .../common/geometrySerde/GeometrySerializer.java | 39 +- .../org/apache/sedona/spark/SedonaContext.scala | 1 - .../org/apache/sedona/sql/UDF/PythonEvalType.scala | 7 + .../scala/org/apache/spark/SedonaSparkEnv.scala | 495 --------------------- .../execution/python/SedonaArrowPythonRunner.scala | 57 +-- .../sql/execution/python/SedonaArrowStrategy.scala | 49 +- .../execution/python/SedonaBasePythonRunner.scala | 200 +++++++++ .../execution/python/SedonaDBWorkerFactory.scala | 116 +++++ .../execution/python/SedonaPythonArrowInput.scala | 97 ++-- .../sql/execution/python/WorkerContext.scala} | 33 +- .../spark/sql/udf/ExtractSedonaUDFRule.scala | 13 +- .../sedona/sql/GeoParquetMetadataTests.scala | 138 ------ .../org/apache/spark/sql/udf/StrategySuite.scala | 204 +++++++-- .../apache/spark/sql/udf/TestScalarPandasUDF.scala | 217 +++++---- 14 files changed, 777 insertions(+), 889 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java index d0f2f39d46..508a62901d 100644 --- a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java +++ b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java @@ -32,38 +32,12 @@ import org.locationtech.jts.geom.Point; import org.locationtech.jts.geom.Polygon; import org.locationtech.jts.geom.PrecisionModel; import org.locationtech.jts.io.WKBConstants; -import org.locationtech.jts.io.WKBReader; -import org.locationtech.jts.io.WKBWriter; public class GeometrySerializer { private static final Coordinate NULL_COORDINATE = new Coordinate(Double.NaN, Double.NaN); private static final PrecisionModel PRECISION_MODEL = new PrecisionModel(); public static byte[] serialize(Geometry geometry) { -// return new WKBWriter().write(geometry); - GeometryBuffer buffer; - if (geometry instanceof Point) { - buffer = serializePoint((Point) geometry); - } else if (geometry instanceof MultiPoint) { - buffer = serializeMultiPoint((MultiPoint) geometry); - } else if (geometry instanceof LineString) { - buffer = serializeLineString((LineString) geometry); - } else if (geometry instanceof MultiLineString) { - buffer = serializeMultiLineString((MultiLineString) geometry); - } else if (geometry instanceof Polygon) { - buffer = serializePolygon((Polygon) geometry); - } else if (geometry instanceof MultiPolygon) { - buffer = serializeMultiPolygon((MultiPolygon) geometry); - } else if (geometry instanceof GeometryCollection) { - buffer = serializeGeometryCollection((GeometryCollection) geometry); - } else { - throw new UnsupportedOperationException( - "Geometry type is not supported: " + geometry.getClass().getSimpleName()); - } - return buffer.toByteArray(); - } - - public static byte[] serializeLegacy(Geometry geometry) { GeometryBuffer buffer; if (geometry instanceof Point) { buffer = serializePoint((Point) geometry); @@ -87,17 +61,6 @@ public class GeometrySerializer { } public static Geometry deserialize(byte[] bytes) { -// WKBReader reader = new WKBReader(); -// try { -// return reader.read(bytes); -// } catch (Exception e) { -// throw new IllegalArgumentException("Failed to deserialize geometry from bytes", e); -// } - GeometryBuffer buffer = GeometryBufferFactory.wrap(bytes); - return deserialize(buffer); - } - - public static Geometry deserializeLegacy(byte[] bytes) { GeometryBuffer buffer = GeometryBufferFactory.wrap(bytes); return deserialize(buffer); } @@ -170,7 +133,7 @@ public class GeometrySerializer { buffer.mark(8); } else { int bufferSize = 8 + coordType.bytes; -// checkBufferSize(buffer, bufferSize); + checkBufferSize(buffer, bufferSize); CoordinateSequence coordinates = buffer.getCoordinate(8); point = factory.createPoint(coordinates); buffer.mark(bufferSize); diff --git a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala index c9e8497f7e..add3caf225 100644 --- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala +++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala @@ -41,7 +41,6 @@ class InternalApi( extends StaticAnnotation object SedonaContext { - private def customOptimizationsWithSession(sparkSession: SparkSession) = Seq( new TransformNestedUDTParquet(sparkSession), diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala index aece26267d..11263dd7f6 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala @@ -23,7 +23,14 @@ object PythonEvalType { val SQL_SCALAR_SEDONA_UDF = 5200 val SEDONA_UDF_TYPE_CONSTANT = 5000 + // sedona db eval types + val SQL_SCALAR_SEDONA_DB_UDF = 6200 + val SEDONA_DB_UDF_TYPE_CONSTANT = 6000 + def toString(pythonEvalType: Int): String = pythonEvalType match { case SQL_SCALAR_SEDONA_UDF => "SQL_SCALAR_GEO_UDF" + case SQL_SCALAR_SEDONA_DB_UDF => "SQL_SCALAR_SEDONA_DB_UDF" } + + def evals(): Set[Int] = Set(SQL_SCALAR_SEDONA_UDF, SQL_SCALAR_SEDONA_DB_UDF) } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala deleted file mode 100644 index 9449a291f5..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala +++ /dev/null @@ -1,495 +0,0 @@ -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -package org.apache.spark - -import java.io.File -import java.net.Socket -import java.util.Locale - -import scala.collection.JavaConverters._ -import scala.collection.concurrent -import scala.collection.mutable -import scala.util.Properties - -import com.google.common.cache.CacheBuilder -import org.apache.hadoop.conf.Configuration - -import org.apache.spark.annotation.DeveloperApi -import org.apache.spark.api.python.PythonWorkerFactory -import org.apache.spark.broadcast.BroadcastManager -import org.apache.spark.executor.ExecutorBackend -import org.apache.spark.internal.{config, Logging} -import org.apache.spark.internal.config._ -import org.apache.spark.memory.{MemoryManager, UnifiedMemoryManager} -import org.apache.spark.metrics.{MetricsSystem, MetricsSystemInstances} -import org.apache.spark.network.netty.{NettyBlockTransferService, SparkTransportConf} -import org.apache.spark.network.shuffle.ExternalBlockStoreClient -import org.apache.spark.rpc.{RpcEndpoint, RpcEndpointRef, RpcEnv} -import org.apache.spark.scheduler.{LiveListenerBus, OutputCommitCoordinator} -import org.apache.spark.scheduler.OutputCommitCoordinator.OutputCommitCoordinatorEndpoint -import org.apache.spark.security.CryptoStreamUtils -import org.apache.spark.serializer.{JavaSerializer, Serializer, SerializerManager} -import org.apache.spark.shuffle.ShuffleManager -import org.apache.spark.storage._ -import org.apache.spark.util.{RpcUtils, Utils} - -/** - * :: DeveloperApi :: - * Holds all the runtime environment objects for a running Spark instance (either master or worker), - * including the serializer, RpcEnv, block manager, map output tracker, etc. Currently - * Spark code finds the SparkEnv through a global variable, so all the threads can access the same - * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext). - */ -@DeveloperApi -class SedonaSparkEnv ( - val executorId: String, - private[spark] val rpcEnv: RpcEnv, - val serializer: Serializer, - val closureSerializer: Serializer, - val serializerManager: SerializerManager, - val mapOutputTracker: MapOutputTracker, - val shuffleManager: ShuffleManager, - val broadcastManager: BroadcastManager, - val blockManager: BlockManager, - val securityManager: SecurityManager, - val metricsSystem: MetricsSystem, - val memoryManager: MemoryManager, - val outputCommitCoordinator: OutputCommitCoordinator, - val conf: SparkConf) extends Logging { - - @volatile private[spark] var isStopped = false - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() - - // A general, soft-reference map for metadata needed during HadoopRDD split computation - // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). - private[spark] val hadoopJobMetadata = - CacheBuilder.newBuilder().maximumSize(1000).softValues().build[String, AnyRef]().asMap() - - private[spark] var driverTmpDir: Option[String] = None - - private[spark] var executorBackend: Option[ExecutorBackend] = None - - private[spark] def stop(): Unit = { - - if (!isStopped) { - isStopped = true - pythonWorkers.values.foreach(_.stop()) - mapOutputTracker.stop() - shuffleManager.stop() - broadcastManager.stop() - blockManager.stop() - blockManager.master.stop() - metricsSystem.stop() - outputCommitCoordinator.stop() - rpcEnv.shutdown() - rpcEnv.awaitTermination() - - // If we only stop sc, but the driver process still run as a services then we need to delete - // the tmp dir, if not, it will create too many tmp dirs. - // We only need to delete the tmp dir create by driver - driverTmpDir match { - case Some(path) => - try { - Utils.deleteRecursively(new File(path)) - } catch { - case e: Exception => - logWarning(s"Exception while deleting Spark temp dir: $path", e) - } - case None => // We just need to delete tmp dir created by driver, so do nothing on executor - } - } - } - - private[spark] - def createPythonWorker( - pythonExec: String, - envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { - synchronized { - val key = (pythonExec, envVars) - pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() - } - } - - private[spark] - def destroyPythonWorker(pythonExec: String, - envVars: Map[String, String], worker: Socket): Unit = { - synchronized { - val key = (pythonExec, envVars) - pythonWorkers.get(key).foreach(_.stopWorker(worker)) - } - } - - private[spark] - def releasePythonWorker(pythonExec: String, - envVars: Map[String, String], worker: Socket): Unit = { - synchronized { - val key = (pythonExec, envVars) - pythonWorkers.get(key).foreach(_.releaseWorker(worker)) - } - } -} - -object SedonaSparkEnv extends Logging { - @volatile private var env: SedonaSparkEnv = _ - - private[spark] val driverSystemName = "sparkDriver" - private[spark] val executorSystemName = "sparkExecutor" - - def set(e: SedonaSparkEnv): Unit = { - env = e - } - - /** - * Returns the SparkEnv. - */ - def get: SedonaSparkEnv = { - env - } - - /** - * Create a SparkEnv for the driver. - */ - private[spark] def createDriverEnv( - conf: SparkConf, - isLocal: Boolean, - listenerBus: LiveListenerBus, - numCores: Int, - sparkContext: SparkContext, - mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { - assert(conf.contains(DRIVER_HOST_ADDRESS), - s"${DRIVER_HOST_ADDRESS.key} is not set on the driver!") - assert(conf.contains(DRIVER_PORT), s"${DRIVER_PORT.key} is not set on the driver!") - val bindAddress = conf.get(DRIVER_BIND_ADDRESS) - val advertiseAddress = conf.get(DRIVER_HOST_ADDRESS) - val port = conf.get(DRIVER_PORT) - val ioEncryptionKey = if (conf.get(IO_ENCRYPTION_ENABLED)) { - Some(CryptoStreamUtils.createKey(conf)) - } else { - None - } - create( - conf, - SparkContext.DRIVER_IDENTIFIER, - bindAddress, - advertiseAddress, - Option(port), - isLocal, - numCores, - ioEncryptionKey, - listenerBus = listenerBus, - Option(sparkContext), - mockOutputCommitCoordinator = mockOutputCommitCoordinator - ) - } - - /** - * Create a SparkEnv for an executor. - * In coarse-grained mode, the executor provides an RpcEnv that is already instantiated. - */ - private[spark] def createExecutorEnv( - conf: SparkConf, - executorId: String, - bindAddress: String, - hostname: String, - numCores: Int, - ioEncryptionKey: Option[Array[Byte]], - isLocal: Boolean): SparkEnv = { - val env = create( - conf, - executorId, - bindAddress, - hostname, - None, - isLocal, - numCores, - ioEncryptionKey - ) - SparkEnv.set(env) - env - } - - private[spark] def createExecutorEnv( - conf: SparkConf, - executorId: String, - hostname: String, - numCores: Int, - ioEncryptionKey: Option[Array[Byte]], - isLocal: Boolean): SparkEnv = { - createExecutorEnv(conf, executorId, hostname, - hostname, numCores, ioEncryptionKey, isLocal) - } - - /** - * Helper method to create a SparkEnv for a driver or an executor. - */ - // scalastyle:off argcount - private def create( - conf: SparkConf, - executorId: String, - bindAddress: String, - advertiseAddress: String, - port: Option[Int], - isLocal: Boolean, - numUsableCores: Int, - ioEncryptionKey: Option[Array[Byte]], - listenerBus: LiveListenerBus = null, - sc: Option[SparkContext] = None, - mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { - // scalastyle:on argcount - - val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER - - // Listener bus is only used on the driver - if (isDriver) { - assert(listenerBus != null, "Attempted to create driver SparkEnv with null listener bus!") - } - val authSecretFileConf = if (isDriver) AUTH_SECRET_FILE_DRIVER else AUTH_SECRET_FILE_EXECUTOR - val securityManager = new SecurityManager(conf, ioEncryptionKey, authSecretFileConf) - if (isDriver) { - securityManager.initializeAuth() - } - - ioEncryptionKey.foreach { _ => - if (!securityManager.isEncryptionEnabled()) { - logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " + - "wire.") - } - } - - val systemName = if (isDriver) driverSystemName else executorSystemName - val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, - securityManager, numUsableCores, !isDriver) - - // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. - if (isDriver) { - conf.set(DRIVER_PORT, rpcEnv.address.port) - } - - val serializer = Utils.instantiateSerializerFromConf[Serializer](SERIALIZER, conf, isDriver) - logDebug(s"Using serializer: ${serializer.getClass}") - - val serializerManager = new SerializerManager(serializer, conf, ioEncryptionKey) - - val closureSerializer = new JavaSerializer(conf) - - def registerOrLookupEndpoint( - name: String, endpointCreator: => RpcEndpoint): - RpcEndpointRef = { - if (isDriver) { - logInfo("Registering " + name) - rpcEnv.setupEndpoint(name, endpointCreator) - } else { - RpcUtils.makeDriverRef(name, conf, rpcEnv) - } - } - - val broadcastManager = new BroadcastManager(isDriver, conf) - - val mapOutputTracker = if (isDriver) { - new MapOutputTrackerMaster(conf, broadcastManager, isLocal) - } else { - new MapOutputTrackerWorker(conf) - } - - // Have to assign trackerEndpoint after initialization as MapOutputTrackerEndpoint - // requires the MapOutputTracker itself - mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME, - new MapOutputTrackerMasterEndpoint( - rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) - - // Let the user specify short names for shuffle managers - val shortShuffleMgrNames = Map( - "sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName, - "tungsten-sort" -> classOf[org.apache.spark.shuffle.sort.SortShuffleManager].getName) - val shuffleMgrName = conf.get(config.SHUFFLE_MANAGER) - val shuffleMgrClass = - shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName) - val shuffleManager = Utils.instantiateSerializerOrShuffleManager[ShuffleManager]( - shuffleMgrClass, conf, isDriver) - - val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores) - - val blockManagerPort = if (isDriver) { - conf.get(DRIVER_BLOCK_MANAGER_PORT) - } else { - conf.get(BLOCK_MANAGER_PORT) - } - - val externalShuffleClient = if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { - val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - Some(new ExternalBlockStoreClient(transConf, securityManager, - securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT))) - } else { - None - } - - // Mapping from block manager id to the block manager's information. - val blockManagerInfo = new concurrent.TrieMap[BlockManagerId, BlockManagerInfo]() - val blockManagerMaster = new BlockManagerMaster( - registerOrLookupEndpoint( - BlockManagerMaster.DRIVER_ENDPOINT_NAME, - new BlockManagerMasterEndpoint( - rpcEnv, - isLocal, - conf, - listenerBus, - if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { - externalShuffleClient - } else { - None - }, blockManagerInfo, - mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], - shuffleManager, - isDriver)), - registerOrLookupEndpoint( - BlockManagerMaster.DRIVER_HEARTBEAT_ENDPOINT_NAME, - new BlockManagerMasterHeartbeatEndpoint(rpcEnv, isLocal, blockManagerInfo)), - conf, - isDriver) - - val blockTransferService = - new NettyBlockTransferService(conf, securityManager, serializerManager, bindAddress, - advertiseAddress, blockManagerPort, numUsableCores, blockManagerMaster.driverEndpoint) - - // NB: blockManager is not valid until initialize() is called later. - val blockManager = new BlockManager( - executorId, - rpcEnv, - blockManagerMaster, - serializerManager, - conf, - memoryManager, - mapOutputTracker, - shuffleManager, - blockTransferService, - securityManager, - externalShuffleClient) - - val metricsSystem = if (isDriver) { - // Don't start metrics system right now for Driver. - // We need to wait for the task scheduler to give us an app ID. - // Then we can start the metrics system. - MetricsSystem.createMetricsSystem(MetricsSystemInstances.DRIVER, conf) - } else { - // We need to set the executor ID before the MetricsSystem is created because sources and - // sinks specified in the metrics configuration file will want to incorporate this executor's - // ID into the metrics they report. - conf.set(EXECUTOR_ID, executorId) - val ms = MetricsSystem.createMetricsSystem(MetricsSystemInstances.EXECUTOR, conf) - ms.start(conf.get(METRICS_STATIC_SOURCES_ENABLED)) - ms - } - - val outputCommitCoordinator = mockOutputCommitCoordinator.getOrElse { - if (isDriver) { - new OutputCommitCoordinator(conf, isDriver, sc) - } else { - new OutputCommitCoordinator(conf, isDriver) - } - - } - val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", - new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) - outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) - - val envInstance = new SparkEnv( - executorId, - rpcEnv, - serializer, - closureSerializer, - serializerManager, - mapOutputTracker, - shuffleManager, - broadcastManager, - blockManager, - securityManager, - metricsSystem, - memoryManager, - outputCommitCoordinator, - conf) - - // Add a reference to tmp dir created by driver, we will delete this tmp dir when stop() is - // called, and we only need to do it for driver. Because driver may run as a service, and if we - // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs. - if (isDriver) { - val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath - envInstance.driverTmpDir = Some(sparkFilesDir) - } - - envInstance - } - - /** - * Return a map representation of jvm information, Spark properties, system properties, and - * class paths. Map keys define the category, and map values represent the corresponding - * attributes as a sequence of KV pairs. This is used mainly for SparkListenerEnvironmentUpdate. - */ - private[spark] def environmentDetails( - conf: SparkConf, - hadoopConf: Configuration, - schedulingMode: String, - addedJars: Seq[String], - addedFiles: Seq[String], - addedArchives: Seq[String], - metricsProperties: Map[String, String]): Map[String, Seq[(String, String)]] = { - - import Properties._ - val jvmInformation = Seq( - ("Java Version", s"$javaVersion ($javaVendor)"), - ("Java Home", javaHome), - ("Scala Version", versionString) - ).sorted - - // Spark properties - // This includes the scheduling mode whether or not it is configured (used by SparkUI) - val schedulerMode = - if (!conf.contains(SCHEDULER_MODE)) { - Seq((SCHEDULER_MODE.key, schedulingMode)) - } else { - Seq.empty[(String, String)] - } - val sparkProperties = (conf.getAll ++ schedulerMode).sorted - - // System properties that are not java classpaths - val systemProperties = Utils.getSystemProperties.toSeq - val otherProperties = systemProperties.filter { case (k, _) => - k != "java.class.path" && !k.startsWith("spark.") - }.sorted - - // Class paths including all added jars and files - val classPathEntries = javaClassPath - .split(File.pathSeparator) - .filterNot(_.isEmpty) - .map((_, "System Classpath")) - val addedJarsAndFiles = (addedJars ++ addedFiles ++ addedArchives).map((_, "Added By User")) - val classPaths = (addedJarsAndFiles ++ classPathEntries).sorted - - // Add Hadoop properties, it will not ignore configs including in Spark. Some spark - // conf starting with "spark.hadoop" may overwrite it. - val hadoopProperties = hadoopConf.asScala - .map(entry => (entry.getKey, entry.getValue)).toSeq.sorted - Map[String, Seq[(String, String)]]( - "JVM Information" -> jvmInformation, - "Spark Properties" -> sparkProperties, - "Hadoop Properties" -> hadoopProperties, - "System Properties" -> otherProperties, - "Classpath Entries" -> classPaths, - "Metrics Properties" -> metricsProperties.toSeq.sorted) - } -} - diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala index 16b81b50c0..598cd830e9 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala @@ -1,21 +1,22 @@ -package org.apache.spark.sql.execution.python - /* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ +package org.apache.spark.sql.execution.python import org.apache.spark.api.python._ import org.apache.spark.sql.catalyst.InternalRow @@ -28,23 +29,25 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. */ class SedonaArrowPythonRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - protected override val schema: StructType, - protected override val timeZoneId: String, - protected override val largeVarTypes: Boolean, - protected override val workerConf: Map[String, String], - val pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) - extends BasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, evalType, argOffsets, jobArtifactUUID) + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + protected override val schema: StructType, + protected override val timeZoneId: String, + protected override val largeVarTypes: Boolean, + protected override val workerConf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + funcs, + evalType, + argOffsets, + jobArtifactUUID) with SedonaBasicPythonArrowInput with BasicPythonArrowOutput { override val pythonExec: String = - SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( - funcs.head.funcs.head.pythonExec) + SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(funcs.head.funcs.head.pythonExec) override val errorOnDuplicatedFieldNames: Boolean = true diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala index ff3c027c5d..c025653029 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala @@ -19,19 +19,17 @@ package org.apache.spark.sql.execution.python import org.apache.sedona.sql.UDF.PythonEvalType +import org.apache.sedona.sql.UDF.PythonEvalType.{SQL_SCALAR_SEDONA_DB_UDF, SQL_SCALAR_SEDONA_UDF} import org.apache.spark.api.python.ChainedPythonFunctions import org.apache.spark.sql.Strategy import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences -import org.apache.spark.sql.catalyst.expressions.UnsafeProjection -import org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, PythonUDF} import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan import org.apache.spark.sql.execution.SparkPlan -import org.apache.spark.sql.internal.SQLConf import org.apache.spark.sql.types.StructType import org.apache.spark.sql.udf.SedonaArrowEvalPython import org.apache.spark.{JobArtifactSet, TaskContext} + import scala.collection.JavaConverters.asScalaIteratorConverter // We use custom Strategy to avoid Apache Spark assert on types, we @@ -71,19 +69,38 @@ case class SedonaArrowEvalPythonExec( val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) - val columnarBatchIter = new SedonaArrowPythonRunner( - funcs, - evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, - argOffsets, - schema, - sessionLocalTimeZone, - largeVarTypes, - pythonRunnerConf, - pythonMetrics, - jobArtifactUUID).compute(batchIter, context.partitionId(), context) + evalType match { + case SQL_SCALAR_SEDONA_DB_UDF => + val columnarBatchIter = new SedonaArrowPythonRunner( + funcs, + evalType - PythonEvalType.SEDONA_DB_UDF_TYPE_CONSTANT, + argOffsets, + schema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID).compute(batchIter, context.partitionId(), context) + + columnarBatchIter.flatMap { batch => + batch.rowIterator.asScala + } + + case SQL_SCALAR_SEDONA_UDF => + val columnarBatchIter = new ArrowPythonRunner( + funcs, + evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, + argOffsets, + schema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID).compute(batchIter, context.partitionId(), context) - columnarBatchIter.flatMap { batch => - batch.rowIterator.asScala + columnarBatchIter.flatMap { batch => + batch.rowIterator.asScala + } } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala new file mode 100644 index 0000000000..b3e0878c91 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala @@ -0,0 +1,200 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.sedona.spark.SedonaContext + +import java.io._ +import java.net._ +import java.nio.file.Path +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ +import org.apache.spark._ +import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, EncryptedPythonBroadcastServer, PythonAccumulatorV2, PythonException, PythonRDD, PythonRunner, SpecialLengths} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.EXECUTOR_CORES +import org.apache.spark.internal.config.Python._ +import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} +import org.apache.spark.util._ + +private object SedonaBasePythonRunner { + + private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler") + + private def faultHandlerLogPath(pid: Int): Path = { + new File(faultHandlerLogDir, pid.toString).toPath + } +} + +/** + * A helper class to run Python mapPartition/UDFs in Spark. + * + * funcs is a list of independent Python functions, each one of them is a list of chained Python + * functions (from bottom to top). + */ +private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + jobArtifactUUID: Option[String]) + extends BasePythonRunner[IN, OUT](funcs, evalType, argOffsets, jobArtifactUUID) + with Logging { + + require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") + + private val conf = SparkEnv.get.conf + private val reuseWorker = conf.get(PYTHON_WORKER_REUSE) + private val faultHandlerEnabled = conf.get(PYTHON_WORKER_FAULTHANLDER_ENABLED) + + private def getWorkerMemoryMb(mem: Option[Long], cores: Int): Option[Long] = { + mem.map(_ / cores) + } + + override def compute( + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): Iterator[OUT] = { + val startTime = System.currentTimeMillis + val env = SparkEnv.get + + // Get the executor cores and pyspark memory, they are passed via the local properties when + // the user specified them in a ResourceProfile. + val execCoresProp = Option(context.getLocalProperty(EXECUTOR_CORES_LOCAL_PROPERTY)) + val memoryMb = Option(context.getLocalProperty(PYSPARK_MEMORY_LOCAL_PROPERTY)).map(_.toLong) + val localdir = env.blockManager.diskBlockManager.localDirs.map(f => f.getPath()).mkString(",") + // If OMP_NUM_THREADS is not explicitly set, override it with the number of task cpus. + // See SPARK-42613 for details. + if (conf.getOption("spark.executorEnv.OMP_NUM_THREADS").isEmpty) { + envVars.put("OMP_NUM_THREADS", conf.get("spark.task.cpus", "1")) + } + envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread + if (reuseWorker) { + envVars.put("SPARK_REUSE_WORKER", "1") + } + if (simplifiedTraceback) { + envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") + } + // SPARK-30299 this could be wrong with standalone mode when executor + // cores might not be correct because it defaults to all cores on the box. + val execCores = execCoresProp.map(_.toInt).getOrElse(conf.get(EXECUTOR_CORES)) + val workerMemoryMb = getWorkerMemoryMb(memoryMb, execCores) + if (workerMemoryMb.isDefined) { + envVars.put("PYSPARK_EXECUTOR_MEMORY_MB", workerMemoryMb.get.toString) + } + envVars.put("SPARK_AUTH_SOCKET_TIMEOUT", authSocketTimeout.toString) + envVars.put("SPARK_BUFFER_SIZE", bufferSize.toString) + if (faultHandlerEnabled) { + envVars.put("PYTHON_FAULTHANDLER_DIR", SedonaBasePythonRunner.faultHandlerLogDir.toString) + } + + envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) + + val (worker: Socket, pid: Option[Int]) = + WorkerContext.createPythonWorker(pythonExec, envVars.asScala.toMap) + // Whether is the worker released into idle pool or closed. When any codes try to release or + // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make + // sure there is only one winner that is going to release or close the worker. + val releasedOrClosed = new AtomicBoolean(false) + + // Start a thread to feed the process input from our parent's iterator + val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + + context.addTaskCompletionListener[Unit] { _ => + writerThread.shutdownOnTaskCompletion() + if (!reuseWorker || releasedOrClosed.compareAndSet(false, true)) { + try { + worker.close() + } catch { + case e: Exception => + logWarning("Failed to close worker socket", e) + } + } + } + + writerThread.start() + new SedonaMonitorThread(SparkEnv.get, worker, writerThread, context).start() + if (reuseWorker) { + val key = (worker, context.taskAttemptId) + // SPARK-35009: avoid creating multiple monitor threads for the same python worker + // and task context + if (PythonRunner.runningMonitorThreads.add(key)) { + new MonitorThread(SparkEnv.get, worker, context).start() + } + } else { + new MonitorThread(SparkEnv.get, worker, context).start() + } + + // Return an iterator that read lines from the process's stdout + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) + + val stdoutIterator = newReaderIterator( + stream, + writerThread, + startTime, + env, + worker, + pid, + releasedOrClosed, + context) + new InterruptibleIterator(context, stdoutIterator) + } + + private class SedonaMonitorThread( + env: SparkEnv, + worker: Socket, + writerThread: WriterThread, + context: TaskContext) + extends Thread(s"Writer Monitor for $pythonExec (writer thread id ${writerThread.getId})") { + + /** + * How long to wait before closing the socket if the writer thread has not exited after the + * task ends. + */ + private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) + + setDaemon(true) + + override def run(): Unit = { + // Wait until the task is completed (or the writer thread exits, in which case this thread has + // nothing to do). + while (!context.isCompleted && writerThread.isAlive) { + Thread.sleep(2000) + } + if (writerThread.isAlive) { + Thread.sleep(taskKillTimeout) + // If the writer thread continues running, this indicates a deadlock. Kill the worker to + // resolve the deadlock. + if (writerThread.isAlive) { + try { + // Mimic the task name used in `Executor` to help the user find out the task to blame. + val taskName = s"${context.partitionId}.${context.attemptNumber} " + + s"in stage ${context.stageId} (TID ${context.taskAttemptId})" + logWarning( + s"Detected deadlock while completing task $taskName: " + + "Attempting to kill Python Worker") + WorkerContext.destroyPythonWorker(pythonExec, envVars.asScala.toMap, worker) + } catch { + case e: Exception => + logError("Exception when trying to kill worker", e) + } + } + } + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala new file mode 100644 index 0000000000..db46ff6d8c --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala @@ -0,0 +1,116 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.execution.python + +import org.apache.spark.{SparkException, SparkFiles} +import org.apache.spark.api.python.{PythonUtils, PythonWorkerFactory} +import org.apache.spark.util.Utils + +import java.io.{DataInputStream, File} +import java.net.{InetAddress, ServerSocket, Socket} +import java.util.Arrays +import java.io.InputStream +import scala.collection.JavaConverters._ +import scala.collection.mutable +import org.apache.spark._ +import org.apache.spark.security.SocketAuthHelper +import org.apache.spark.util.RedirectThread + +class SedonaDBWorkerFactory(pythonExec: String, envVars: Map[String, String]) + extends PythonWorkerFactory(pythonExec, envVars) { + self => + + private val sedonaWorkerModule = "sedonaworker.work" + + private val simpleWorkers = new mutable.WeakHashMap[Socket, Process]() + private val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + + private val pythonPath = PythonUtils.mergePythonPaths( + PythonUtils.sparkPythonPath, + envVars.getOrElse("PYTHONPATH", ""), + sys.env.getOrElse("PYTHONPATH", "")) + + override def create(): (Socket, Option[Int]) = { + createSimpleWorker(sedonaWorkerModule) + } + + private def createSimpleWorker(workerModule: String): (Socket, Option[Int]) = { + var serverSocket: ServerSocket = null + try { + serverSocket = new ServerSocket(0, 1, InetAddress.getLoopbackAddress()) + + // Create and start the worker + val pb = new ProcessBuilder(Arrays.asList(pythonExec, "-m", workerModule)) + val jobArtifactUUID = envVars.getOrElse("SPARK_JOB_ARTIFACT_UUID", "default") + if (jobArtifactUUID != "default") { + val f = new File(SparkFiles.getRootDirectory(), jobArtifactUUID) + f.mkdir() + pb.directory(f) + } + val workerEnv = pb.environment() + workerEnv.putAll(envVars.asJava) + workerEnv.put("PYTHONPATH", pythonPath) + // This is equivalent to setting the -u flag; we use it because ipython doesn't support -u: + workerEnv.put("PYTHONUNBUFFERED", "YES") + workerEnv.put("PYTHON_WORKER_FACTORY_PORT", serverSocket.getLocalPort.toString) + workerEnv.put("PYTHON_WORKER_FACTORY_SECRET", authHelper.secret) + if (Utils.preferIPv6) { + workerEnv.put("SPARK_PREFER_IPV6", "True") + } + val worker = pb.start() + + // Redirect worker stdout and stderr + redirectStreamsToStderr(worker.getInputStream, worker.getErrorStream) + + // Wait for it to connect to our socket, and validate the auth secret. + serverSocket.setSoTimeout(10000) + + try { + val socket = serverSocket.accept() + authHelper.authClient(socket) + // TODO: When we drop JDK 8, we can just use worker.pid() + val pid = new DataInputStream(socket.getInputStream).readInt() + if (pid < 0) { + throw new IllegalStateException("Python failed to launch worker with code " + pid) + } + self.synchronized { + simpleWorkers.put(socket, worker) + } + return (socket, Some(pid)) + } catch { + case e: Exception => + throw new SparkException("Python worker failed to connect back.", e) + } + } finally { + if (serverSocket != null) { + serverSocket.close() + } + } + } + + private def redirectStreamsToStderr(stdout: InputStream, stderr: InputStream): Unit = { + try { + new RedirectThread(stdout, System.err, "stdout reader for " + pythonExec).start() + new RedirectThread(stderr, System.err, "stderr reader for " + pythonExec).start() + } catch { + case e: Exception => + logError("Exception in redirecting streams", e) + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala index 5567ef28b5..fee3c22e64 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ package org.apache.spark.sql.execution.python /* @@ -20,13 +38,10 @@ package org.apache.spark.sql.execution.python import org.apache.arrow.vector.VectorSchemaRoot import org.apache.arrow.vector.ipc.ArrowStreamWriter import org.apache.spark.api.python -import org.apache.spark.api.python.{BasePythonRunner, ChainedPythonFunctions, PythonRDD} import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.expressions.GenericInternalRow import org.apache.spark.sql.execution.arrow.ArrowWriter -import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils import org.apache.spark.{SparkEnv, TaskContext} @@ -35,54 +50,56 @@ import java.io.DataOutputStream import java.net.Socket /** - * A trait that can be mixed-in with [[python.BasePythonRunner]]. It implements the logic from - * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). + * A trait that can be mixed-in with [[python.BasePythonRunner]]. It implements the logic from JVM + * (an iterator of internal rows + additional data if required) to Python (Arrow). */ -private[python] trait SedonaPythonArrowInput[IN] extends PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => - +private[python] trait SedonaPythonArrowInput[IN] extends PythonArrowInput[IN] { + self: SedonaBasePythonRunner[IN, _] => protected override def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext): WriterThread = { + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { handleMetadataBeforeExec(dataOut) writeUDF(dataOut, funcs, argOffsets) - val toReadCRS = inputIterator.buffered.headOption.flatMap( - el => el.asInstanceOf[Iterator[IN]].buffered.headOption - ) + val toReadCRS = inputIterator.buffered.headOption.flatMap(el => + el.asInstanceOf[Iterator[IN]].buffered.headOption) val row = toReadCRS match { - case Some(value) => value match { - case row: GenericInternalRow => - Some(row) - } + case Some(value) => + value match { + case row: GenericInternalRow => + Some(row) + } case None => None } - val geometryFields = schema.zipWithIndex.filter { - case (field, index) => field.dataType == GeometryUDT - }.map { - case (field, index) => - if (row.isEmpty || row.get.values(index) == null) (index, 0) else { + val geometryFields = schema.zipWithIndex + .filter { case (field, index) => + field.dataType == GeometryUDT + } + .map { case (field, index) => + if (row.isEmpty || row.get.values(index) == null) (index, 0) + else { val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] - val preambleByte = geom(0) & 0xFF + val preambleByte = geom(0) & 0xff val hasSrid = (preambleByte & 0x01) != 0 var srid = 0 if (hasSrid) { - val srid2 = (geom(1) & 0xFF) << 16 - val srid1 = (geom(2) & 0xFF) << 8 - val srid0 = geom(3) & 0xFF + val srid2 = (geom(1) & 0xff) << 16 + val srid1 = (geom(2) & 0xff) << 8 + val srid0 = geom(3) & 0xff srid = srid2 | srid1 | srid0 } (index, srid) } - } + } // write number of geometry fields dataOut.writeInt(geometryFields.length) @@ -94,10 +111,12 @@ private[python] trait SedonaPythonArrowInput[IN] extends PythonArrowInput[IN] { } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val arrowSchema = ArrowUtils.toArrowSchema( - schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) + val arrowSchema = + ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdout writer for $pythonExec", 0, Long.MaxValue) + s"stdout writer for $pythonExec", + 0, + Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) Utils.tryWithSafeFinally { @@ -129,15 +148,15 @@ private[python] trait SedonaPythonArrowInput[IN] extends PythonArrowInput[IN] { } } - -private[python] trait SedonaBasicPythonArrowInput extends SedonaPythonArrowInput[Iterator[InternalRow]] { - self: BasePythonRunner[Iterator[InternalRow], _] => +private[python] trait SedonaBasicPythonArrowInput + extends SedonaPythonArrowInput[Iterator[InternalRow]] { + self: SedonaBasePythonRunner[Iterator[InternalRow], _] => protected def writeIteratorToArrowStream( - root: VectorSchemaRoot, - writer: ArrowStreamWriter, - dataOut: DataOutputStream, - inputIterator: Iterator[Iterator[InternalRow]]): Unit = { + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[Iterator[InternalRow]]): Unit = { val arrowWriter = ArrowWriter.create(root) while (inputIterator.hasNext) { diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala similarity index 51% copy from spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala copy to spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala index aece26267d..c1193cb7fa 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala @@ -16,14 +16,33 @@ * specific language governing permissions and limitations * under the License. */ -package org.apache.sedona.sql.UDF +package org.apache.spark.sql.execution.python -// We use constant 5000 for Sedona UDFs, 200 is Apache Spark scalar UDF -object PythonEvalType { - val SQL_SCALAR_SEDONA_UDF = 5200 - val SEDONA_UDF_TYPE_CONSTANT = 5000 +import java.net.Socket +import scala.collection.mutable - def toString(pythonEvalType: Int): String = pythonEvalType match { - case SQL_SCALAR_SEDONA_UDF => "SQL_SCALAR_GEO_UDF" +object WorkerContext { + + def createPythonWorker( + pythonExec: String, + envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.getOrElseUpdate(key, new SedonaDBWorkerFactory(pythonExec, envVars)).create() + } + } + + private[spark] def destroyPythonWorker( + pythonExec: String, + envVars: Map[String, String], + worker: Socket): Unit = { + synchronized { + val key = (pythonExec, envVars) + pythonWorkers.get(key).foreach(_.stopWorker(worker)) + } } + + private val pythonWorkers = + mutable.HashMap[(String, Map[String, String]), SedonaDBWorkerFactory]() + } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala index 3d3301580c..ebb5a568e1 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala @@ -44,9 +44,7 @@ class ExtractSedonaUDFRule extends Rule[LogicalPlan] with Logging { } def isScalarPythonUDF(e: Expression): Boolean = { - e.isInstanceOf[PythonUDF] && e - .asInstanceOf[PythonUDF] - .evalType == PythonEvalType.SQL_SCALAR_SEDONA_UDF + e.isInstanceOf[PythonUDF] && PythonEvalType.evals.contains(e.asInstanceOf[PythonUDF].evalType) } private def collectEvaluableUDFsFromExpressions( @@ -168,13 +166,12 @@ class ExtractSedonaUDFRule extends Rule[LogicalPlan] with Logging { evalTypes.mkString(",")) } val evalType = evalTypes.head - val evaluation = evalType match { - case PythonEvalType.SQL_SCALAR_SEDONA_UDF => - SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType) - case _ => - throw new IllegalStateException("Unexpected UDF evalType") + if (!PythonEvalType.evals().contains(evalType)) { + throw new IllegalStateException(s"Unexpected UDF evalType: $evalType") } + val evaluation = SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType) + attributeMap ++= validUdfs.map(canonicalizeDeterministic).zip(resultAttrs) evaluation } else { diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala index 5198008392..421890c700 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/GeoParquetMetadataTests.scala @@ -18,157 +18,19 @@ */ package org.apache.sedona.sql -import org.apache.sedona.sql.utils.GeometrySerializer import org.apache.spark.sql.Row import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.types.{IntegerType, StructField, StructType} -import org.locationtech.jts.geom.Geometry -import org.locationtech.jts.io.WKTReader import org.scalatest.BeforeAndAfterAll import java.util.Collections import scala.collection.JavaConverters._ -case class GeoDataHex(id: Int, geometry_hex: String) -case class GeoData(id: Int, geometry: Geometry) - class GeoParquetMetadataTests extends TestBaseScala with BeforeAndAfterAll { val geoparquetdatalocation: String = resourceFolder + "geoparquet/" val geoparquetoutputlocation: String = resourceFolder + "geoparquet/geoparquet_output/" - import sparkSession.implicits._ - describe("GeoParquet Metadata tests") { - it("reading and writing GeoParquet files") { -// 'POINT(30.0123 10.2131)', \ - // 'POINT(-20 20)', \ - // 'POINT(10 30)', \ - // 'POINT(40 -40)' \ - -// [0] = {u8} 18 -//[1] = {u8} 0 -//[2] = {u8} 0 -//[3] = {u8} 0 -//[4] = {u8} 1 -//[5] = {u8} 165 -//[6] = {u8} 189 -//[7] = {u8} 193 -//[8] = {u8} 23 -//[9] = {u8} 38 -//[10] = {u8} 3 -//[11] = {u8} 62 -//[12] = {u8} 64 -//[13] = {u8} 34 -//[14] = {u8} 142 -//[15] = {u8} 117 -//[16] = {u8} 113 -//[17] = {u8} 27 -//[18] = {u8} 109 -//[19] = {u8} 36 -//[20] = {u8} 64 - val byteArray = Array[Int]( - 18, 0, 0, 0, 1, - 165, 189, 193, 23, 38, 3, 62, 64, 34, 142, 117, 113, 27, 109, 36, 64 - ) - .map(_.toByte) - -// [ 18, 0, 0, 0, 1, -91, -67, -63, 23, 38, 3, 62, 64, 34, -114, 117, 113, 27, 109, 36, 64 ] - -// GeometrySerializer.deserialize(byteArray) -// [18, 0, 0, 0, 1, -91, -67, -63, 23, 38, 3, 62, 64, 34, -114, 117, 113, 27, 109, 36, 64] -// 18 18 -// 0 0 -// 0 0 -// 0 0 -// 1 1 -// 0 -// 0 -// 0 -// -91 -91 -// -67 -67 -// -63 -63 -// 23 23 -// 38 38 -// 3 3 -// 62 62 -// 64 64 -// 34 34 -// -114 -114 -// 117 117 -// 113 113 -// 27 27 -// 109 109 -// 36 36 -// 64 64 - -// val wktReader = new WKTReader() -// val pointWKT = "POINT(30.0123 10.2131)" -// val point = wktReader.read(pointWKT) -// val serializedBytes = GeometrySerializer.serialize(point) -// serializedBytes.foreach( -// byte => println(byte) -// ) -// - def hexToBytes(hex: String): Array[Byte] = - hex.grouped(2).map(Integer.parseInt(_, 16).toByte).toArray -// - def bytesToHex(bytes: Array[Byte]): String = - bytes.map("%02x".format(_)).mkString - -// Seq( -// (1, "POINT(30.0123 10.2131)"), -//// (2, "POINT(-20 20)"), -//// (3, "POINT(10 30)"), -//// (4, "POINT(40 -40)") -// ).toDF("id", "wkt") -// .selectExpr("id", "ST_GeomFromWKT(wkt) AS geometry") -// .as[GeoData] -// .map( -// row => (row.id, bytesToHex(GeometrySerializer.serialize(row.geometry))) -// ).show(4, false) - - -// -// val data = Seq( -// (1, "1200000001000000a5bdc11726033e40228e75711b6d2440"), -// ) -// .toDF("id", "geometry_hex") -// .as[GeoDataHex] -// -// data.map( -// row => GeoData(row.id, GeometrySerializer.deserialize(hexToBytes(row.geometry_hex))) -// ).show - - val wkt = "LINESTRING ( 20.9972017 52.1696936, 20.9971687 52.1696659, 20.997156 52.169644, 20.9971487 52.1696213 ) " - val reader = new WKTReader() - val geometry = reader.read(wkt) - val serialized = GeometrySerializer.serialize(geometry) - - Seq( - (1, serialized) - ).toDF("id", "geometry_bytes") - .show(1, false) - -// println(bytesToHex(serialized)) - -// -// val binaryData = "1200000001000000bb9d61f7b6c92c40f1ba168a85" -// val binaryData2 = "120000000100000046b6f3fdd4083e404e62105839342440" -// val value = new GeometryUDT().deserialize(hexToBytes(binaryData)) -// val value3 = new GeometryUDT().deserialize(hexToBytes(binaryData2)) -// println(value) -// println(value3) -// -// val reader = new WKTReader() -// val geometryPoint = "POINT (30.0345 10.1020)" -// val point = reader.read(geometryPoint) -// val result = new GeometryUDT().serialize(point) -// -// val value2 = new GeometryUDT().deserialize(result) -// println(bytesToHex(result)) -// println(value2) -// println("ssss") - } it("Reading GeoParquet Metadata") { val df = sparkSession.read.format("geoparquet.metadata").load(geoparquetdatalocation) val metadataArray = df.collect() diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala index c92af92cdd..9dc6677035 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -19,79 +19,203 @@ package org.apache.spark.sql.udf import org.apache.sedona.spark.SedonaContext +import org.apache.sedona.sql.TestBaseScala import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.execution.python.WorkerContext import org.apache.spark.sql.functions.{col, expr, lit} -import org.apache.spark.sql.udf.ScalarUDF.geometryToGeometryFunction +import org.apache.spark.sql.udf.ScalarUDF.{geometryToGeometryFunction, nonGeometryVectorizedUDF, nonGeometryVectorizedUDF2} import org.locationtech.jts.io.WKTReader import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers -class StrategySuite extends AnyFunSuite with Matchers { +class StrategySuite extends TestBaseScala with Matchers { val wktReader = new WKTReader() val spark: SparkSession = { - val builder = SedonaContext - .builder() - .master("local[*]") - .appName("sedonasqlScalaTest") - - val spark = SedonaContext.create(builder.getOrCreate()) - - spark.sparkContext.setLogLevel("ALL") - spark + sparkSession.sparkContext.setLogLevel("ALL") + sparkSession } import spark.implicits._ - test("sedona geospatial UDF") { + it("sedona geospatial UDF - geopandas") { + val df = Seq( + (1, "value", wktReader.read("POINT(21 52)")), + (2, "value1", wktReader.read("POINT(20 50)")), + (3, "value2", wktReader.read("POINT(20 49)")), + (4, "value3", wktReader.read("POINT(20 48)")), + (5, "value4", wktReader.read("POINT(20 47)"))) + .toDF("id", "value", "geom") + .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) + + df.count shouldEqual 5 + + df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") + .as[String] + .collect() should contain theSameElementsAs Seq( + "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", + "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", + "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", + "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", + "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") + } + + it("sedona geospatial UDF") { // spark.sql("select 1").show() - val df = spark.read.format("geoparquet") + val df = spark.read + .format("geoparquet") .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") .withColumn("geometry", expr("ST_SetSRID(geometry, '4326')")) df.show() - df + df .select( col("id"), col("version"), col("bbox"), -// geometryToNonGeometryFunction(col("geometry")), + // nonGeometryVectorizedUDF(col("bbox.xmin")).alias("xmin"), geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom"), -// nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), - ).show(10) + nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), + // nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), + // geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") + // geometryToNonGeometryFunction(col("geometry")), + // geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") + // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), + ) + .show(10) - println(df + df .select( col("id"), col("version"), col("bbox"), - // geometryToNonGeometryFunction(col("geometry")), + // nonGeometryVectorizedUDF(col("bbox.xmin")).alias("xmin"), geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom"), + nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), + // nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), + // geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") + // geometryToNonGeometryFunction(col("geometry")), + // geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") + // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), + ) + .explain(true) +// +// df +// .select( +// col("id"), +// col("version"), +// col("bbox"), +//// nonGeometryVectorizedUDF(col("bbox.xmin")).alias("xmin"), +// geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom"), +//// nonGeometryVectorizedUDF(col("bbox.xmin")).alias("xmin"), +//// nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), +//// nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), +//// geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") +//// geometryToNonGeometryFunction(col("geometry")), +//// geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") +//// nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), +// ) +// .show(10) + } + + + it("sedona db 1 geospatial UDF") { + // spark.sql("select 1").show() + val df = spark.read + .format("geoparquet") + .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") + .withColumn("geometry", expr("ST_SetSRID(geometry, '4326')")) + + df.show() + + df.printSchema() + + df + .select( + col("id"), + col("version"), + col("bbox"), + // geometryToNonGeometryFunction(col("geometry")), + geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), - ).count()) + ) + .show(10) -// df.show() + println( + df + .select( + col("id"), + col("version"), + col("bbox"), + // geometryToNonGeometryFunction(col("geometry")), + geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") + // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), + ) + .count()) + + WorkerContext + + // df.show() 1 shouldBe 1 -// val df = Seq( -// (1, "value", wktReader.read("POINT(21 52)")), -// (2, "value1", wktReader.read("POINT(20 50)")), -// (3, "value2", wktReader.read("POINT(20 49)")), -// (4, "value3", wktReader.read("POINT(20 48)")), -// (5, "value4", wktReader.read("POINT(20 47)"))) -// .toDF("id", "value", "geom") -// .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) - -// df.count shouldEqual 5 - -// df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") -// .as[String] -// .collect() should contain theSameElementsAs Seq( -// "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", -// "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", -// "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", -// "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", -// "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") + // val df = Seq( + // (1, "value", wktReader.read("POINT(21 52)")), + // (2, "value1", wktReader.read("POINT(20 50)")), + // (3, "value2", wktReader.read("POINT(20 49)")), + // (4, "value3", wktReader.read("POINT(20 48)")), + // (5, "value4", wktReader.read("POINT(20 47)"))) + // .toDF("id", "value", "geom") + // .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) + + // df.count shouldEqual 5 + + // df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") + // .as[String] + // .collect() should contain theSameElementsAs Seq( + // "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", + // "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", + // "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", + // "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", + // "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") } + + it("sedona db geospatial UDF") { +// spark.sql("select 1").show() + val df = spark.read + .format("geoparquet") + .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") + .withColumn("geometry", expr("ST_SetSRID(geometry, '4326')")) + + df.show() + + df + .select( + col("id"), + col("version"), + col("bbox"), + // geometryToNonGeometryFunction(col("geometry")), + geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") + // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), + ) + .show(10) + + println( + df + .select( + col("id"), + col("version"), + col("bbox"), + // geometryToNonGeometryFunction(col("geometry")), + geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") + // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), + ) + .count()) + + WorkerContext + + // df.show() + 1 shouldBe 1 + } + } diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index bf8134e423..cc83f5f852 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -73,29 +73,54 @@ object ScalarUDF { } val additionalModule = "spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf" -// -// val geopandasGeometryToNonGeometry: Array[Byte] = { -// var binaryPandasFunc: Array[Byte] = null -// withTempPath { path => -// Process( -// Seq( -// pythonExec, -// "-c", -// f""" -// |from pyspark.sql.types import FloatType -// |from pyspark.serializers import CloudPickleSerializer -// |f = open('$path', 'wb'); -// |def apply_geopandas(x): -// | return x.area -// |f.write(CloudPickleSerializer().dumps((apply_geopandas, FloatType()))) -// |""".stripMargin), -// None, -// "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! -// binaryPandasFunc = Files.readAllBytes(path.toPath) -// } -// assert(binaryPandasFunc != null) -// binaryPandasFunc -// } + + val vectorizedFunction2: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + f""" + |from pyspark.sql.types import FloatType + |from pyspark.serializers import CloudPickleSerializer + |f = open('$path', 'wb'); + | + |def apply_function_on_number(x): + | return x + 3.0 + |f.write(CloudPickleSerializer().dumps((apply_function_on_number, FloatType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } + + val vectorizedFunction: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + f""" + |from pyspark.sql.types import FloatType + |from pyspark.serializers import CloudPickleSerializer + |f = open('$path', 'wb'); + | + |def apply_function_on_number(x): + | return x + 1.0 + |f.write(CloudPickleSerializer().dumps((apply_function_on_number, FloatType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } val geopandasGeometryToGeometryFunction: Array[Byte] = { var binaryPandasFunc: Array[Byte] = null @@ -131,49 +156,81 @@ object ScalarUDF { assert(binaryPandasFunc != null) binaryPandasFunc } -// -// val geopandasNonGeometryToGeometryFunction: Array[Byte] = { -// var binaryPandasFunc: Array[Byte] = null -// withTempPath { path => -// Process( -// Seq( -// pythonExec, -// "-c", -// f""" -// |from sedona.sql.types import GeometryType -// |from shapely.wkt import loads -// |from pyspark.serializers import CloudPickleSerializer -// |f = open('$path', 'wb'); -// |def apply_geopandas(x): -// | return x.apply(lambda wkt: loads(wkt).buffer(1)) -// |f.write(CloudPickleSerializer().dumps((apply_geopandas, GeometryType()))) -// |""".stripMargin), -// None, -// "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! -// binaryPandasFunc = Files.readAllBytes(path.toPath) -// } -// assert(binaryPandasFunc != null) -// binaryPandasFunc -// } + // + // val geopandasNonGeometryToGeometryFunction: Array[Byte] = { + // var binaryPandasFunc: Array[Byte] = null + // withTempPath { path => + // Process( + // Seq( + // pythonExec, + // "-c", + // f""" + // |from sedona.sql.types import GeometryType + // |from shapely.wkt import loads + // |from pyspark.serializers import CloudPickleSerializer + // |f = open('$path', 'wb'); + // |def apply_geopandas(x): + // | return x.apply(lambda wkt: loads(wkt).buffer(1)) + // |f.write(CloudPickleSerializer().dumps((apply_geopandas, GeometryType()))) + // |""".stripMargin), + // None, + // "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + // binaryPandasFunc = Files.readAllBytes(path.toPath) + // } + // assert(binaryPandasFunc != null) + // binaryPandasFunc + // } private val workerEnv = new java.util.HashMap[String, String]() - workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") - SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "sedonaworker.work") - SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false) -// -// val geometryToNonGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( -// name = "geospatial_udf", -// func = SimplePythonFunction( -// command = geopandasGeometryToNonGeometry, -// envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], -// pythonIncludes = List.empty[String].asJava, -// pythonExec = pythonExec, -// pythonVer = pythonVer, -// broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, -// accumulator = null), -// dataType = FloatType, -// pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, -// udfDeterministic = true) + workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") + SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "sedonaworker.initialworker") +// SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false) + // + // val geometryToNonGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( + // name = "geospatial_udf", + // func = SimplePythonFunction( + // command = geopandasGeometryToNonGeometry, + // envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + // pythonIncludes = List.empty[String].asJava, + // pythonExec = pythonExec, + // pythonVer = pythonVer, + // broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + // accumulator = null), + // dataType = FloatType, + // pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, + // udfDeterministic = true) + + val nonGeometryVectorizedUDF: UserDefinedPythonFunction = UserDefinedPythonFunction( + name = "vectorized_udf", + func = SimplePythonFunction( + command = vectorizedFunction, + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null + ), + dataType = FloatType, + pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, + udfDeterministic = true + ) + + val nonGeometryVectorizedUDF2: UserDefinedPythonFunction = UserDefinedPythonFunction( + name = "vectorized_udf", + func = SimplePythonFunction( + command = vectorizedFunction2, + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null + ), + dataType = FloatType, + pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, + udfDeterministic = true + ) val geometryToGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( name = "geospatial_udf", @@ -186,20 +243,20 @@ object ScalarUDF { broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, accumulator = null), dataType = GeometryUDT, - pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, - udfDeterministic = true) -// -// val nonGeometryToGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( -// name = "geospatial_udf", -// func = SimplePythonFunction( -// command = geopandasNonGeometryToGeometryFunction, -// envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], -// pythonIncludes = List.empty[String].asJava, -// pythonExec = pythonExec, -// pythonVer = pythonVer, -// broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, -// accumulator = null), -// dataType = GeometryUDT, -// pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, -// udfDeterministic = true) + pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_DB_UDF, + udfDeterministic = false) + // + // val nonGeometryToGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( + // name = "geospatial_udf", + // func = SimplePythonFunction( + // command = geopandasNonGeometryToGeometryFunction, + // envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + // pythonIncludes = List.empty[String].asJava, + // pythonExec = pythonExec, + // pythonVer = pythonVer, + // broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + // accumulator = null), + // dataType = GeometryUDT, + // pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, + // udfDeterministic = true) }
