This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch add-geom-from-mysql-function in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 788863de3163723f629b75acb29bddf9471e123f Author: pawelkocinski <[email protected]> AuthorDate: Fri Aug 15 15:29:54 2025 +0200 SEDONA-743 Add geom from mysql function. --- .../org/apache/sedona/common/Constructors.java | 19 ++ .../common/geometrySerde/GeometrySerializer.java | 44 ++-- docs/api/sql/Constructor.md | 29 +++ spark/common/pom.xml | 12 + .../scala/org/apache/sedona/sql/UDF/Catalog.scala | 1 + .../sql/sedona_sql/expressions/Constructors.scala | 8 + .../common/src/test/resources/mysql/init_mysql.sql | 17 ++ .../apache/sedona/sql/constructorTestScala.scala | 49 +++- .../scala/org/apache/spark/SedonaSparkEnv.scala | 258 +++++++++++---------- .../spark/api/python/SedonaPythonRunner.scala | 171 ++++++++------ .../execution/python/SedonaArrowPythonRunner.scala | 57 +++-- .../sql/execution/python/SedonaArrowStrategy.scala | 61 ++--- .../sql/execution/python/SedonaArrowUtils.scala | 166 ++++++++----- .../execution/python/SedonaPythonArrowInput.scala | 67 ++++-- .../execution/python/SedonaPythonArrowOutput.scala | 73 ++++-- .../execution/python/SedonaPythonUDFRunner.scala | 108 ++++++--- .../org/apache/spark/sql/udf/StrategySuite.scala | 6 +- .../apache/spark/sql/udf/TestScalarPandasUDF.scala | 6 +- 18 files changed, 744 insertions(+), 408 deletions(-) diff --git a/common/src/main/java/org/apache/sedona/common/Constructors.java b/common/src/main/java/org/apache/sedona/common/Constructors.java index 3cd4729243..6542e691a2 100644 --- a/common/src/main/java/org/apache/sedona/common/Constructors.java +++ b/common/src/main/java/org/apache/sedona/common/Constructors.java @@ -19,6 +19,8 @@ package org.apache.sedona.common; import java.io.IOException; +import java.nio.ByteBuffer; +import java.nio.ByteOrder; import javax.xml.parsers.ParserConfigurationException; import org.apache.sedona.common.enums.FileDataSplitter; import org.apache.sedona.common.enums.GeometryType; @@ -302,4 +304,21 @@ public class Constructors { public static Geometry geomFromKML(String kml) throws ParseException { return new KMLReader().read(kml); } + + public static Geometry geomFromMySQL(byte[] binary) throws ParseException { + ByteBuffer buffer = ByteBuffer.wrap(binary); + + buffer.order(ByteOrder.LITTLE_ENDIAN); + int srid = buffer.getInt(); + + byte[] wkb = new byte[buffer.remaining()]; + + buffer.get(wkb); + + Geometry geom = geomFromWKB(wkb); + + geom.setSRID(srid); + + return geom; + } } diff --git a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java index ba135aa6a1..c0d9154cfd 100644 --- a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java +++ b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java @@ -41,26 +41,26 @@ public class GeometrySerializer { public static byte[] serialize(Geometry geometry) { return new WKBWriter().write(geometry); -// GeometryBuffer buffer; -// if (geometry instanceof Point) { -// buffer = serializePoint((Point) geometry); -// } else if (geometry instanceof MultiPoint) { -// buffer = serializeMultiPoint((MultiPoint) geometry); -// } else if (geometry instanceof LineString) { -// buffer = serializeLineString((LineString) geometry); -// } else if (geometry instanceof MultiLineString) { -// buffer = serializeMultiLineString((MultiLineString) geometry); -// } else if (geometry instanceof Polygon) { -// buffer = serializePolygon((Polygon) geometry); -// } else if (geometry instanceof MultiPolygon) { -// buffer = serializeMultiPolygon((MultiPolygon) geometry); -// } else if (geometry instanceof GeometryCollection) { -// buffer = serializeGeometryCollection((GeometryCollection) geometry); -// } else { -// throw new UnsupportedOperationException( -// "Geometry type is not supported: " + geometry.getClass().getSimpleName()); -// } -// return buffer.toByteArray(); + // GeometryBuffer buffer; + // if (geometry instanceof Point) { + // buffer = serializePoint((Point) geometry); + // } else if (geometry instanceof MultiPoint) { + // buffer = serializeMultiPoint((MultiPoint) geometry); + // } else if (geometry instanceof LineString) { + // buffer = serializeLineString((LineString) geometry); + // } else if (geometry instanceof MultiLineString) { + // buffer = serializeMultiLineString((MultiLineString) geometry); + // } else if (geometry instanceof Polygon) { + // buffer = serializePolygon((Polygon) geometry); + // } else if (geometry instanceof MultiPolygon) { + // buffer = serializeMultiPolygon((MultiPolygon) geometry); + // } else if (geometry instanceof GeometryCollection) { + // buffer = serializeGeometryCollection((GeometryCollection) geometry); + // } else { + // throw new UnsupportedOperationException( + // "Geometry type is not supported: " + geometry.getClass().getSimpleName()); + // } + // return buffer.toByteArray(); } public static byte[] serializeLegacy(Geometry geometry) { @@ -93,8 +93,8 @@ public class GeometrySerializer { } catch (Exception e) { throw new IllegalArgumentException("Failed to deserialize geometry from bytes", e); } -// GeometryBuffer buffer = GeometryBufferFactory.wrap(bytes); -// return deserialize(buffer); + // GeometryBuffer buffer = GeometryBufferFactory.wrap(bytes); + // return deserialize(buffer); } public static Geometry deserializeLegacy(byte[] bytes) { diff --git a/docs/api/sql/Constructor.md b/docs/api/sql/Constructor.md index 8b4602f946..0bbd273433 100644 --- a/docs/api/sql/Constructor.md +++ b/docs/api/sql/Constructor.md @@ -842,3 +842,32 @@ Output: ``` POLYGON ((-74.0428197 40.6867969, -74.0421975 40.6921336, -74.050802 40.6912794, -74.0428197 40.6867969)) ``` + +## ST_GeomFromMySQL + +Introduction: Construct a Geometry from MySQL Geometry binary. + +Format: `ST_GeomFromMySQL (binary: Binary)` + +Since: `v1.0.0` + +SQL Example + +```sql +SELECT + ST_GeomFromMySQL(geomWKB) AS geom, + ST_SRID(ST_GeomFromMySQL(geomWKB)) AS srid +FROM mysql_table +``` + +Output: + +``` ++-------------+----+ +| geom|srid| ++-------------+----+ +|POINT (20 10)|4326| +|POINT (40 30)|4326| +|POINT (60 50)|4326| ++-------------+----+ +``` diff --git a/spark/common/pom.xml b/spark/common/pom.xml index a910d74480..43fc55bef1 100644 --- a/spark/common/pom.xml +++ b/spark/common/pom.xml @@ -259,6 +259,18 @@ <version>1.20.0</version> <scope>test</scope> </dependency> + <dependency> + <groupId>org.testcontainers</groupId> + <artifactId>mysql</artifactId> + <version>1.20.0</version> + <scope>test</scope> + </dependency> + <dependency> + <groupId>com.mysql</groupId> + <artifactId>mysql-connector-j</artifactId> + <version>9.1.0</version> <!-- or latest --> + <scope>test</scope> <!-- or 'runtime' if also used in prod --> + </dependency> <dependency> <groupId>io.minio</groupId> <artifactId>minio</artifactId> diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala index 074db478d0..cd83147575 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/Catalog.scala @@ -217,6 +217,7 @@ object Catalog extends AbstractCatalog { function[ST_MLineFromText](0), function[ST_GeomCollFromText](0), function[ST_GeogCollFromText](0), + function[ST_GeomFromMySQL](), function[ST_Split](), function[ST_S2CellIDs](), function[ST_S2ToGeom](), diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala index 98e1636b34..d4b3d3efb1 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/expressions/Constructors.scala @@ -624,3 +624,11 @@ private[apache] case class ST_GeomCollFromText(inputExpressions: Seq[Expression] copy(inputExpressions = newChildren) } } + +private[apache] case class ST_GeomFromMySQL(inputExpressions: Seq[Expression]) + extends InferredExpression(Constructors.geomFromMySQL _) { + + protected def withNewChildrenInternal(newChildren: IndexedSeq[Expression]) = { + copy(inputExpressions = newChildren) + } +} diff --git a/spark/common/src/test/resources/mysql/init_mysql.sql b/spark/common/src/test/resources/mysql/init_mysql.sql new file mode 100644 index 0000000000..f7621a85bd --- /dev/null +++ b/spark/common/src/test/resources/mysql/init_mysql.sql @@ -0,0 +1,17 @@ +CREATE TABLE points +( + name VARCHAR(50), + location GEOMETRY +); + +INSERT INTO points (name, location) +VALUES ('Point A', + ST_GeomFromText('POINT(10 20)', 4326)); + +INSERT INTO points (name, location) +VALUES ('Point B', + ST_GeomFromText('POINT(30 40)', 4326)); + +INSERT INTO points (name, location) +VALUES ('Point C', + ST_GeomFromText('POINT(50 60)', 4326)); diff --git a/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala b/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala index 2dd9cdfedd..f9a5891a94 100644 --- a/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala +++ b/spark/common/src/test/scala/org/apache/sedona/sql/constructorTestScala.scala @@ -21,9 +21,12 @@ package org.apache.sedona.sql import org.apache.sedona.core.formatMapper.GeoJsonReader import org.apache.sedona.core.formatMapper.shapefileParser.ShapefileReader import org.apache.sedona.sql.utils.Adapter +import org.apache.spark.sql.Row import org.locationtech.jts.geom.{Geometry, LineString} +import org.scalatest.matchers.should.Matchers +import org.testcontainers.containers.MySQLContainer -class constructorTestScala extends TestBaseScala { +class constructorTestScala extends TestBaseScala with Matchers { import sparkSession.implicits._ @@ -626,5 +629,49 @@ class constructorTestScala extends TestBaseScala { val actualSrid = actualGeom.getSRID assert(4326 == actualSrid) } + + it("should properly read data from MySQL") { + val runTest = (jdbcURL: String) => { + val tableName = "points" + val properties = new java.util.Properties() + properties.setProperty("user", "sedona") + properties.setProperty("password", "sedona") + + sparkSession.read + .jdbc(jdbcURL, tableName, properties) + .selectExpr( + "ST_GeomFromMySQL(location) as geom", + "ST_SRID(ST_GeomFromMySQL(location)) AS srid") + .show + + val elements = sparkSession.read + .jdbc(jdbcURL, tableName, properties) + .selectExpr( + "ST_GeomFromMySQL(location) as geom", + "ST_SRID(ST_GeomFromMySQL(location)) AS srid") + .selectExpr("ST_AsText(geom) as geom", "srid") + .collect() + + elements.length shouldBe 3 + elements should contain theSameElementsAs Seq( + Row("POINT (20 10)", 4326), + Row("POINT (40 30)", 4326), + Row("POINT (60 50)", 4326)) + } + + val mysql = new MySQLContainer("mysql:9.1.0") + mysql.withInitScript("mysql/init_mysql.sql") + mysql.withUsername("sedona") + mysql.withPassword("sedona") + mysql.withDatabaseName("sedona") + + mysql.start() + + try { + runTest(mysql.getJdbcUrl) + } finally { + mysql.stop() + } + } } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala index 9449a291f5..b89fe93890 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/SedonaSparkEnv.scala @@ -1,20 +1,21 @@ /* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at * - * http://www.apache.org/licenses/LICENSE-2.0 + * http://www.apache.org/licenses/LICENSE-2.0 * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. */ - package org.apache.spark import java.io.File @@ -49,31 +50,33 @@ import org.apache.spark.storage._ import org.apache.spark.util.{RpcUtils, Utils} /** - * :: DeveloperApi :: - * Holds all the runtime environment objects for a running Spark instance (either master or worker), - * including the serializer, RpcEnv, block manager, map output tracker, etc. Currently - * Spark code finds the SparkEnv through a global variable, so all the threads can access the same - * SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a SparkContext). + * :: DeveloperApi :: Holds all the runtime environment objects for a running Spark instance + * (either master or worker), including the serializer, RpcEnv, block manager, map output tracker, + * etc. Currently Spark code finds the SparkEnv through a global variable, so all the threads can + * access the same SparkEnv. It can be accessed by SparkEnv.get (e.g. after creating a + * SparkContext). */ @DeveloperApi -class SedonaSparkEnv ( - val executorId: String, - private[spark] val rpcEnv: RpcEnv, - val serializer: Serializer, - val closureSerializer: Serializer, - val serializerManager: SerializerManager, - val mapOutputTracker: MapOutputTracker, - val shuffleManager: ShuffleManager, - val broadcastManager: BroadcastManager, - val blockManager: BlockManager, - val securityManager: SecurityManager, - val metricsSystem: MetricsSystem, - val memoryManager: MemoryManager, - val outputCommitCoordinator: OutputCommitCoordinator, - val conf: SparkConf) extends Logging { +class SedonaSparkEnv( + val executorId: String, + private[spark] val rpcEnv: RpcEnv, + val serializer: Serializer, + val closureSerializer: Serializer, + val serializerManager: SerializerManager, + val mapOutputTracker: MapOutputTracker, + val shuffleManager: ShuffleManager, + val broadcastManager: BroadcastManager, + val blockManager: BlockManager, + val securityManager: SecurityManager, + val metricsSystem: MetricsSystem, + val memoryManager: MemoryManager, + val outputCommitCoordinator: OutputCommitCoordinator, + val conf: SparkConf) + extends Logging { @volatile private[spark] var isStopped = false - private val pythonWorkers = mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() + private val pythonWorkers = + mutable.HashMap[(String, Map[String, String]), PythonWorkerFactory]() // A general, soft-reference map for metadata needed during HadoopRDD split computation // (e.g., HadoopFileRDD uses this to cache JobConfs and InputFormats). @@ -115,28 +118,29 @@ class SedonaSparkEnv ( } } - private[spark] - def createPythonWorker( - pythonExec: String, - envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { + private[spark] def createPythonWorker( + pythonExec: String, + envVars: Map[String, String]): (java.net.Socket, Option[Int]) = { synchronized { val key = (pythonExec, envVars) pythonWorkers.getOrElseUpdate(key, new PythonWorkerFactory(pythonExec, envVars)).create() } } - private[spark] - def destroyPythonWorker(pythonExec: String, - envVars: Map[String, String], worker: Socket): Unit = { + private[spark] def destroyPythonWorker( + pythonExec: String, + envVars: Map[String, String], + worker: Socket): Unit = { synchronized { val key = (pythonExec, envVars) pythonWorkers.get(key).foreach(_.stopWorker(worker)) } } - private[spark] - def releasePythonWorker(pythonExec: String, - envVars: Map[String, String], worker: Socket): Unit = { + private[spark] def releasePythonWorker( + pythonExec: String, + envVars: Map[String, String], + worker: Socket): Unit = { synchronized { val key = (pythonExec, envVars) pythonWorkers.get(key).foreach(_.releaseWorker(worker)) @@ -165,13 +169,14 @@ object SedonaSparkEnv extends Logging { * Create a SparkEnv for the driver. */ private[spark] def createDriverEnv( - conf: SparkConf, - isLocal: Boolean, - listenerBus: LiveListenerBus, - numCores: Int, - sparkContext: SparkContext, - mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { - assert(conf.contains(DRIVER_HOST_ADDRESS), + conf: SparkConf, + isLocal: Boolean, + listenerBus: LiveListenerBus, + numCores: Int, + sparkContext: SparkContext, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { + assert( + conf.contains(DRIVER_HOST_ADDRESS), s"${DRIVER_HOST_ADDRESS.key} is not set on the driver!") assert(conf.contains(DRIVER_PORT), s"${DRIVER_PORT.key} is not set on the driver!") val bindAddress = conf.get(DRIVER_BIND_ADDRESS) @@ -193,45 +198,35 @@ object SedonaSparkEnv extends Logging { ioEncryptionKey, listenerBus = listenerBus, Option(sparkContext), - mockOutputCommitCoordinator = mockOutputCommitCoordinator - ) + mockOutputCommitCoordinator = mockOutputCommitCoordinator) } /** - * Create a SparkEnv for an executor. - * In coarse-grained mode, the executor provides an RpcEnv that is already instantiated. + * Create a SparkEnv for an executor. In coarse-grained mode, the executor provides an RpcEnv + * that is already instantiated. */ private[spark] def createExecutorEnv( - conf: SparkConf, - executorId: String, - bindAddress: String, - hostname: String, - numCores: Int, - ioEncryptionKey: Option[Array[Byte]], - isLocal: Boolean): SparkEnv = { - val env = create( - conf, - executorId, - bindAddress, - hostname, - None, - isLocal, - numCores, - ioEncryptionKey - ) + conf: SparkConf, + executorId: String, + bindAddress: String, + hostname: String, + numCores: Int, + ioEncryptionKey: Option[Array[Byte]], + isLocal: Boolean): SparkEnv = { + val env = + create(conf, executorId, bindAddress, hostname, None, isLocal, numCores, ioEncryptionKey) SparkEnv.set(env) env } private[spark] def createExecutorEnv( - conf: SparkConf, - executorId: String, - hostname: String, - numCores: Int, - ioEncryptionKey: Option[Array[Byte]], - isLocal: Boolean): SparkEnv = { - createExecutorEnv(conf, executorId, hostname, - hostname, numCores, ioEncryptionKey, isLocal) + conf: SparkConf, + executorId: String, + hostname: String, + numCores: Int, + ioEncryptionKey: Option[Array[Byte]], + isLocal: Boolean): SparkEnv = { + createExecutorEnv(conf, executorId, hostname, hostname, numCores, ioEncryptionKey, isLocal) } /** @@ -239,17 +234,17 @@ object SedonaSparkEnv extends Logging { */ // scalastyle:off argcount private def create( - conf: SparkConf, - executorId: String, - bindAddress: String, - advertiseAddress: String, - port: Option[Int], - isLocal: Boolean, - numUsableCores: Int, - ioEncryptionKey: Option[Array[Byte]], - listenerBus: LiveListenerBus = null, - sc: Option[SparkContext] = None, - mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { + conf: SparkConf, + executorId: String, + bindAddress: String, + advertiseAddress: String, + port: Option[Int], + isLocal: Boolean, + numUsableCores: Int, + ioEncryptionKey: Option[Array[Byte]], + listenerBus: LiveListenerBus = null, + sc: Option[SparkContext] = None, + mockOutputCommitCoordinator: Option[OutputCommitCoordinator] = None): SparkEnv = { // scalastyle:on argcount val isDriver = executorId == SparkContext.DRIVER_IDENTIFIER @@ -266,14 +261,22 @@ object SedonaSparkEnv extends Logging { ioEncryptionKey.foreach { _ => if (!securityManager.isEncryptionEnabled()) { - logWarning("I/O encryption enabled without RPC encryption: keys will be visible on the " + - "wire.") + logWarning( + "I/O encryption enabled without RPC encryption: keys will be visible on the " + + "wire.") } } val systemName = if (isDriver) driverSystemName else executorSystemName - val rpcEnv = RpcEnv.create(systemName, bindAddress, advertiseAddress, port.getOrElse(-1), conf, - securityManager, numUsableCores, !isDriver) + val rpcEnv = RpcEnv.create( + systemName, + bindAddress, + advertiseAddress, + port.getOrElse(-1), + conf, + securityManager, + numUsableCores, + !isDriver) // Figure out which port RpcEnv actually bound to in case the original port is 0 or occupied. if (isDriver) { @@ -288,8 +291,8 @@ object SedonaSparkEnv extends Logging { val closureSerializer = new JavaSerializer(conf) def registerOrLookupEndpoint( - name: String, endpointCreator: => RpcEndpoint): - RpcEndpointRef = { + name: String, + endpointCreator: => RpcEndpoint): RpcEndpointRef = { if (isDriver) { logInfo("Registering " + name) rpcEnv.setupEndpoint(name, endpointCreator) @@ -308,9 +311,12 @@ object SedonaSparkEnv extends Logging { // Have to assign trackerEndpoint after initialization as MapOutputTrackerEndpoint // requires the MapOutputTracker itself - mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint(MapOutputTracker.ENDPOINT_NAME, + mapOutputTracker.trackerEndpoint = registerOrLookupEndpoint( + MapOutputTracker.ENDPOINT_NAME, new MapOutputTrackerMasterEndpoint( - rpcEnv, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], conf)) + rpcEnv, + mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], + conf)) // Let the user specify short names for shuffle managers val shortShuffleMgrNames = Map( @@ -319,8 +325,8 @@ object SedonaSparkEnv extends Logging { val shuffleMgrName = conf.get(config.SHUFFLE_MANAGER) val shuffleMgrClass = shortShuffleMgrNames.getOrElse(shuffleMgrName.toLowerCase(Locale.ROOT), shuffleMgrName) - val shuffleManager = Utils.instantiateSerializerOrShuffleManager[ShuffleManager]( - shuffleMgrClass, conf, isDriver) + val shuffleManager = + Utils.instantiateSerializerOrShuffleManager[ShuffleManager](shuffleMgrClass, conf, isDriver) val memoryManager: MemoryManager = UnifiedMemoryManager(conf, numUsableCores) @@ -332,8 +338,12 @@ object SedonaSparkEnv extends Logging { val externalShuffleClient = if (conf.get(config.SHUFFLE_SERVICE_ENABLED)) { val transConf = SparkTransportConf.fromSparkConf(conf, "shuffle", numUsableCores) - Some(new ExternalBlockStoreClient(transConf, securityManager, - securityManager.isAuthenticationEnabled(), conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT))) + Some( + new ExternalBlockStoreClient( + transConf, + securityManager, + securityManager.isAuthenticationEnabled(), + conf.get(config.SHUFFLE_REGISTRATION_TIMEOUT))) } else { None } @@ -352,7 +362,8 @@ object SedonaSparkEnv extends Logging { externalShuffleClient } else { None - }, blockManagerInfo, + }, + blockManagerInfo, mapOutputTracker.asInstanceOf[MapOutputTrackerMaster], shuffleManager, isDriver)), @@ -363,8 +374,15 @@ object SedonaSparkEnv extends Logging { isDriver) val blockTransferService = - new NettyBlockTransferService(conf, securityManager, serializerManager, bindAddress, - advertiseAddress, blockManagerPort, numUsableCores, blockManagerMaster.driverEndpoint) + new NettyBlockTransferService( + conf, + securityManager, + serializerManager, + bindAddress, + advertiseAddress, + blockManagerPort, + numUsableCores, + blockManagerMaster.driverEndpoint) // NB: blockManager is not valid until initialize() is called later. val blockManager = new BlockManager( @@ -403,7 +421,8 @@ object SedonaSparkEnv extends Logging { } } - val outputCommitCoordinatorRef = registerOrLookupEndpoint("OutputCommitCoordinator", + val outputCommitCoordinatorRef = registerOrLookupEndpoint( + "OutputCommitCoordinator", new OutputCommitCoordinatorEndpoint(rpcEnv, outputCommitCoordinator)) outputCommitCoordinator.coordinatorRef = Some(outputCommitCoordinatorRef) @@ -427,7 +446,8 @@ object SedonaSparkEnv extends Logging { // called, and we only need to do it for driver. Because driver may run as a service, and if we // don't delete this tmp dir when sc is stopped, then will create too many tmp dirs. if (isDriver) { - val sparkFilesDir = Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath + val sparkFilesDir = + Utils.createTempDir(Utils.getLocalDir(conf), "userFiles").getAbsolutePath envInstance.driverTmpDir = Some(sparkFilesDir) } @@ -440,20 +460,19 @@ object SedonaSparkEnv extends Logging { * attributes as a sequence of KV pairs. This is used mainly for SparkListenerEnvironmentUpdate. */ private[spark] def environmentDetails( - conf: SparkConf, - hadoopConf: Configuration, - schedulingMode: String, - addedJars: Seq[String], - addedFiles: Seq[String], - addedArchives: Seq[String], - metricsProperties: Map[String, String]): Map[String, Seq[(String, String)]] = { + conf: SparkConf, + hadoopConf: Configuration, + schedulingMode: String, + addedJars: Seq[String], + addedFiles: Seq[String], + addedArchives: Seq[String], + metricsProperties: Map[String, String]): Map[String, Seq[(String, String)]] = { import Properties._ val jvmInformation = Seq( ("Java Version", s"$javaVersion ($javaVendor)"), ("Java Home", javaHome), - ("Scala Version", versionString) - ).sorted + ("Scala Version", versionString)).sorted // Spark properties // This includes the scheduling mode whether or not it is configured (used by SparkUI) @@ -482,7 +501,9 @@ object SedonaSparkEnv extends Logging { // Add Hadoop properties, it will not ignore configs including in Spark. Some spark // conf starting with "spark.hadoop" may overwrite it. val hadoopProperties = hadoopConf.asScala - .map(entry => (entry.getKey, entry.getValue)).toSeq.sorted + .map(entry => (entry.getKey, entry.getValue)) + .toSeq + .sorted Map[String, Seq[(String, String)]]( "JVM Information" -> jvmInformation, "Spark Properties" -> sparkProperties, @@ -492,4 +513,3 @@ object SedonaSparkEnv extends Logging { "Metrics Properties" -> metricsProperties.toSeq.sorted) } } - diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala index 6656d85f5c..c510d0cd93 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ package org.apache.spark.api.python /* @@ -35,7 +53,6 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.util.control.NonFatal - /** * Enumerate the type of command that will be sent to the Python worker */ @@ -92,11 +109,11 @@ private object SedonaBasePythonRunner { * functions (from bottom to top). */ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( - protected val funcs: Seq[ChainedPythonFunctions], - protected val evalType: Int, - protected val argOffsets: Array[Array[Int]], - protected val jobArtifactUUID: Option[String]) - extends Logging { + protected val funcs: Seq[ChainedPythonFunctions], + protected val evalType: Int, + protected val argOffsets: Array[Array[Int]], + protected val jobArtifactUUID: Option[String]) + extends Logging { require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") @@ -131,9 +148,9 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( } def compute( - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext): Iterator[OUT] = { + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): Iterator[OUT] = { val startTime = System.currentTimeMillis val sedonaEnv = SedonaSparkEnv.get val env = SparkEnv.get @@ -170,8 +187,8 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) - val (worker: Socket, pid: Option[Int]) = env.createPythonWorker( - pythonExec, envVars.asScala.toMap) + val (worker: Socket, pid: Option[Int]) = + env.createPythonWorker(pythonExec, envVars.asScala.toMap) // Whether is the worker released into idle pool or closed. When any codes try to release or // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make // sure there is only one winner that is going to release or close the worker. @@ -209,38 +226,45 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) val stdoutIterator = newReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) + stream, + writerThread, + startTime, + env, + worker, + pid, + releasedOrClosed, + context) new InterruptibleIterator(context, stdoutIterator) } protected def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext): WriterThread + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): WriterThread protected def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[OUT] + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[OUT] /** * The thread responsible for writing the data from the PythonRDD's parent iterator to the * Python process. */ abstract class WriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext) - extends Thread(s"stdout writer for $pythonExec") { + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext) + extends Thread(s"stdout writer for $pythonExec") { @volatile private var _exception: Throwable = null @@ -253,8 +277,8 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( def exception: Option[Throwable] = Option(_exception) /** - * Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur - * due to cleanup. + * Terminates the writer thread and waits for it to exit, ignoring any exceptions that may + * occur due to cleanup. */ def shutdownOnTaskCompletion(): Unit = { assert(context.isCompleted) @@ -290,9 +314,11 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( // Init a ServerSocket to accept method calls from Python side. val isBarrier = context.isInstanceOf[BarrierTaskContext] if (isBarrier) { - serverSocket = Some(new ServerSocket(/* port */ 0, - /* backlog */ 1, - InetAddress.getByName("localhost"))) + serverSocket = Some( + new ServerSocket( + /* port */ 0, + /* backlog */ 1, + InetAddress.getByName("localhost"))) // A call to accept() for ServerSocket shall block infinitely. serverSocket.foreach(_.setSoTimeout(0)) new Thread("accept-connections") { @@ -320,8 +346,8 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( input.readFully(message) barrierAndServe(requestMethod, sock, new String(message, UTF_8)) case _ => - val out = new DataOutputStream(new BufferedOutputStream( - sock.getOutputStream)) + val out = + new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) } } catch { @@ -383,9 +409,11 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( } // sparkFilesDir - val root = jobArtifactUUID.map { uuid => - new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath - }.getOrElse(SparkFiles.getRootDirectory()) + val root = jobArtifactUUID + .map { uuid => + new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath + } + .getOrElse(SparkFiles.getRootDirectory()) PythonRDD.writeUTF(root, dataOut) // Python includes (*.zip and *.egg files) dataOut.writeInt(pythonIncludes.size) @@ -455,8 +483,10 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( } catch { case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => if (context.isCompleted || context.isInterrupted) { - logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + - "cleanup)", t) + logDebug( + "Exception/NonFatal Error thrown after task completion (likely due to " + + "cleanup)", + t) if (!worker.isClosed) { Utils.tryLog(worker.shutdownOutput()) } @@ -478,8 +508,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { require( serverSocket.isDefined, - "No available ServerSocket to redirect the BarrierTaskContext method call." - ) + "No available ServerSocket to redirect the BarrierTaskContext method call.") val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) try { val messages = requestMethod match { @@ -507,15 +536,15 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( } abstract class ReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext) - extends Iterator[OUT] { + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext) + extends Iterator[OUT] { private var nextObj: OUT = _ private var eos = false @@ -540,9 +569,8 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( } /** - * Reads next object from the stream. - * When the stream reaches end of data, needs to process the following sections, - * and then returns null. + * Reads next object from the stream. When the stream reaches end of data, needs to process + * the following sections, and then returns null. */ protected def read(): OUT @@ -555,8 +583,8 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( val init = initTime - bootTime val finish = finishTime - initTime val total = finishTime - startTime - logInfo("Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, - init, finish)) + logInfo( + "Times: total = %s, boot = %s, init = %s, finish = %s".format(total, boot, init, finish)) val memoryBytesSpilled = stream.readLong() val diskBytesSpilled = stream.readLong() context.taskMetrics.incMemoryBytesSpilled(memoryBytesSpilled) @@ -568,8 +596,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( val exLength = stream.readInt() val obj = new Array[Byte](exLength) stream.readFully(obj) - new PythonException(new String(obj, StandardCharsets.UTF_8), - writerThread.exception.orNull) + new PythonException(new String(obj, StandardCharsets.UTF_8), writerThread.exception.orNull) } protected def handleEndOfDataSection(): Unit = { @@ -601,8 +628,9 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( logError("This may have been caused by a prior exception:", writerThread.exception.get) throw writerThread.exception.get - case eof: EOFException if faultHandlerEnabled && pid.isDefined && - JavaFiles.exists(SedonaBasePythonRunner.faultHandlerLogPath(pid.get)) => + case eof: EOFException + if faultHandlerEnabled && pid.isDefined && + JavaFiles.exists(SedonaBasePythonRunner.faultHandlerLogPath(pid.get)) => val path = SedonaBasePythonRunner.faultHandlerLogPath(pid.get) val error = String.join("\n", JavaFiles.readAllLines(path)) + "\n" JavaFiles.deleteIfExists(path) @@ -619,7 +647,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( * threads can block indefinitely. */ class MonitorThread(env: SparkEnv, worker: Socket, context: TaskContext) - extends Thread(s"Worker Monitor for $pythonExec") { + extends Thread(s"Worker Monitor for $pythonExec") { /** How long to wait before killing the python worker if a task cannot be interrupted. */ private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) @@ -666,16 +694,19 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( * * A deadlock can arise if the task completes while the writer thread is sending input to the * Python process (e.g. due to the use of `take()`), and the Python process is still producing - * output. When the inputs are sufficiently large, this can result in a deadlock due to the use of - * blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the socket. + * output. When the inputs are sufficiently large, this can result in a deadlock due to the use + * of blocking I/O (SPARK-38677). To resolve the deadlock, we need to close the socket. */ class WriterMonitorThread( - env: SparkEnv, worker: Socket, writerThread: WriterThread, context: TaskContext) - extends Thread(s"Writer Monitor for $pythonExec (writer thread id ${writerThread.getId})") { + env: SparkEnv, + worker: Socket, + writerThread: WriterThread, + context: TaskContext) + extends Thread(s"Writer Monitor for $pythonExec (writer thread id ${writerThread.getId})") { /** - * How long to wait before closing the socket if the writer thread has not exited after the task - * ends. + * How long to wait before closing the socket if the writer thread has not exited after the + * task ends. */ private val taskKillTimeout = env.conf.get(PYTHON_TASK_KILL_TIMEOUT) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala index 27e4b851ee..976d034e08 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ package org.apache.spark.sql.execution.python /* @@ -28,23 +46,25 @@ import org.apache.spark.sql.vectorized.ColumnarBatch * Similar to `PythonUDFRunner`, but exchange data with Python worker via Arrow stream. */ class SedonaArrowPythonRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - protected override val schema: StructType, - protected override val timeZoneId: String, - protected override val largeVarTypes: Boolean, - protected override val workerConf: Map[String, String], - val pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) - extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, evalType, argOffsets, jobArtifactUUID) + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + protected override val schema: StructType, + protected override val timeZoneId: String, + protected override val largeVarTypes: Boolean, + protected override val workerConf: Map[String, String], + val pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch]( + funcs, + evalType, + argOffsets, + jobArtifactUUID) with SedonaBasicPythonArrowInput with SedonaBasicPythonArrowOutput { override val pythonExec: String = - SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( - funcs.head.funcs.head.pythonExec) + SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(funcs.head.funcs.head.pythonExec) override val errorOnDuplicatedFieldNames: Boolean = true @@ -58,13 +78,16 @@ class SedonaArrowPythonRunner( } object SedonaArrowPythonRunner { + /** Return Map with conf settings to be used in ArrowPythonRunner */ def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) - val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> - conf.pandasGroupedMapAssignColumnsByName.toString) - val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key -> - conf.arrowSafeTypeConversion.toString) + val pandasColsByName = Seq( + SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> + conf.pandasGroupedMapAssignColumnsByName.toString) + val arrowSafeTypeCheck = Seq( + SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key -> + conf.arrowSafeTypeConversion.toString) Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*) } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala index 3869ab24b8..375d6536ca 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala @@ -64,7 +64,9 @@ class SedonaArrowStrategy extends Strategy { object SedonaUnsafeProjection { def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): UnsafeProjection = { - GenerateUnsafeProjection.generate(bindReferences(exprs, inputSchema), SQLConf.get.subexpressionEliminationEnabled) + GenerateUnsafeProjection.generate( + bindReferences(exprs, inputSchema), + SQLConf.get.subexpressionEliminationEnabled) // createObject(bindReferences(exprs, inputSchema)) } } @@ -127,13 +129,12 @@ case class SedonaArrowEvalPythonExec( override def doExecute(): RDD[InternalRow] = { - val customProjection = new Projection with Serializable{ - def apply(row: InternalRow): InternalRow = { - row match { - case joinedRow: JoinedRow => - val arrowField = joinedRow.getRight.asInstanceOf[ColumnarBatchRow] - val left = joinedRow.getLeft - + val customProjection = new Projection with Serializable { + def apply(row: InternalRow): InternalRow = { + row match { + case joinedRow: JoinedRow => + val arrowField = joinedRow.getRight.asInstanceOf[ColumnarBatchRow] + val left = joinedRow.getLeft // resultAttrs.zipWithIndex.map { // case (x, y) => @@ -153,8 +154,8 @@ case class SedonaArrowEvalPythonExec( // // println("ssss") // arrowField. - row - // We need to convert JoinedRow to UnsafeRow + row + // We need to convert JoinedRow to UnsafeRow // val leftUnsafe = left.asInstanceOf[UnsafeRow] // val rightUnsafe = right.asInstanceOf[UnsafeRow] // val joinedUnsafe = new UnsafeRow(leftUnsafe.numFields + rightUnsafe.numFields) @@ -165,15 +166,15 @@ case class SedonaArrowEvalPythonExec( // joinedUnsafe.setRight(leftUnsafe) // joinedUnsafe // val wktReader = new WKTReader() - val resultProj = SedonaUnsafeProjection.create(output, output) + val resultProj = SedonaUnsafeProjection.create(output, output) // val WKBWriter = new org.locationtech.jts.io.WKBWriter() - resultProj(new JoinedRow(left, arrowField)) - case _ => - println(row.getClass) - throw new UnsupportedOperationException("Unsupported row type") - } + resultProj(new JoinedRow(left, arrowField)) + case _ => + println(row.getClass) + throw new UnsupportedOperationException("Unsupported row type") } } + } val inputRDD = child.execute().map(_.copy()) inputRDD.mapPartitions { iter => @@ -182,8 +183,10 @@ case class SedonaArrowEvalPythonExec( // The queue used to buffer input rows so we can drain it to // combine input with output from Python. - val queue = HybridRowQueue(context.taskMemoryManager(), - new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) + val queue = HybridRowQueue( + context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), + child.output.length) context.addTaskCompletionListener[Unit] { ctx => queue.close() } @@ -216,8 +219,7 @@ case class SedonaArrowEvalPythonExec( projection(inputRow) } - val outputRowIterator = evaluate( - pyFuncs, argOffsets, projectedRowIter, schema, context) + val outputRowIterator = evaluate(pyFuncs, argOffsets, projectedRowIter, schema, context) val joined = new JoinedRow @@ -232,18 +234,17 @@ case class SedonaArrowEvalPythonExec( val row = new GenericInternalRow(numFields) - resultAttrs.zipWithIndex.map { - case (attr, index) => - if (attr.dataType.isInstanceOf[GeometryUDT]) { - // Convert the geometry type to WKB - val wkbReader = new org.locationtech.jts.io.WKBReader() - val wkbWriter = new org.locationtech.jts.io.WKBWriter() - val geom = wkbReader.read(projected.getBinary(startField + index)) + resultAttrs.zipWithIndex.map { case (attr, index) => + if (attr.dataType.isInstanceOf[GeometryUDT]) { + // Convert the geometry type to WKB + val wkbReader = new org.locationtech.jts.io.WKBReader() + val wkbWriter = new org.locationtech.jts.io.WKBWriter() + val geom = wkbReader.read(projected.getBinary(startField + index)) - row.update(startField + index, wkbWriter.write(geom)) + row.update(startField + index, wkbWriter.write(geom)) - println("ssss") - } + println("ssss") + } } println("ssss") diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala index 58166d173d..bf33cde1c1 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ package org.apache.spark.sql.execution.python /* @@ -34,32 +52,32 @@ private[sql] object SedonaArrowUtils { // todo: support more types. /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for TimestampTypes */ - def toArrowType( - dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = dt match { - case BooleanType => ArrowType.Bool.INSTANCE - case ByteType => new ArrowType.Int(8, true) - case ShortType => new ArrowType.Int(8 * 2, true) - case IntegerType => new ArrowType.Int(8 * 4, true) - case LongType => new ArrowType.Int(8 * 8, true) - case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) - case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) - case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE - case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE - case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE - case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE - case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) - case DateType => new ArrowType.Date(DateUnit.DAY) - case TimestampType if timeZoneId == null => - throw new IllegalStateException("Missing timezoneId where it is mandatory.") - case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) - case TimestampNTZType => - new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) - case NullType => ArrowType.Null.INSTANCE - case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) - case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) - case _ => - throw ExecutionErrors.unsupportedDataTypeError(dt) - } + def toArrowType(dt: DataType, timeZoneId: String, largeVarTypes: Boolean = false): ArrowType = + dt match { + case BooleanType => ArrowType.Bool.INSTANCE + case ByteType => new ArrowType.Int(8, true) + case ShortType => new ArrowType.Int(8 * 2, true) + case IntegerType => new ArrowType.Int(8 * 4, true) + case LongType => new ArrowType.Int(8 * 8, true) + case FloatType => new ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE) + case DoubleType => new ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE) + case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE + case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE + case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE + case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE + case DecimalType.Fixed(precision, scale) => new ArrowType.Decimal(precision, scale) + case DateType => new ArrowType.Date(DateUnit.DAY) + case TimestampType if timeZoneId == null => + throw new IllegalStateException("Missing timezoneId where it is mandatory.") + case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, timeZoneId) + case TimestampNTZType => + new ArrowType.Timestamp(TimeUnit.MICROSECOND, null) + case NullType => ArrowType.Null.INSTANCE + case _: YearMonthIntervalType => new ArrowType.Interval(IntervalUnit.YEAR_MONTH) + case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND) + case _ => + throw ExecutionErrors.unsupportedDataTypeError(dt) + } def fromArrowType(dt: ArrowType): DataType = dt match { case ArrowType.Bool.INSTANCE => BooleanType @@ -68,9 +86,11 @@ private[sql] object SedonaArrowUtils { case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => IntegerType case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => LongType case float: ArrowType.FloatingPoint - if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType + if float.getPrecision() == FloatingPointPrecision.SINGLE => + FloatType case float: ArrowType.FloatingPoint - if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType + if float.getPrecision() == FloatingPointPrecision.DOUBLE => + DoubleType case ArrowType.Utf8.INSTANCE => StringType case ArrowType.Binary.INSTANCE => BinaryType case ArrowType.LargeUtf8.INSTANCE => StringType @@ -78,59 +98,72 @@ private[sql] object SedonaArrowUtils { case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale) case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType case ts: ArrowType.Timestamp - if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => TimestampNTZType + if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => + TimestampNTZType case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => TimestampType case ArrowType.Null.INSTANCE => NullType - case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => YearMonthIntervalType() + case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => + YearMonthIntervalType() case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => DayTimeIntervalType() case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt) } /** Maps field from Spark to Arrow. NOTE: timeZoneId required for TimestampType */ def toArrowField( - name: String, - dt: DataType, - nullable: Boolean, - timeZoneId: String, - largeVarTypes: Boolean = false): Field = { + name: String, + dt: DataType, + nullable: Boolean, + timeZoneId: String, + largeVarTypes: Boolean = false): Field = { dt match { case GeometryUDT => - val jsonData = """{"crs": {"$schema": "https://proj.org/schemas/v0.7/projjson.schema.json", "type": "GeographicCRS", "name": "WGS 84", "datum_ensemble": {"name": "World Geodetic System 1984 ensemble", "members": [{"name": "World Geodetic System 1984 (Transit)", "id": {"authority": "EPSG", "code": 1166}}, {"name": "World Geodetic System 1984 (G730)", "id": {"authority": "EPSG", "code": 1152}}, {"name": "World Geodetic System 1984 (G873)", "id": {"authority": "EPSG", "code": 1153}} [...] + val jsonData = + """{"crs": {"$schema": "https://proj.org/schemas/v0.7/projjson.schema.json", "type": "GeographicCRS", "name": "WGS 84", "datum_ensemble": {"name": "World Geodetic System 1984 ensemble", "members": [{"name": "World Geodetic System 1984 (Transit)", "id": {"authority": "EPSG", "code": 1166}}, {"name": "World Geodetic System 1984 (G730)", "id": {"authority": "EPSG", "code": 1152}}, {"name": "World Geodetic System 1984 (G873)", "id": {"authority": "EPSG", "code": 1153}}, {"name": "W [...] val metadata = Map( "ARROW:extension:name" -> "geoarrow.wkb", - "ARROW:extension:metadata" -> jsonData, - ).asJava + "ARROW:extension:metadata" -> jsonData).asJava val fieldType = new FieldType(nullable, ArrowType.Binary.INSTANCE, null, metadata) new Field(name, fieldType, Seq.empty[Field].asJava) case ArrayType(elementType, containsNull) => val fieldType = new FieldType(nullable, ArrowType.List.INSTANCE, null) - new Field(name, fieldType, - Seq(toArrowField("element", elementType, containsNull, timeZoneId, - largeVarTypes)).asJava) + new Field( + name, + fieldType, + Seq( + toArrowField("element", elementType, containsNull, timeZoneId, largeVarTypes)).asJava) case StructType(fields) => val fieldType = new FieldType(nullable, ArrowType.Struct.INSTANCE, null) - new Field(name, fieldType, - fields.map { field => - toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes) - }.toSeq.asJava) + new Field( + name, + fieldType, + fields + .map { field => + toArrowField(field.name, field.dataType, field.nullable, timeZoneId, largeVarTypes) + } + .toSeq + .asJava) case MapType(keyType, valueType, valueContainsNull) => val mapType = new FieldType(nullable, new ArrowType.Map(false), null) // Note: Map Type struct can not be null, Struct Type key field can not be null - new Field(name, mapType, - Seq(toArrowField(MapVector.DATA_VECTOR_NAME, - new StructType() - .add(MapVector.KEY_NAME, keyType, nullable = false) - .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), - nullable = false, - timeZoneId, - largeVarTypes)).asJava) + new Field( + name, + mapType, + Seq( + toArrowField( + MapVector.DATA_VECTOR_NAME, + new StructType() + .add(MapVector.KEY_NAME, keyType, nullable = false) + .add(MapVector.VALUE_NAME, valueType, nullable = valueContainsNull), + nullable = false, + timeZoneId, + largeVarTypes)).asJava) case udt: UserDefinedType[_] => toArrowField(name, udt.sqlType, nullable, timeZoneId, largeVarTypes) case dataType => - val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId, - largeVarTypes), null) + val fieldType = + new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes), null) new Field(name, fieldType, Seq.empty[Field].asJava) } } @@ -156,12 +189,14 @@ private[sql] object SedonaArrowUtils { } } - /** Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType */ + /** + * Maps schema from Spark to Arrow. NOTE: timeZoneId required for TimestampType in StructType + */ def toArrowSchema( - schema: StructType, - timeZoneId: String, - errorOnDuplicatedFieldNames: Boolean, - largeVarTypes: Boolean = false): Schema = { + schema: StructType, + timeZoneId: String, + errorOnDuplicatedFieldNames: Boolean, + largeVarTypes: Boolean = false): Schema = { new Schema(schema.map { field => toArrowField( field.name, @@ -180,9 +215,11 @@ private[sql] object SedonaArrowUtils { } private def deduplicateFieldNames( - dt: DataType, errorOnDuplicatedFieldNames: Boolean): DataType = dt match { + dt: DataType, + errorOnDuplicatedFieldNames: Boolean): DataType = dt match { case geometryType: GeometryUDT => geometryType - case udt: UserDefinedType[_] => deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) + case udt: UserDefinedType[_] => + deduplicateFieldNames(udt.sqlType, errorOnDuplicatedFieldNames) case st @ StructType(fields) => val newNames = if (st.names.toSet.size == st.names.length) { st.names @@ -201,7 +238,10 @@ private[sql] object SedonaArrowUtils { val newFields = fields.zip(newNames).map { case (StructField(_, dataType, nullable, metadata), name) => StructField( - name, deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), nullable, metadata) + name, + deduplicateFieldNames(dataType, errorOnDuplicatedFieldNames), + nullable, + metadata) } StructType(newFields) case ArrayType(elementType, containsNull) => diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala index 6791015ae9..178227a66d 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ package org.apache.spark.sql.execution.python /* @@ -33,8 +51,8 @@ import java.io.DataOutputStream import java.net.Socket /** - * A trait that can be mixed-in with [[python.BasePythonRunner]]. It implements the logic from - * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). + * A trait that can be mixed-in with [[python.BasePythonRunner]]. It implements the logic from JVM + * (an iterator of internal rows + additional data if required) to Python (Arrow). */ private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[IN, _] => protected val workerConf: Map[String, String] @@ -50,15 +68,15 @@ private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[ protected def pythonMetrics: Map[String, SQLMetric] protected def writeIteratorToArrowStream( - root: VectorSchemaRoot, - writer: ArrowStreamWriter, - dataOut: DataOutputStream, - inputIterator: Iterator[IN]): Unit + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[IN]): Unit protected def writeUDF( - dataOut: DataOutputStream, - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]]): Unit = + dataOut: DataOutputStream, + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]]): Unit = SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { @@ -71,11 +89,11 @@ private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[ } protected override def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext): WriterThread = { + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext): WriterThread = { new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { @@ -85,9 +103,14 @@ private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[ protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { val arrowSchema = SedonaArrowUtils.toArrowSchema( - schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) + schema, + timeZoneId, + errorOnDuplicatedFieldNames, + largeVarTypes) val allocator = SedonaArrowUtils.rootAllocator.newChildAllocator( - s"stdout writer for $pythonExec", 0, Long.MaxValue) + s"stdout writer for $pythonExec", + 0, + Long.MaxValue) val root = VectorSchemaRoot.create(arrowSchema, allocator) Utils.tryWithSafeFinally { @@ -119,15 +142,15 @@ private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[ } } - -private[python] trait SedonaBasicPythonArrowInput extends SedonaPythonArrowInput[Iterator[InternalRow]] { +private[python] trait SedonaBasicPythonArrowInput + extends SedonaPythonArrowInput[Iterator[InternalRow]] { self: SedonaBasePythonRunner[Iterator[InternalRow], _] => protected def writeIteratorToArrowStream( - root: VectorSchemaRoot, - writer: ArrowStreamWriter, - dataOut: DataOutputStream, - inputIterator: Iterator[Iterator[InternalRow]]): Unit = { + root: VectorSchemaRoot, + writer: ArrowStreamWriter, + dataOut: DataOutputStream, + inputIterator: Iterator[Iterator[InternalRow]]): Unit = { val arrowWriter = ArrowWriter.create(root) while (inputIterator.hasNext) { diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala index f2c8543537..12f6e60eb9 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ package org.apache.spark.sql.execution.python /* @@ -32,32 +50,42 @@ import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ /** - * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from - * Python (Arrow) to JVM (output type being deserialized from ColumnarBatch). + * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from Python + * (Arrow) to JVM (output type being deserialized from ColumnarBatch). */ -private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: SedonaBasePythonRunner[_, OUT] => +private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { + self: SedonaBasePythonRunner[_, OUT] => protected def pythonMetrics: Map[String, SQLMetric] - protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } + protected def handleMetadataAfterExec(stream: DataInputStream): Unit = {} protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT protected def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[OUT] = { + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[OUT] = { new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + stream, + writerThread, + startTime, + env, + worker, + pid, + releasedOrClosed, + context) { private val allocator = ArrowUtils.rootAllocator.newChildAllocator( - s"stdin reader for $pythonExec", 0, Long.MaxValue) + s"stdin reader for $pythonExec", + 0, + Long.MaxValue) private var reader: ArrowStreamReader = _ private var root: VectorSchemaRoot = _ @@ -107,9 +135,13 @@ private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: SedonaBaseP reader = new ArrowStreamReader(stream, allocator) root = reader.getVectorSchemaRoot() schema = ArrowUtils.fromArrowSchema(root.getSchema()) - vectors = root.getFieldVectors().asScala.map { vector => - new ArrowColumnVector(vector) - }.toArray[ColumnVector] + vectors = root + .getFieldVectors() + .asScala + .map { vector => + new ArrowColumnVector(vector) + } + .toArray[ColumnVector] read() case SpecialLengths.TIMING_DATA => handleTimingData() @@ -127,10 +159,11 @@ private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: SedonaBaseP } } -private[python] trait SedonaBasicPythonArrowOutput extends SedonaPythonArrowOutput[ColumnarBatch] { +private[python] trait SedonaBasicPythonArrowOutput + extends SedonaPythonArrowOutput[ColumnarBatch] { self: SedonaBasePythonRunner[_, ColumnarBatch] => protected def deserializeColumnarBatch( - batch: ColumnarBatch, - schema: StructType): ColumnarBatch = batch + batch: ColumnarBatch, + schema: StructType): ColumnarBatch = batch } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala index ced32cf801..56bfb782b1 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala @@ -1,3 +1,21 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ package org.apache.spark.sql.execution.python /* @@ -29,27 +47,29 @@ import org.apache.spark.sql.internal.SQLConf * A helper class to run Python UDFs in Spark. */ abstract class SedonaBasePythonUDFRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) - extends SedonaBasePythonRunner[Array[Byte], Array[Byte]]( - funcs, evalType, argOffsets, jobArtifactUUID) { + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends SedonaBasePythonRunner[Array[Byte], Array[Byte]]( + funcs, + evalType, + argOffsets, + jobArtifactUUID) { override val pythonExec: String = - SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( - funcs.head.funcs.head.pythonExec) + SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(funcs.head.funcs.head.pythonExec) override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback abstract class SedonaPythonUDFWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[Array[Byte]], - partitionIndex: Int, - context: TaskContext) - extends WriterThread(env, worker, inputIterator, partitionIndex, context) { + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext) + extends WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { val startData = dataOut.size() @@ -63,16 +83,23 @@ abstract class SedonaBasePythonUDFRunner( } protected override def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[Array[Byte]] = { + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[Array[Byte]] = { new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + stream, + writerThread, + startTime, + env, + worker, + pid, + releasedOrClosed, + context) { protected override def read(): Array[Byte] = { if (writerThread.exception.isDefined) { @@ -102,19 +129,24 @@ abstract class SedonaBasePythonUDFRunner( } class SedonaPythonUDFRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) - extends SedonaBasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, jobArtifactUUID) { + funcs: Seq[ChainedPythonFunctions], + evalType: Int, + argOffsets: Array[Array[Int]], + pythonMetrics: Map[String, SQLMetric], + jobArtifactUUID: Option[String]) + extends SedonaBasePythonUDFRunner( + funcs, + evalType, + argOffsets, + pythonMetrics, + jobArtifactUUID) { protected override def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[Array[Byte]], - partitionIndex: Int, - context: TaskContext): WriterThread = { + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[Array[Byte]], + partitionIndex: Int, + context: TaskContext): WriterThread = { new SedonaPythonUDFWriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { @@ -128,9 +160,9 @@ class SedonaPythonUDFRunner( object SedonaPythonUDFRunner { def writeUDFs( - dataOut: DataOutputStream, - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]]): Unit = { + dataOut: DataOutputStream, + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]]): Unit = { dataOut.writeInt(funcs.length) funcs.zip(argOffsets).foreach { case (chained, offsets) => dataOut.writeInt(offsets.length) diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala index 77ab4abbb8..0a6b416314 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -45,14 +45,14 @@ class StrategySuite extends AnyFunSuite with Matchers { test("sedona geospatial UDF") { // spark.sql("select 1").show() - val df = spark.read.format("parquet") + val df = spark.read + .format("parquet") .load("/Users/pawelkocinski/Desktop/projects/sedona-book/apache-sedona-book/book/chapter10/data/buildings/partitioned") .select( geometryToNonGeometryFunction(col("geometry")), geometryToGeometryFunction(col("geometry")), nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), - col("geohash") - ) + col("geohash")) df.show() diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index 3006a14e14..1ca705e297 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -145,9 +145,9 @@ object ScalarUDF { } private val workerEnv = new java.util.HashMap[String, String]() - workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") - SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "sedonaworker.worker") - SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false) + workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") + SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "sedonaworker.worker") + SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false) val geometryToNonGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( name = "geospatial_udf",
