This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch arrow-worker in repository https://gitbox.apache.org/repos/asf/sedona.git
commit eee700d4986da5f6f0a3a242857e4d08b213b89d Author: pawelkocinski <[email protected]> AuthorDate: Sun Dec 21 23:58:42 2025 +0100 add code so far --- .../execution/python/SedonaArrowPythonRunner.scala | 12 - .../execution/python/SedonaPythonArrowInput.scala | 38 +-- .../execution/python/SedonaPythonUDFRunner.scala | 147 --------- .../spark/sql/execution/python/SedonaThread1.scala | 285 ----------------- .../sql/execution/python/SedonaWriterThread.scala | 349 --------------------- 5 files changed, 1 insertion(+), 830 deletions(-) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala index 2a28eba6db..16b81b50c0 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala @@ -56,15 +56,3 @@ class SedonaArrowPythonRunner( "Pandas execution requires more than 4 bytes. Please set higher buffer. " + s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") } - -object SedonaArrowPythonRunner { - /** Return Map with conf settings to be used in ArrowPythonRunner */ - def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = { - val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> conf.sessionLocalTimeZone) - val pandasColsByName = Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key -> - conf.pandasGroupedMapAssignColumnsByName.toString) - val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key -> - conf.arrowSafeTypeConversion.toString) - Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*) - } -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala index 7bc0d322c2..9c0b0c94b4 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -27,8 +27,6 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter import org.apache.spark.sql.execution.metric.SQLMetric import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.types.StructType -import org.apache.spark.sql.util.ArrowUtils -import org.apache.spark.sql.util.ArrowUtils.toArrowSchema import org.apache.spark.util.Utils import org.apache.spark.{SparkEnv, TaskContext} @@ -39,39 +37,7 @@ import java.net.Socket * A trait that can be mixed-in with [[python.BasePythonRunner]]. It implements the logic from * JVM (an iterator of internal rows + additional data if required) to Python (Arrow). */ -private[python] trait SedonaPythonArrowInput[IN] { self: BasePythonRunner[IN, _] => - protected val workerConf: Map[String, String] - - protected val schema: StructType - - protected val timeZoneId: String - - protected val errorOnDuplicatedFieldNames: Boolean - - protected val largeVarTypes: Boolean - - protected def pythonMetrics: Map[String, SQLMetric] - - protected def writeIteratorToArrowStream( - root: VectorSchemaRoot, - writer: ArrowStreamWriter, - dataOut: DataOutputStream, - inputIterator: Iterator[IN]): Unit - - protected def writeUDF( - dataOut: DataOutputStream, - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]]): Unit = - SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - - protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = { - // Write config for the worker as a number of key -> value pairs of strings - stream.writeInt(workerConf.size) - for ((k, v) <- workerConf) { - PythonRDD.writeUTF(k, stream) - PythonRDD.writeUTF(v, stream) - } - } +private[python] trait SedonaPythonArrowInput[IN] extends PythonArrowInput[IN] { self: BasePythonRunner[IN, _] => protected override def newWriterThread( env: SparkEnv, @@ -79,8 +45,6 @@ private[python] trait SedonaPythonArrowInput[IN] { self: BasePythonRunner[IN, _] inputIterator: Iterator[IN], partitionIndex: Int, context: TaskContext): WriterThread = { -// createWorkerThread(env, worker, inputIterator, partitionIndex, context, schema) - new WriterThread(env, worker, inputIterator, partitionIndex, context) { protected override def writeCommand(dataOut: DataOutputStream): Unit = { diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala deleted file mode 100644 index dcf93b5213..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala +++ /dev/null @@ -1,147 +0,0 @@ -package org.apache.spark.sql.execution.python - -/* - * Licensed to the Apache Software Foundation (ASF) under one or more - * contributor license agreements. See the NOTICE file distributed with - * this work for additional information regarding copyright ownership. - * The ASF licenses this file to You under the Apache License, Version 2.0 - * (the "License"); you may not use this file except in compliance with - * the License. You may obtain a copy of the License at - * - * http://www.apache.org/licenses/LICENSE-2.0 - * - * Unless required by applicable law or agreed to in writing, software - * distributed under the License is distributed on an "AS IS" BASIS, - * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. - * See the License for the specific language governing permissions and - * limitations under the License. - */ - -import java.io._ -import java.net._ -import java.util.concurrent.atomic.AtomicBoolean -import org.apache.spark._ -import org.apache.spark.api.python._ -import org.apache.spark.sql.execution.metric.SQLMetric -import org.apache.spark.sql.internal.SQLConf - -/** - * A helper class to run Python UDFs in Spark. - */ -abstract class SedonaBasePythonUDFRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) - extends BasePythonRunner[Array[Byte], Array[Byte]]( - funcs, evalType, argOffsets, jobArtifactUUID) { - - override val pythonExec: String = - SQLConf.get.pysparkWorkerPythonExecutable.getOrElse( - funcs.head.funcs.head.pythonExec) - - override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - - abstract class SedonaPythonUDFWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[Array[Byte]], - partitionIndex: Int, - context: TaskContext) - extends WriterThread(env, worker, inputIterator, partitionIndex, context) { - - protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { - val startData = dataOut.size() - - PythonRDD.writeIteratorToStream(inputIterator, dataOut) - dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION) - - val deltaData = dataOut.size() - startData - pythonMetrics("pythonDataSent") += deltaData - } - } - - protected override def newReaderIterator( - stream: DataInputStream, - writerThread: WriterThread, - startTime: Long, - env: SparkEnv, - worker: Socket, - pid: Option[Int], - releasedOrClosed: AtomicBoolean, - context: TaskContext): Iterator[Array[Byte]] = { - new ReaderIterator( - stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { - - protected override def read(): Array[Byte] = { - if (writerThread.exception.isDefined) { - throw writerThread.exception.get - } - try { - stream.readInt() match { - case length if length > 0 => - val obj = new Array[Byte](length) - stream.readFully(obj) - pythonMetrics("pythonDataReceived") += length - obj - case 0 => Array.emptyByteArray - case SpecialLengths.TIMING_DATA => - handleTimingData() - read() - case SpecialLengths.PYTHON_EXCEPTION_THROWN => - throw handlePythonException() - case SpecialLengths.END_OF_DATA_SECTION => - handleEndOfDataSection() - null - } - } catch handleException - } - } - } -} - -class SedonaPythonUDFRunner( - funcs: Seq[ChainedPythonFunctions], - evalType: Int, - argOffsets: Array[Array[Int]], - pythonMetrics: Map[String, SQLMetric], - jobArtifactUUID: Option[String]) - extends SedonaBasePythonUDFRunner(funcs, evalType, argOffsets, pythonMetrics, jobArtifactUUID) { - - protected override def newWriterThread( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[Array[Byte]], - partitionIndex: Int, - context: TaskContext): WriterThread = { - new SedonaPythonUDFWriterThread(env, worker, inputIterator, partitionIndex, context) { - - protected override def writeCommand(dataOut: DataOutputStream): Unit = { - SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets) - } - - } - } -} - -object SedonaPythonUDFRunner { - - def writeUDFs( - dataOut: DataOutputStream, - funcs: Seq[ChainedPythonFunctions], - argOffsets: Array[Array[Int]]): Unit = { - dataOut.writeInt(funcs.length) - funcs.zip(argOffsets).foreach { case (chained, offsets) => - dataOut.writeInt(offsets.length) - offsets.foreach { offset => - dataOut.writeInt(offset) - } - dataOut.writeInt(chained.funcs.length) - chained.funcs.foreach { f => - dataOut.writeInt(f.command.length) - dataOut.write(f.command.toArray) - } - } - } -} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala deleted file mode 100644 index 41fc67df3a..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala +++ /dev/null @@ -1,285 +0,0 @@ -package org.apache.spark.sql.execution.python - -import org.apache.arrow.vector.VectorSchemaRoot -import org.apache.arrow.vector.ipc.ArrowStreamWriter -import org.apache.spark.{BarrierTaskContext, SparkEnv, SparkException, SparkFiles, TaskContext} -import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, BasePythonRunner, EncryptedPythonBroadcastServer, PythonRDD, SedonaBasePythonRunner, SpecialLengths} -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.StructType -import org.apache.spark.util.Utils - -import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, File} -import java.net.{InetAddress, ServerSocket, Socket, SocketException} -import java.nio.charset.StandardCharsets.UTF_8 -import scala.util.control.NonFatal -// -//trait WorkerThreadFactory[IN] { -// self: BasePythonRunner [IN, _] => -// -// // Define common functionality for worker threads here -// -// def createWorkerThread( -// env: SparkEnv, -// worker: Socket, -// inputIterator: Iterator[IN], -// partitionIndex: Int, -// context: TaskContext, -// schema: StructType, -// ): WriterThread = { -// new SedonaThread1 (env, worker, inputIterator, partitionIndex, context, schema) -// } -// -// class SedonaThread1( -// env: SparkEnv, -// worker: Socket, -// inputIterator: Iterator[IN], -// partitionIndex: Int, -// context: TaskContext, -// schema: StructType, -// ) extends WriterThread(env, worker, inputIterator, partitionIndex, context) { -// -// -// override def run(): Unit = Utils.logUncaughtExceptions { -// try { -// val toReadCRS = inputIterator.buffered.headOption.flatMap( -// el => el.asInstanceOf[Iterator[IN]].buffered.headOption -// ) -// -// val row = toReadCRS match { -// case Some(value) => value match { -// case row: GenericInternalRow => -// Some(row) -// } -// case None => None -// } -// -// val geometryFields = schema.zipWithIndex.filter { -// case (field, index) => field.dataType == GeometryUDT -// }.map { -// case (field, index) => -// if (row.isEmpty || row.get.values(index) == null) (index, 0) else { -// val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] -// val preambleByte = geom(0) & 0xFF -// val hasSrid = (preambleByte & 0x01) != 0 -// -// var srid = 0 -// if (hasSrid) { -// val srid2 = (geom(1) & 0xFF) << 16 -// val srid1 = (geom(2) & 0xFF) << 8 -// val srid0 = geom(3) & 0xFF -// srid = srid2 | srid1 | srid0 -// } -// (index, srid) -// } -// } -// -// TaskContext.setTaskContext(context) -// val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) -// val dataOut = new DataOutputStream(stream) -// -// // Partition index -// dataOut.writeInt(partitionIndex) -// // Python version of driver -// PythonRDD.writeUTF(pythonVer, dataOut) -// // Init a ServerSocket to accept method calls from Python side. -// val isBarrier = context.isInstanceOf[BarrierTaskContext] -// if (isBarrier) { -// serverSocket = Some(new ServerSocket(/* port */ 0, -// /* backlog */ 1, -// InetAddress.getByName("localhost"))) -// // A call to accept() for ServerSocket shall block infinitely. -// serverSocket.foreach(_.setSoTimeout(0)) -// new Thread("accept-connections") { -// setDaemon(true) -// -// override def run(): Unit = { -// while (!serverSocket.get.isClosed()) { -// var sock: Socket = null -// try { -// sock = serverSocket.get.accept() -// // Wait for function call from python side. -// sock.setSoTimeout(10000) -// authHelper.authClient(sock) -// val input = new DataInputStream(sock.getInputStream()) -// val requestMethod = input.readInt() -// // The BarrierTaskContext function may wait infinitely, socket shall not timeout -// // before the function finishes. -// sock.setSoTimeout(10000) -// requestMethod match { -// case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => -// barrierAndServe(requestMethod, sock) -// case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => -// val length = input.readInt() -// val message = new Array[Byte](length) -// input.readFully(message) -// barrierAndServe(requestMethod, sock, new String(message, UTF_8)) -// case _ => -// val out = new DataOutputStream(new BufferedOutputStream( -// sock.getOutputStream)) -// writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) -// } -// } catch { -// case e: SocketException if e.getMessage.contains("Socket closed") => -// // It is possible that the ServerSocket is not closed, but the native socket -// // has already been closed, we shall catch and silently ignore this case. -// } finally { -// if (sock != null) { -// sock.close() -// } -// } -// } -// } -// }.start() -// } -// val secret = if (isBarrier) { -// authHelper.secret -// } else { -// "" -// } -// // Close ServerSocket on task completion. -// serverSocket.foreach { server => -// context.addTaskCompletionListener[Unit](_ => server.close()) -// } -// val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) -// if (boundPort == -1) { -// val message = "ServerSocket failed to bind to Java side." -// logError(message) -// throw new SparkException(message) -// } else if (isBarrier) { -// logDebug(s"Started ServerSocket on port $boundPort.") -// } -// // Write out the TaskContextInfo -// dataOut.writeBoolean(isBarrier) -// dataOut.writeInt(boundPort) -// val secretBytes = secret.getBytes(UTF_8) -// dataOut.writeInt(secretBytes.length) -// dataOut.write(secretBytes, 0, secretBytes.length) -// dataOut.writeInt(context.stageId()) -// dataOut.writeInt(context.partitionId()) -// dataOut.writeInt(context.attemptNumber()) -// dataOut.writeLong(context.taskAttemptId()) -// dataOut.writeInt(context.cpus()) -// val resources = context.resources() -// dataOut.writeInt(resources.size) -// resources.foreach { case (k, v) => -// PythonRDD.writeUTF(k, dataOut) -// PythonRDD.writeUTF(v.name, dataOut) -// dataOut.writeInt(v.addresses.size) -// v.addresses.foreach { case addr => -// PythonRDD.writeUTF(addr, dataOut) -// } -// } -// val localProps = context.getLocalProperties.asScala -// dataOut.writeInt(localProps.size) -// localProps.foreach { case (k, v) => -// PythonRDD.writeUTF(k, dataOut) -// PythonRDD.writeUTF(v, dataOut) -// } -// -// // sparkFilesDir -// val root = jobArtifactUUID.map { uuid => -// new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath -// }.getOrElse(SparkFiles.getRootDirectory()) -// PythonRDD.writeUTF(root, dataOut) -// // Python includes (*.zip and *.egg files) -// dataOut.writeInt(pythonIncludes.size) -// for (include <- pythonIncludes) { -// PythonRDD.writeUTF(include, dataOut) -// } -// // Broadcast variables -// val oldBids = PythonRDD.getWorkerBroadcasts(worker) -// val newBids = broadcastVars.map(_.id).toSet -// // number of different broadcasts -// val toRemove = oldBids.diff(newBids) -// val addedBids = newBids.diff(oldBids) -// val cnt = toRemove.size + addedBids.size -// val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty -// dataOut.writeBoolean(needsDecryptionServer) -// dataOut.writeInt(cnt) -// -// def sendBidsToRemove(): Unit = { -// for (bid <- toRemove) { -// // remove the broadcast from worker -// dataOut.writeLong(-bid - 1) // bid >= 0 -// oldBids.remove(bid) -// } -// } -// -// if (needsDecryptionServer) { -// // if there is encryption, we setup a server which reads the encrypted files, and sends -// // the decrypted data to python -// val idsAndFiles = broadcastVars.flatMap { broadcast => -// if (!oldBids.contains(broadcast.id)) { -// oldBids.add(broadcast.id) -// Some((broadcast.id, broadcast.value.path)) -// } else { -// None -// } -// } -// val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) -// dataOut.writeInt(server.port) -// logTrace(s"broadcast decryption server setup on ${server.port}") -// PythonRDD.writeUTF(server.secret, dataOut) -// sendBidsToRemove() -// idsAndFiles.foreach { case (id, _) => -// // send new broadcast -// dataOut.writeLong(id) -// } -// dataOut.flush() -// logTrace("waiting for python to read decrypted broadcast data from server") -// server.waitTillBroadcastDataSent() -// logTrace("done sending decrypted data to python") -// } else { -// sendBidsToRemove() -// for (broadcast <- broadcastVars) { -// if (!oldBids.contains(broadcast.id)) { -// // send new broadcast -// dataOut.writeLong(broadcast.id) -// PythonRDD.writeUTF(broadcast.value.path, dataOut) -// oldBids.add(broadcast.id) -// } -// } -// } -// dataOut.flush() -// -// dataOut.writeInt(evalType) -// writeCommand(dataOut) -// -// // write number of geometry fields -// dataOut.writeInt(geometryFields.length) -// // write geometry field indices and their SRIDs -// geometryFields.foreach { case (index, srid) => -// dataOut.writeInt(index) -// dataOut.writeInt(srid) -// } -// -// writeIteratorToStream(dataOut) -// -// dataOut.writeInt(SpecialLengths.END_OF_STREAM) -// dataOut.flush() -// } catch { -// case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => -// if (context.isCompleted || context.isInterrupted) { -// logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + -// "cleanup)", t) -// if (!worker.isClosed) { -// Utils.tryLog(worker.shutdownOutput()) -// } -// } else { -// // We must avoid throwing exceptions/NonFatals here, because the thread uncaught -// // exception handler will kill the whole executor (see -// // org.apache.spark.executor.Executor). -// _exception = t -// if (!worker.isClosed) { -// Utils.tryLog(worker.shutdownOutput()) -// } -// } -// } -// } -// -// override protected def writeCommand(dataOut: DataOutputStream): Unit = ??? -// -// override protected def writeIteratorToStream(dataOut: DataOutputStream): Unit = ??? -// } -//} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala deleted file mode 100644 index 57cf2dc7bb..0000000000 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala +++ /dev/null @@ -1,349 +0,0 @@ -package org.apache.spark.sql.execution.python - - -import org.apache.sedona.common.geometrySerde.CoordinateType -import org.apache.spark._ -import org.apache.spark.SedonaSparkEnv -import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, BasePythonRunner, ChainedPythonFunctions, EncryptedPythonBroadcastServer, PythonRDD, SpecialLengths} -import org.apache.spark.internal.Logging -import org.apache.spark.internal.config.Python._ -import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES} -import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} -import org.apache.spark.security.SocketAuthHelper -import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator} -import org.apache.spark.util._ - -import java.io._ -import java.net._ -import java.nio.charset.StandardCharsets -import java.nio.charset.StandardCharsets.UTF_8 -import java.nio.file.{Path, Files => JavaFiles} -import java.util.concurrent.atomic.AtomicBoolean -import scala.collection.JavaConverters._ -import scala.util.control.NonFatal -import org.apache.spark.sql.catalyst.expressions.GenericInternalRow -import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.StructType - -abstract class SedonaWriterThread[IN, OUT]( - env: SparkEnv, - worker: Socket, - inputIterator: Iterator[IN], - partitionIndex: Int, - context: TaskContext, - pythonExec: String, - schema: org.apache.spark.sql.types.StructType, - ) - extends Thread(s"stdout writer for $pythonExec") with Logging { - self: BasePythonRunner[IN, _] => - - @volatile private var _exception: Throwable = null - private val conf = SparkEnv.get.conf - private lazy val authHelper = new SocketAuthHelper(conf) - private val pythonIncludes = funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet - private val broadcastVars = funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala)) - - setDaemon(true) - - /** Contains the throwable thrown while writing the parent iterator to the Python process. */ - def exception: Option[Throwable] = Option(_exception) - - /** - * Terminates the writer thread and waits for it to exit, ignoring any exceptions that may occur - * due to cleanup. - */ - def shutdownOnTaskCompletion(): Unit = { - assert(context.isCompleted) - this.interrupt() - // Task completion listeners that run after this method returns may invalidate - // `inputIterator`. For example, when `inputIterator` was generated by the off-heap vectorized - // reader, a task completion listener will free the underlying off-heap buffers. If the writer - // thread is still running when `inputIterator` is invalidated, it can cause a use-after-free - // bug that crashes the executor (SPARK-33277). Therefore this method must wait for the writer - // thread to exit before returning. - this.join() - } - - /** - * Writes a command section to the stream connected to the Python worker. - */ - protected def writeCommand(dataOut: DataOutputStream): Unit - - /** - * Writes input data to the stream connected to the Python worker. - */ - protected def writeIteratorToStream(dataOut: DataOutputStream): Unit - - override def run(): Unit = Utils.logUncaughtExceptions { - try { - println("ssss") - val toReadCRS = inputIterator.buffered.headOption.flatMap( - el => el.asInstanceOf[Iterator[IN]].buffered.headOption - ) - - val row = toReadCRS match { - case Some(value) => value match { - case row: GenericInternalRow => - Some(row) - } - case None => None - } - - val geometryFields = schema.zipWithIndex.filter { - case (field, index) => field.dataType == GeometryUDT - }.map { - case (field, index) => - if (row.isEmpty || row.get.values(index) == null) (index, 0) else { - val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] - val preambleByte = geom(0) & 0xFF - val hasSrid = (preambleByte & 0x01) != 0 - - var srid = 0 - if (hasSrid) { - val srid2 = (geom(1) & 0xFF) << 16 - val srid1 = (geom(2) & 0xFF) << 8 - val srid0 = geom(3) & 0xFF - srid = srid2 | srid1 | srid0 - } - (index, srid) - } - } - - TaskContext.setTaskContext(context) - val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize) - val dataOut = new DataOutputStream(stream) - - // Partition index - dataOut.writeInt(partitionIndex) - // Python version of driver - PythonRDD.writeUTF(pythonVer, dataOut) - // Init a ServerSocket to accept method calls from Python side. - val isBarrier = context.isInstanceOf[BarrierTaskContext] - if (isBarrier) { - serverSocket = Some(new ServerSocket(/* port */ 0, - /* backlog */ 1, - InetAddress.getByName("localhost"))) - // A call to accept() for ServerSocket shall block infinitely. - serverSocket.foreach(_.setSoTimeout(0)) - new Thread("accept-connections") { - setDaemon(true) - - override def run(): Unit = { - while (!serverSocket.get.isClosed()) { - var sock: Socket = null - try { - sock = serverSocket.get.accept() - // Wait for function call from python side. - sock.setSoTimeout(10000) - authHelper.authClient(sock) - val input = new DataInputStream(sock.getInputStream()) - val requestMethod = input.readInt() - // The BarrierTaskContext function may wait infinitely, socket shall not timeout - // before the function finishes. - sock.setSoTimeout(10000) - requestMethod match { - case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - barrierAndServe(requestMethod, sock) - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - val length = input.readInt() - val message = new Array[Byte](length) - input.readFully(message) - barrierAndServe(requestMethod, sock, new String(message, UTF_8)) - case _ => - val out = new DataOutputStream(new BufferedOutputStream( - sock.getOutputStream)) - writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out) - } - } catch { - case e: SocketException if e.getMessage.contains("Socket closed") => - // It is possible that the ServerSocket is not closed, but the native socket - // has already been closed, we shall catch and silently ignore this case. - } finally { - if (sock != null) { - sock.close() - } - } - } - } - }.start() - } - val secret = if (isBarrier) { - authHelper.secret - } else { - "" - } - // Close ServerSocket on task completion. - serverSocket.foreach { server => - context.addTaskCompletionListener[Unit](_ => server.close()) - } - val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0) - if (boundPort == -1) { - val message = "ServerSocket failed to bind to Java side." - logError(message) - throw new SparkException(message) - } else if (isBarrier) { - logDebug(s"Started ServerSocket on port $boundPort.") - } - // Write out the TaskContextInfo - dataOut.writeBoolean(isBarrier) - dataOut.writeInt(boundPort) - val secretBytes = secret.getBytes(UTF_8) - dataOut.writeInt(secretBytes.length) - dataOut.write(secretBytes, 0, secretBytes.length) - dataOut.writeInt(context.stageId()) - dataOut.writeInt(context.partitionId()) - dataOut.writeInt(context.attemptNumber()) - dataOut.writeLong(context.taskAttemptId()) - dataOut.writeInt(context.cpus()) - val resources = context.resources() - dataOut.writeInt(resources.size) - resources.foreach { case (k, v) => - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v.name, dataOut) - dataOut.writeInt(v.addresses.size) - v.addresses.foreach { case addr => - PythonRDD.writeUTF(addr, dataOut) - } - } - val localProps = context.getLocalProperties.asScala - dataOut.writeInt(localProps.size) - localProps.foreach { case (k, v) => - PythonRDD.writeUTF(k, dataOut) - PythonRDD.writeUTF(v, dataOut) - } - - // sparkFilesDir - val root = jobArtifactUUID.map { uuid => - new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath - }.getOrElse(SparkFiles.getRootDirectory()) - PythonRDD.writeUTF(root, dataOut) - // Python includes (*.zip and *.egg files) - dataOut.writeInt(pythonIncludes.size) - for (include <- pythonIncludes) { - PythonRDD.writeUTF(include, dataOut) - } - // Broadcast variables - val oldBids = PythonRDD.getWorkerBroadcasts(worker) - val newBids = broadcastVars.map(_.id).toSet - // number of different broadcasts - val toRemove = oldBids.diff(newBids) - val addedBids = newBids.diff(oldBids) - val cnt = toRemove.size + addedBids.size - val needsDecryptionServer = env.serializerManager.encryptionEnabled && addedBids.nonEmpty - dataOut.writeBoolean(needsDecryptionServer) - dataOut.writeInt(cnt) - - def sendBidsToRemove(): Unit = { - for (bid <- toRemove) { - // remove the broadcast from worker - dataOut.writeLong(-bid - 1) // bid >= 0 - oldBids.remove(bid) - } - } - - if (needsDecryptionServer) { - // if there is encryption, we setup a server which reads the encrypted files, and sends - // the decrypted data to python - val idsAndFiles = broadcastVars.flatMap { broadcast => - if (!oldBids.contains(broadcast.id)) { - oldBids.add(broadcast.id) - Some((broadcast.id, broadcast.value.path)) - } else { - None - } - } - val server = new EncryptedPythonBroadcastServer(env, idsAndFiles) - dataOut.writeInt(server.port) - logTrace(s"broadcast decryption server setup on ${server.port}") - PythonRDD.writeUTF(server.secret, dataOut) - sendBidsToRemove() - idsAndFiles.foreach { case (id, _) => - // send new broadcast - dataOut.writeLong(id) - } - dataOut.flush() - logTrace("waiting for python to read decrypted broadcast data from server") - server.waitTillBroadcastDataSent() - logTrace("done sending decrypted data to python") - } else { - sendBidsToRemove() - for (broadcast <- broadcastVars) { - if (!oldBids.contains(broadcast.id)) { - // send new broadcast - dataOut.writeLong(broadcast.id) - PythonRDD.writeUTF(broadcast.value.path, dataOut) - oldBids.add(broadcast.id) - } - } - } - dataOut.flush() - - dataOut.writeInt(evalType) - writeCommand(dataOut) - - // write number of geometry fields - dataOut.writeInt(geometryFields.length) - // write geometry field indices and their SRIDs - geometryFields.foreach { case (index, srid) => - dataOut.writeInt(index) - dataOut.writeInt(srid) - } - - writeIteratorToStream(dataOut) - - dataOut.writeInt(SpecialLengths.END_OF_STREAM) - dataOut.flush() - } catch { - case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) => - if (context.isCompleted || context.isInterrupted) { - logDebug("Exception/NonFatal Error thrown after task completion (likely due to " + - "cleanup)", t) - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } else { - // We must avoid throwing exceptions/NonFatals here, because the thread uncaught - // exception handler will kill the whole executor (see - // org.apache.spark.executor.Executor). - _exception = t - if (!worker.isClosed) { - Utils.tryLog(worker.shutdownOutput()) - } - } - } - } - - /** - * Gateway to call BarrierTaskContext methods. - */ - def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): Unit = { - require( - serverSocket.isDefined, - "No available ServerSocket to redirect the BarrierTaskContext method call." - ) - val out = new DataOutputStream(new BufferedOutputStream(sock.getOutputStream)) - try { - val messages = requestMethod match { - case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION => - context.asInstanceOf[BarrierTaskContext].barrier() - Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS) - case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION => - context.asInstanceOf[BarrierTaskContext].allGather(message) - } - out.writeInt(messages.length) - messages.foreach(writeUTF(_, out)) - } catch { - case e: SparkException => - writeUTF(e.getMessage, out) - } finally { - out.close() - } - } - - def writeUTF(str: String, dataOut: DataOutputStream): Unit = { - val bytes = str.getBytes(UTF_8) - dataOut.writeInt(bytes.length) - dataOut.write(bytes) - } -} -
