This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch arrow-worker in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 37389d606152519b8431903944b8abf4196acfd6 Author: pawelkocinski <[email protected]> AuthorDate: Sun Dec 21 23:17:05 2025 +0100 add code so far --- .../spark/api/python/SedonaPythonRunner.scala | 99 +++--- .../execution/python/SedonaArrowPythonRunner.scala | 2 +- .../sql/execution/python/SedonaArrowUtils.scala | 8 +- .../execution/python/SedonaPythonArrowInput.scala | 2 + .../execution/python/SedonaPythonUDFRunner.scala | 2 +- .../spark/sql/execution/python/SedonaThread1.scala | 285 +++++++++++++++++ .../sql/execution/python/SedonaWriterThread.scala | 349 +++++++++++++++++++++ .../org/apache/spark/sql/udf/StrategySuite.scala | 25 +- 8 files changed, 718 insertions(+), 54 deletions(-) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala index bdc989d82b..38a9c7182b 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala @@ -17,6 +17,7 @@ package org.apache.spark.api.python * limitations under the License. */ +import org.apache.sedona.common.geometrySerde.CoordinateType import org.apache.spark._ import org.apache.spark.SedonaSparkEnv import org.apache.spark.internal.Logging @@ -24,6 +25,7 @@ import org.apache.spark.internal.config.Python._ import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} import org.apache.spark.security.SocketAuthHelper +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator} import org.apache.spark.util._ import java.io._ @@ -34,48 +36,11 @@ import java.nio.file.{Path, Files => JavaFiles} import java.util.concurrent.atomic.AtomicBoolean import scala.collection.JavaConverters._ import scala.util.control.NonFatal +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType -/** - * Enumerate the type of command that will be sent to the Python worker - */ -private[spark] object PythonEvalType { - val NON_UDF = 0 - - val SQL_BATCHED_UDF = 100 - val SQL_ARROW_BATCHED_UDF = 101 - - val SQL_SCALAR_PANDAS_UDF = 200 - val SQL_GROUPED_MAP_PANDAS_UDF = 201 - val SQL_GROUPED_AGG_PANDAS_UDF = 202 - val SQL_WINDOW_AGG_PANDAS_UDF = 203 - val SQL_SCALAR_PANDAS_ITER_UDF = 204 - val SQL_MAP_PANDAS_ITER_UDF = 205 - val SQL_COGROUPED_MAP_PANDAS_UDF = 206 - val SQL_MAP_ARROW_ITER_UDF = 207 - val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208 - - val SQL_TABLE_UDF = 300 - val SQL_ARROW_TABLE_UDF = 301 - - def toString(pythonEvalType: Int): String = pythonEvalType match { - case NON_UDF => "NON_UDF" - case SQL_BATCHED_UDF => "SQL_BATCHED_UDF" - case SQL_ARROW_BATCHED_UDF => "SQL_ARROW_BATCHED_UDF" - case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF" - case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF" - case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF" - case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF" - case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF" - case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF" - case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF" - case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF" - case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => "SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE" - case SQL_TABLE_UDF => "SQL_TABLE_UDF" - case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF" - } -} - private object SedonaBasePythonRunner { private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = "faulthandler") @@ -92,10 +57,11 @@ private object SedonaBasePythonRunner { * functions (from bottom to top). */ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( - protected val funcs: Seq[ChainedPythonFunctions], - protected val evalType: Int, - protected val argOffsets: Array[Array[Int]], - protected val jobArtifactUUID: Option[String]) + protected val funcs: Seq[ChainedPythonFunctions], + protected val evalType: Int, + protected val argOffsets: Array[Array[Int]], + protected val jobArtifactUUID: Option[String], + schema: StructType) extends Logging { require(funcs.length == argOffsets.length, "argOffsets should have the same length as funcs") @@ -279,9 +245,43 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( override def run(): Unit = Utils.logUncaughtExceptions { try { + println("ssss") + val toReadCRS = inputIterator.buffered.headOption.flatMap( + el => el.asInstanceOf[Iterator[IN]].buffered.headOption + ) + + val row = toReadCRS match { + case Some(value) => value match { + case row: GenericInternalRow => + Some(row) + } + case None => None + } + + val geometryFields = schema.zipWithIndex.filter { + case (field, index) => field.dataType == GeometryUDT + }.map { + case (field, index) => + if (row.isEmpty || row.get.values(index) == null) (index, 0) else { + val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] + val preambleByte = geom(0) & 0xFF + val hasSrid = (preambleByte & 0x01) != 0 + + var srid = 0 + if (hasSrid) { + val srid2 = (geom(1) & 0xFF) << 16 + val srid1 = (geom(2) & 0xFF) << 8 + val srid0 = geom(3) & 0xFF + srid = srid2 | srid1 | srid0 + } + (index, srid) + } + } + TaskContext.setTaskContext(context) val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) val dataOut = new DataOutputStream(stream) + // Partition index dataOut.writeInt(partitionIndex) // Python version of driver @@ -401,6 +401,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty dataOut.writeBoolean(needsDecryptionServer) dataOut.writeInt(cnt) + def sendBidsToRemove(): Unit = { for (bid <- toRemove) { // remove the broadcast from worker @@ -408,6 +409,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( oldBids.remove(bid) } } + if (needsDecryptionServer) { // if there is encryption, we setup a server which reads the encrypted files, and sends // the decrypted data to python @@ -447,6 +449,15 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( dataOut.writeInt(evalType) writeCommand(dataOut) + + // write number of geometry fields + dataOut.writeInt(geometryFields.length) + // write geometry field indices and their SRIDs + geometryFields.foreach { case (index, srid) => + dataOut.writeInt(index) + dataOut.writeInt(srid) + } + writeIteratorToStream(dataOut) dataOut.writeInt(SpecialLengths.END_OF_STREAM) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala index 27e4b851ee..c3eafc9766 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala @@ -38,7 +38,7 @@ class SedonaArrowPythonRunner( val pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch]( - funcs, evalType, argOffsets, jobArtifactUUID) + funcs, evalType, argOffsets, jobArtifactUUID, schema) with SedonaBasicPythonArrowInput with SedonaBasicPythonArrowOutput { diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala index ec4f7c00d0..6c1bb9edd5 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala @@ -40,10 +40,11 @@ private[sql] object SedonaArrowUtils { largeVarTypes: Boolean = false): Field = { dt match { case GeometryUDT => - val jsonData = """{"crs": {"$schema": "https://proj.org/schemas/v0.7/projjson.schema.json", "type": "GeographicCRS", "name": "WGS 84", "datum_ensemble": {"name": "World Geodetic System 1984 ensemble", "members": [{"name": "World Geodetic System 1984 (Transit)", "id": {"authority": "EPSG", "code": 1166}}, {"name": "World Geodetic System 1984 (G730)", "id": {"authority": "EPSG", "code": 1152}}, {"name": "World Geodetic System 1984 (G873)", "id": {"authority": "EPSG", "code": 1153}} [...] +// val jsonData = """{"crs": {"$schema": "https://proj.org/schemas/v0.7/projjson.schema.json", "type": "GeographicCRS", "name": "WGS 84", "datum_ensemble": {"name": "World Geodetic System 1984 ensemble", "members": [{"name": "World Geodetic System 1984 (Transit)", "id": {"authority": "EPSG", "code": 1166}}, {"name": "World Geodetic System 1984 (G730)", "id": {"authority": "EPSG", "code": 1152}}, {"name": "World Geodetic System 1984 (G873)", "id": {"authority": "EPSG", "code": 1153 [...] val metadata = Map( - "ARROW:extension:name" -> "geoarrow.wkb", - "ARROW:extension:metadata" -> jsonData, + "empty" -> "empty", +// "ARROW:extension:name" -> "geoarrow.wkb", +// "ARROW:extension:metadata" -> jsonData, ).asJava val fieldType = new FieldType(nullable, ArrowType.Binary.INSTANCE, null, metadata) @@ -77,6 +78,7 @@ private[sql] object SedonaArrowUtils { val fieldType = new FieldType(nullable, toArrowType(dataType, timeZoneId, largeVarTypes), null) new Field(name, fieldType, Seq.empty[Field].asJava) + case _ => toArrowField(name, dt, nullable, timeZoneId, largeVarTypes) } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala index 8a5e241c51..bf353539bc 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -77,6 +77,8 @@ private[python] trait SedonaPythonArrowInput[IN] { self: SedonaBasePythonRunner[ inputIterator: Iterator[IN], partitionIndex: Int, context: TaskContext): WriterThread = { +// createWorkerThread(env, worker, inputIterator, partitionIndex, context, schema) + new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala index ced32cf801..beb49c1dde 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala @@ -35,7 +35,7 @@ abstract class SedonaBasePythonUDFRunner( pythonMetrics: Map[String, SQLMetric], jobArtifactUUID: Option[String]) extends SedonaBasePythonRunner[Array[Byte], Array[Byte]]( - funcs, evalType, argOffsets, jobArtifactUUID) { + funcs, evalType, argOffsets, jobArtifactUUID, null) { override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala new file mode 100644 index 0000000000..41fc67df3a --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala @@ -0,0 +1,285 @@ +package org.apache.spark.sql.execution.python + +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamWriter +import org.apache.spark.{BarrierTaskContext, SparkEnv, SparkException, SparkFiles, TaskContext} +import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, BasePythonRunner, EncryptedPythonBroadcastServer, PythonRDD, SedonaBasePythonRunner, SpecialLengths} +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType +import org.apache.spark.util.Utils + +import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, File} +import java.net.{InetAddress, ServerSocket, Socket, SocketException} +import java.nio.charset.StandardCharsets.UTF_8 +import scala.util.control.NonFatal +// +//trait WorkerThreadFactory[IN] { +// self: BasePythonRunner [IN, _] => +// +// // Define common functionality for worker threads here +// +// def createWorkerThread( +// env: SparkEnv, +// worker: Socket, +// inputIterator: Iterator[IN], +// partitionIndex: Int, +// context: TaskContext, +// schema: StructType, +// ): WriterThread = { +// new SedonaThread1 (env, worker, inputIterator, partitionIndex, context, schema) +// } +// +// class SedonaThread1( +// env: SparkEnv, +// worker: Socket, +// inputIterator: Iterator[IN], +// partitionIndex: Int, +// context: TaskContext, +// schema: StructType, +// ) extends WriterThread(env, worker, inputIterator, partitionIndex, context) { +// +// +// override def run(): Unit = Utils.logUncaughtExceptions { +// try { +// val toReadCRS = inputIterator.buffered.headOption.flatMap( +// el => el.asInstanceOf[Iterator[IN]].buffered.headOption +// ) +// +// val row = toReadCRS match { +// case Some(value) => value match { +// case row: GenericInternalRow => +// Some(row) +// } +// case None => None +// } +// +// val geometryFields = schema.zipWithIndex.filter { +// case (field, index) => field.dataType == GeometryUDT +// }.map { +// case (field, index) => +// if (row.isEmpty || row.get.values(index) == null) (index, 0) else { +// val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] +// val preambleByte = geom(0) & 0xFF +// val hasSrid = (preambleByte & 0x01) != 0 +// +// var srid = 0 +// if (hasSrid) { +// val srid2 = (geom(1) & 0xFF) << 16 +// val srid1 = (geom(2) & 0xFF) << 8 +// val srid0 = geom(3) & 0xFF +// srid = srid2 | srid1 | srid0 +// } +// (index, srid) +// } +// } +// +// TaskContext.setTaskContext(context) +// val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) +// val dataOut = new DataOutputStream(stream) +// +// // Partition index +// dataOut.writeInt(partitionIndex) +// // Python version of driver +// PythonRDD.writeUTF(pythonVer, dataOut) +// // Init a ServerSocket to accept method calls from Python side. +// val isBarrier = context.isInstanceOf[BarrierTaskContext] +// if (isBarrier) { +// serverSocket = Some(new ServerSocket(/* port */ 0, +// /* backlog */ 1, +// InetAddress.getByName("localhost"))) +// // A call to accept() for ServerSocket shall block infinitely. +// serverSocket.foreach(_.setSoTimeout(0)) +// new Thread("accept-connections") { +// setDaemon(true) +// +// override def run(): Unit = { +// while (!serverSocket.get.isClosed()) { +// var sock: Socket = null +// try { +// sock = serverSocket.get.accept() +// // Wait for function call from python side. +// sock.setSoTimeout(10000) +// authHelper.authClient(sock) +// val input = new DataInputStream(sock.getInputStream()) +// val requestMethod = input.readInt() +// // The BarrierTaskContext function may wait infinitely, socket shall not timeout +// // before the function finishes. +// sock.setSoTimeout(10000) +// requestMethod match { +// case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => +// barrierAndServe(requestMethod, sock) +// case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => +// val length = input.readInt() +// val message = new Array[Byte](length) +// input.readFully(message) +// barrierAndServe(requestMethod, sock, new String(message, UTF_8)) +// case _ => +// val out = new DataOutputStream(new BufferedOutputStream( +// sock.getOutputStream)) +// writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) +// } +// } catch { +// case e: SocketException if e.getMessage.contains("Socket closed") => +// // It is possible that the ServerSocket is not closed, but the native socket +// // has already been closed, we shall catch and silently ignore this case. +// } finally { +// if (sock != null) { +// sock.close() +// } +// } +// } +// } +// }.start() +// } +// val secret = if (isBarrier) { +// authHelper.secret +// } else { +// "" +// } +// // Close ServerSocket on task completion. +// serverSocket.foreach { server => +// context.addTaskCompletionListener[Unit](_ => server.close()) +// } +// val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) +// if (boundPort == -1) { +// val message = "ServerSocket failed to bind to Java side." +// logError(message) +// throw new SparkException(message) +// } else if (isBarrier) { +// logDebug(s"Started ServerSocket on port $boundPort.") +// } +// // Write out the TaskContextInfo +// dataOut.writeBoolean(isBarrier) +// dataOut.writeInt(boundPort) +// val secretBytes = secret.getBytes(UTF_8) +// dataOut.writeInt(secretBytes.length) +// dataOut.write(secretBytes, 0, secretBytes.length) +// dataOut.writeInt(context.stageId()) +// dataOut.writeInt(context.partitionId()) +// dataOut.writeInt(context.attemptNumber()) +// dataOut.writeLong(context.taskAttemptId()) +// dataOut.writeInt(context.cpus()) +// val resources = context.resources() +// dataOut.writeInt(resources.size) +// resources.foreach { case (k, v) => +// PythonRDD.writeUTF(k, dataOut) +// PythonRDD.writeUTF(v.name, dataOut) +// dataOut.writeInt(v.addresses.size) +// v.addresses.foreach { case addr => +// PythonRDD.writeUTF(addr, dataOut) +// } +// } +// val localProps = context.getLocalProperties.asScala +// dataOut.writeInt(localProps.size) +// localProps.foreach { case (k, v) => +// PythonRDD.writeUTF(k, dataOut) +// PythonRDD.writeUTF(v, dataOut) +// } +// +// // sparkFilesDir +// val root = jobArtifactUUID.map { uuid => +// new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath +// }.getOrElse(SparkFiles.getRootDirectory()) +// PythonRDD.writeUTF(root, dataOut) +// // Python includes (*.zip and *.egg files) +// dataOut.writeInt(pythonIncludes.size) +// for (include <- pythonIncludes) { +// PythonRDD.writeUTF(include, dataOut) +// } +// // Broadcast variables +// val oldBids = PythonRDD.getWorkerBroadcasts(worker) +// val newBids = broadcastVars.map(_.id).toSet +// // number of different broadcasts +// val toRemove = oldBids.diff(newBids) +// val addedBids = newBids.diff(oldBids) +// val cnt = toRemove.size + addedBids.size +// val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty +// dataOut.writeBoolean(needsDecryptionServer) +// dataOut.writeInt(cnt) +// +// def sendBidsToRemove(): Unit = { +// for (bid <- toRemove) { +// // remove the broadcast from worker +// dataOut.writeLong(-bid - 1) // bid >= 0 +// oldBids.remove(bid) +// } +// } +// +// if (needsDecryptionServer) { +// // if there is encryption, we setup a server which reads the encrypted files, and sends +// // the decrypted data to python +// val idsAndFiles = broadcastVars.flatMap { broadcast => +// if (!oldBids.contains(broadcast.id)) { +// oldBids.add(broadcast.id) +// Some((broadcast.id, broadcast.value.path)) +// } else { +// None +// } +// } +// val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) +// dataOut.writeInt(server.port) +// logTrace(s"broadcast decryption server setup on ${server.port}") +// PythonRDD.writeUTF(server.secret, dataOut) +// sendBidsToRemove() +// idsAndFiles.foreach { case (id, _) => +// // send new broadcast +// dataOut.writeLong(id) +// } +// dataOut.flush() +// logTrace("waiting for python to read decrypted broadcast data from server") +// server.waitTillBroadcastDataSent() +// logTrace("done sending decrypted data to python") +// } else { +// sendBidsToRemove() +// for (broadcast <- broadcastVars) { +// if (!oldBids.contains(broadcast.id)) { +// // send new broadcast +// dataOut.writeLong(broadcast.id) +// PythonRDD.writeUTF(broadcast.value.path, dataOut) +// oldBids.add(broadcast.id) +// } +// } +// } +// dataOut.flush() +// +// dataOut.writeInt(evalType) +// writeCommand(dataOut) +// +// // write number of geometry fields +// dataOut.writeInt(geometryFields.length) +// // write geometry field indices and their SRIDs +// geometryFields.foreach { case (index, srid) => +// dataOut.writeInt(index) +// dataOut.writeInt(srid) +// } +// +// writeIteratorToStream(dataOut) +// +// dataOut.writeInt(SpecialLengths.END_OF_STREAM) +// dataOut.flush() +// } catch { +// case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => +// if (context.isCompleted || context.isInterrupted) { +// logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + +// "cleanup)", t) +// if (!worker.isClosed) { +// Utils.tryLog(worker.shutdownOutput()) +// } +// } else { +// // We must avoid throwing exceptions/NonFatals here, because the thread uncaught +// // exception handler will kill the whole executor (see +// // org.apache.spark.executor.Executor). +// _exception = t +// if (!worker.isClosed) { +// Utils.tryLog(worker.shutdownOutput()) +// } +// } +// } +// } +// +// override protected def writeCommand(dataOut: DataOutputStream): Unit = ??? +// +// override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = ??? +// } +//} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala new file mode 100644 index 0000000000..b28cf81906 --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala @@ -0,0 +1,349 @@ +package org.apache.spark.sql.execution.python + + +import org.apache.sedona.common.geometrySerde.CoordinateType +import org.apache.spark._ +import org.apache.spark.SedonaSparkEnv +import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, BasePythonRunner, ChainedPythonFunctions, EncryptedPythonBroadcastServer, PythonRDD, SedonaBasePythonRunner, SpecialLengths} +import org.apache.spark.internal.Logging +import org.apache.spark.internal.config.Python._ +import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} +import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} +import org.apache.spark.security.SocketAuthHelper +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator} +import org.apache.spark.util._ + +import java.io._ +import java.net._ +import java.nio.charset.StandardCharsets +import java.nio.charset.StandardCharsets.UTF_8 +import java.nio.file.{Path, Files => JavaFiles} +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ +import scala.util.control.NonFatal +import org.apache.spark.sql.catalyst.expressions.GenericInternalRow +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.sql.types.StructType + +abstract class SedonaWriterThread[IN, OUT]( + env: SparkEnv, + worker: Socket, + inputIterator: Iterator[IN], + partitionIndex: Int, + context: TaskContext, + pythonExec: String, + schema: org.apache.spark.sql.types.StructType, + ) + extends Thread(s"stdout writer for $pythonExec") with Logging { + self: BasePythonRunner[IN, _] => + + @volatile private var _exception: Throwable = null + private val conf = SparkEnv.get.conf + private lazy val authHelper = new SocketAuthHelper(conf) + private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet + private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) + + setDaemon(true) + + /** Contains the throwable thrown while writing the parent iterator to the Python process. */ + def exception: Option[Throwable] = Option(_exception) + + /** + * Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur + * due to cleanup. + */ + def shutdownOnTaskCompletion(): Unit = { + assert(context.isCompleted) + this.interrupt() + // Task completion listeners that run after this method returns may invalidate + // `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized + // reader, a task completion listener will free the underlying off-heap buffers. If the writer + // thread is still running when `inputIterator` is invalidated, it can cause a use-after-free + // bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer + // thread to exit before returning. + this.join() + } + + /** + * Writes a command section to the stream connected to the Python worker. + */ + protected def writeCommand(dataOut: DataOutputStream): Unit + + /** + * Writes input data to the stream connected to the Python worker. + */ + protected def writeIteratorToStream(dataOut: DataOutputStream): Unit + + override def run(): Unit = Utils.logUncaughtExceptions { + try { + println("ssss") + val toReadCRS = inputIterator.buffered.headOption.flatMap( + el => el.asInstanceOf[Iterator[IN]].buffered.headOption + ) + + val row = toReadCRS match { + case Some(value) => value match { + case row: GenericInternalRow => + Some(row) + } + case None => None + } + + val geometryFields = schema.zipWithIndex.filter { + case (field, index) => field.dataType == GeometryUDT + }.map { + case (field, index) => + if (row.isEmpty || row.get.values(index) == null) (index, 0) else { + val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] + val preambleByte = geom(0) & 0xFF + val hasSrid = (preambleByte & 0x01) != 0 + + var srid = 0 + if (hasSrid) { + val srid2 = (geom(1) & 0xFF) << 16 + val srid1 = (geom(2) & 0xFF) << 8 + val srid0 = geom(3) & 0xFF + srid = srid2 | srid1 | srid0 + } + (index, srid) + } + } + + TaskContext.setTaskContext(context) + val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) + val dataOut = new DataOutputStream(stream) + + // Partition index + dataOut.writeInt(partitionIndex) + // Python version of driver + PythonRDD.writeUTF(pythonVer, dataOut) + // Init a ServerSocket to accept method calls from Python side. + val isBarrier = context.isInstanceOf[BarrierTaskContext] + if (isBarrier) { + serverSocket = Some(new ServerSocket(/* port */ 0, + /* backlog */ 1, + InetAddress.getByName("localhost"))) + // A call to accept() for ServerSocket shall block infinitely. + serverSocket.foreach(_.setSoTimeout(0)) + new Thread("accept-connections") { + setDaemon(true) + + override def run(): Unit = { + while (!serverSocket.get.isClosed()) { + var sock: Socket = null + try { + sock = serverSocket.get.accept() + // Wait for function call from python side. + sock.setSoTimeout(10000) + authHelper.authClient(sock) + val input = new DataInputStream(sock.getInputStream()) + val requestMethod = input.readInt() + // The BarrierTaskContext function may wait infinitely, socket shall not timeout + // before the function finishes. + sock.setSoTimeout(10000) + requestMethod match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + barrierAndServe(requestMethod, sock) + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + val length = input.readInt() + val message = new Array[Byte](length) + input.readFully(message) + barrierAndServe(requestMethod, sock, new String(message, UTF_8)) + case _ => + val out = new DataOutputStream(new BufferedOutputStream( + sock.getOutputStream)) + writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) + } + } catch { + case e: SocketException if e.getMessage.contains("Socket closed") => + // It is possible that the ServerSocket is not closed, but the native socket + // has already been closed, we shall catch and silently ignore this case. + } finally { + if (sock != null) { + sock.close() + } + } + } + } + }.start() + } + val secret = if (isBarrier) { + authHelper.secret + } else { + "" + } + // Close ServerSocket on task completion. + serverSocket.foreach { server => + context.addTaskCompletionListener[Unit](_ => server.close()) + } + val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) + if (boundPort == -1) { + val message = "ServerSocket failed to bind to Java side." + logError(message) + throw new SparkException(message) + } else if (isBarrier) { + logDebug(s"Started ServerSocket on port $boundPort.") + } + // Write out the TaskContextInfo + dataOut.writeBoolean(isBarrier) + dataOut.writeInt(boundPort) + val secretBytes = secret.getBytes(UTF_8) + dataOut.writeInt(secretBytes.length) + dataOut.write(secretBytes, 0, secretBytes.length) + dataOut.writeInt(context.stageId()) + dataOut.writeInt(context.partitionId()) + dataOut.writeInt(context.attemptNumber()) + dataOut.writeLong(context.taskAttemptId()) + dataOut.writeInt(context.cpus()) + val resources = context.resources() + dataOut.writeInt(resources.size) + resources.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v.name, dataOut) + dataOut.writeInt(v.addresses.size) + v.addresses.foreach { case addr => + PythonRDD.writeUTF(addr, dataOut) + } + } + val localProps = context.getLocalProperties.asScala + dataOut.writeInt(localProps.size) + localProps.foreach { case (k, v) => + PythonRDD.writeUTF(k, dataOut) + PythonRDD.writeUTF(v, dataOut) + } + + // sparkFilesDir + val root = jobArtifactUUID.map { uuid => + new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath + }.getOrElse(SparkFiles.getRootDirectory()) + PythonRDD.writeUTF(root, dataOut) + // Python includes (*.zip and *.egg files) + dataOut.writeInt(pythonIncludes.size) + for (include <- pythonIncludes) { + PythonRDD.writeUTF(include, dataOut) + } + // Broadcast variables + val oldBids = PythonRDD.getWorkerBroadcasts(worker) + val newBids = broadcastVars.map(_.id).toSet + // number of different broadcasts + val toRemove = oldBids.diff(newBids) + val addedBids = newBids.diff(oldBids) + val cnt = toRemove.size + addedBids.size + val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty + dataOut.writeBoolean(needsDecryptionServer) + dataOut.writeInt(cnt) + + def sendBidsToRemove(): Unit = { + for (bid <- toRemove) { + // remove the broadcast from worker + dataOut.writeLong(-bid - 1) // bid >= 0 + oldBids.remove(bid) + } + } + + if (needsDecryptionServer) { + // if there is encryption, we setup a server which reads the encrypted files, and sends + // the decrypted data to python + val idsAndFiles = broadcastVars.flatMap { broadcast => + if (!oldBids.contains(broadcast.id)) { + oldBids.add(broadcast.id) + Some((broadcast.id, broadcast.value.path)) + } else { + None + } + } + val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) + dataOut.writeInt(server.port) + logTrace(s"broadcast decryption server setup on ${server.port}") + PythonRDD.writeUTF(server.secret, dataOut) + sendBidsToRemove() + idsAndFiles.foreach { case (id, _) => + // send new broadcast + dataOut.writeLong(id) + } + dataOut.flush() + logTrace("waiting for python to read decrypted broadcast data from server") + server.waitTillBroadcastDataSent() + logTrace("done sending decrypted data to python") + } else { + sendBidsToRemove() + for (broadcast <- broadcastVars) { + if (!oldBids.contains(broadcast.id)) { + // send new broadcast + dataOut.writeLong(broadcast.id) + PythonRDD.writeUTF(broadcast.value.path, dataOut) + oldBids.add(broadcast.id) + } + } + } + dataOut.flush() + + dataOut.writeInt(evalType) + writeCommand(dataOut) + + // write number of geometry fields + dataOut.writeInt(geometryFields.length) + // write geometry field indices and their SRIDs + geometryFields.foreach { case (index, srid) => + dataOut.writeInt(index) + dataOut.writeInt(srid) + } + + writeIteratorToStream(dataOut) + + dataOut.writeInt(SpecialLengths.END_OF_STREAM) + dataOut.flush() + } catch { + case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => + if (context.isCompleted || context.isInterrupted) { + logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + + "cleanup)", t) + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } else { + // We must avoid throwing exceptions/NonFatals here, because the thread uncaught + // exception handler will kill the whole executor (see + // org.apache.spark.executor.Executor). + _exception = t + if (!worker.isClosed) { + Utils.tryLog(worker.shutdownOutput()) + } + } + } + } + + /** + * Gateway to call BarrierTaskContext methods. + */ + def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { + require( + serverSocket.isDefined, + "No available ServerSocket to redirect the BarrierTaskContext method call." + ) + val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) + try { + val messages = requestMethod match { + case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => + context.asInstanceOf[BarrierTaskContext].barrier() + Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS) + case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => + context.asInstanceOf[BarrierTaskContext].allGather(message) + } + out.writeInt(messages.length) + messages.foreach(writeUTF(_, out)) + } catch { + case e: SparkException => + writeUTF(e.getMessage, out) + } finally { + out.close() + } + } + + def writeUTF(str: String, dataOut: DataOutputStream): Unit = { + val bytes = str.getBytes(UTF_8) + dataOut.writeInt(bytes.length) + dataOut.write(bytes) + } +} + diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala index 391587e586..c92af92cdd 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -20,8 +20,8 @@ package org.apache.spark.sql.udf import org.apache.sedona.spark.SedonaContext import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.functions.{col, expr} -import org.apache.spark.sql.udf.ScalarUDF.{geometryToGeometryFunction} +import org.apache.spark.sql.functions.{col, expr, lit} +import org.apache.spark.sql.udf.ScalarUDF.geometryToGeometryFunction import org.locationtech.jts.io.WKTReader import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers @@ -47,14 +47,29 @@ class StrategySuite extends AnyFunSuite with Matchers { // spark.sql("select 1").show() val df = spark.read.format("geoparquet") .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") - .selectExpr("ST_Centroid(geometry) AS geometry") + .withColumn("geometry", expr("ST_SetSRID(geometry, '4326')")) + + df.show() df .select( + col("id"), + col("version"), + col("bbox"), // geometryToNonGeometryFunction(col("geometry")), - geometryToGeometryFunction(col("geometry")), + geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom"), // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), - ).show(10, false) + ).show(10) + + println(df + .select( + col("id"), + col("version"), + col("bbox"), + // geometryToNonGeometryFunction(col("geometry")), + geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom"), + // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), + ).count()) // df.show() 1 shouldBe 1
