This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch arrow-worker
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit eee700d4986da5f6f0a3a242857e4d08b213b89d
Author: pawelkocinski <[email protected]>
AuthorDate: Sun Dec 21 23:58:42 2025 +0100

    add code so far
---
 .../execution/python/SedonaArrowPythonRunner.scala |  12 -
 .../execution/python/SedonaPythonArrowInput.scala  |  38 +--
 .../execution/python/SedonaPythonUDFRunner.scala   | 147 ---------
 .../spark/sql/execution/python/SedonaThread1.scala | 285 -----------------
 .../sql/execution/python/SedonaWriterThread.scala  | 349 ---------------------
 5 files changed, 1 insertion(+), 830 deletions(-)

diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
index 2a28eba6db..16b81b50c0 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
@@ -56,15 +56,3 @@ class SedonaArrowPythonRunner(
     "Pandas execution requires more than 4 bytes. Please set higher buffer. " +
       s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.")
 }
-
-object SedonaArrowPythonRunner {
-  /** Return Map with conf settings to be used in ArrowPythonRunner */
-  def getPythonRunnerConfMap(conf: SQLConf): Map[String, String] = {
-    val timeZoneConf = Seq(SQLConf.SESSION_LOCAL_TIMEZONE.key -> 
conf.sessionLocalTimeZone)
-    val pandasColsByName = 
Seq(SQLConf.PANDAS_GROUPED_MAP_ASSIGN_COLUMNS_BY_NAME.key ->
-      conf.pandasGroupedMapAssignColumnsByName.toString)
-    val arrowSafeTypeCheck = Seq(SQLConf.PANDAS_ARROW_SAFE_TYPE_CONVERSION.key 
->
-      conf.arrowSafeTypeConversion.toString)
-    Map(timeZoneConf ++ pandasColsByName ++ arrowSafeTypeCheck: _*)
-  }
-}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
index 7bc0d322c2..9c0b0c94b4 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
@@ -27,8 +27,6 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.types.StructType
-import org.apache.spark.sql.util.ArrowUtils
-import org.apache.spark.sql.util.ArrowUtils.toArrowSchema
 import org.apache.spark.util.Utils
 import org.apache.spark.{SparkEnv, TaskContext}
 
@@ -39,39 +37,7 @@ import java.net.Socket
  * A trait that can be mixed-in with [[python.BasePythonRunner]]. It 
implements the logic from
  * JVM (an iterator of internal rows + additional data if required) to Python 
(Arrow).
  */
-private[python] trait SedonaPythonArrowInput[IN] { self: BasePythonRunner[IN, 
_] =>
-  protected val workerConf: Map[String, String]
-
-  protected val schema: StructType
-
-  protected val timeZoneId: String
-
-  protected val errorOnDuplicatedFieldNames: Boolean
-
-  protected val largeVarTypes: Boolean
-
-  protected def pythonMetrics: Map[String, SQLMetric]
-
-  protected def writeIteratorToArrowStream(
-                                            root: VectorSchemaRoot,
-                                            writer: ArrowStreamWriter,
-                                            dataOut: DataOutputStream,
-                                            inputIterator: Iterator[IN]): Unit
-
-  protected def writeUDF(
-                          dataOut: DataOutputStream,
-                          funcs: Seq[ChainedPythonFunctions],
-                          argOffsets: Array[Array[Int]]): Unit =
-    SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
-
-  protected def handleMetadataBeforeExec(stream: DataOutputStream): Unit = {
-    // Write config for the worker as a number of key -> value pairs of strings
-    stream.writeInt(workerConf.size)
-    for ((k, v) <- workerConf) {
-      PythonRDD.writeUTF(k, stream)
-      PythonRDD.writeUTF(v, stream)
-    }
-  }
+private[python] trait SedonaPythonArrowInput[IN] extends PythonArrowInput[IN] 
{ self: BasePythonRunner[IN, _] =>
 
   protected override def newWriterThread(
                                           env: SparkEnv,
@@ -79,8 +45,6 @@ private[python] trait SedonaPythonArrowInput[IN] { self: 
BasePythonRunner[IN, _]
                                           inputIterator: Iterator[IN],
                                           partitionIndex: Int,
                                           context: TaskContext): WriterThread 
= {
-//    createWorkerThread(env, worker, inputIterator, partitionIndex, context, 
schema)
-
     new WriterThread(env, worker, inputIterator, partitionIndex, context) {
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
deleted file mode 100644
index dcf93b5213..0000000000
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
+++ /dev/null
@@ -1,147 +0,0 @@
-package org.apache.spark.sql.execution.python
-
-/*
- * Licensed to the Apache Software Foundation (ASF) under one or more
- * contributor license agreements.  See the NOTICE file distributed with
- * this work for additional information regarding copyright ownership.
- * The ASF licenses this file to You under the Apache License, Version 2.0
- * (the "License"); you may not use this file except in compliance with
- * the License.  You may obtain a copy of the License at
- *
- *    http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing, software
- * distributed under the License is distributed on an "AS IS" BASIS,
- * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
- * See the License for the specific language governing permissions and
- * limitations under the License.
- */
-
-import java.io._
-import java.net._
-import java.util.concurrent.atomic.AtomicBoolean
-import org.apache.spark._
-import org.apache.spark.api.python._
-import org.apache.spark.sql.execution.metric.SQLMetric
-import org.apache.spark.sql.internal.SQLConf
-
-/**
- * A helper class to run Python UDFs in Spark.
- */
-abstract class SedonaBasePythonUDFRunner(
-                                    funcs: Seq[ChainedPythonFunctions],
-                                    evalType: Int,
-                                    argOffsets: Array[Array[Int]],
-                                    pythonMetrics: Map[String, SQLMetric],
-                                    jobArtifactUUID: Option[String])
-  extends BasePythonRunner[Array[Byte], Array[Byte]](
-    funcs, evalType, argOffsets, jobArtifactUUID) {
-
-  override val pythonExec: String =
-    SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
-      funcs.head.funcs.head.pythonExec)
-
-  override val simplifiedTraceback: Boolean = 
SQLConf.get.pysparkSimplifiedTraceback
-
-  abstract class SedonaPythonUDFWriterThread(
-                                        env: SparkEnv,
-                                        worker: Socket,
-                                        inputIterator: Iterator[Array[Byte]],
-                                        partitionIndex: Int,
-                                        context: TaskContext)
-    extends WriterThread(env, worker, inputIterator, partitionIndex, context) {
-
-    protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
-      val startData = dataOut.size()
-
-      PythonRDD.writeIteratorToStream(inputIterator, dataOut)
-      dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
-
-      val deltaData = dataOut.size() - startData
-      pythonMetrics("pythonDataSent") += deltaData
-    }
-  }
-
-  protected override def newReaderIterator(
-                                            stream: DataInputStream,
-                                            writerThread: WriterThread,
-                                            startTime: Long,
-                                            env: SparkEnv,
-                                            worker: Socket,
-                                            pid: Option[Int],
-                                            releasedOrClosed: AtomicBoolean,
-                                            context: TaskContext): 
Iterator[Array[Byte]] = {
-    new ReaderIterator(
-      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context) {
-
-      protected override def read(): Array[Byte] = {
-        if (writerThread.exception.isDefined) {
-          throw writerThread.exception.get
-        }
-        try {
-          stream.readInt() match {
-            case length if length > 0 =>
-              val obj = new Array[Byte](length)
-              stream.readFully(obj)
-              pythonMetrics("pythonDataReceived") += length
-              obj
-            case 0 => Array.emptyByteArray
-            case SpecialLengths.TIMING_DATA =>
-              handleTimingData()
-              read()
-            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
-              throw handlePythonException()
-            case SpecialLengths.END_OF_DATA_SECTION =>
-              handleEndOfDataSection()
-              null
-          }
-        } catch handleException
-      }
-    }
-  }
-}
-
-class SedonaPythonUDFRunner(
-                       funcs: Seq[ChainedPythonFunctions],
-                       evalType: Int,
-                       argOffsets: Array[Array[Int]],
-                       pythonMetrics: Map[String, SQLMetric],
-                       jobArtifactUUID: Option[String])
-  extends SedonaBasePythonUDFRunner(funcs, evalType, argOffsets, 
pythonMetrics, jobArtifactUUID) {
-
-  protected override def newWriterThread(
-                                          env: SparkEnv,
-                                          worker: Socket,
-                                          inputIterator: Iterator[Array[Byte]],
-                                          partitionIndex: Int,
-                                          context: TaskContext): WriterThread 
= {
-    new SedonaPythonUDFWriterThread(env, worker, inputIterator, 
partitionIndex, context) {
-
-      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-        SedonaPythonUDFRunner.writeUDFs(dataOut, funcs, argOffsets)
-      }
-
-    }
-  }
-}
-
-object SedonaPythonUDFRunner {
-
-  def writeUDFs(
-                 dataOut: DataOutputStream,
-                 funcs: Seq[ChainedPythonFunctions],
-                 argOffsets: Array[Array[Int]]): Unit = {
-    dataOut.writeInt(funcs.length)
-    funcs.zip(argOffsets).foreach { case (chained, offsets) =>
-      dataOut.writeInt(offsets.length)
-      offsets.foreach { offset =>
-        dataOut.writeInt(offset)
-      }
-      dataOut.writeInt(chained.funcs.length)
-      chained.funcs.foreach { f =>
-        dataOut.writeInt(f.command.length)
-        dataOut.write(f.command.toArray)
-      }
-    }
-  }
-}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala
deleted file mode 100644
index 41fc67df3a..0000000000
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala
+++ /dev/null
@@ -1,285 +0,0 @@
-package org.apache.spark.sql.execution.python
-
-import org.apache.arrow.vector.VectorSchemaRoot
-import org.apache.arrow.vector.ipc.ArrowStreamWriter
-import org.apache.spark.{BarrierTaskContext, SparkEnv, SparkException, 
SparkFiles, TaskContext}
-import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, 
BasePythonRunner, EncryptedPythonBroadcastServer, PythonRDD, 
SedonaBasePythonRunner, SpecialLengths}
-import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
-import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.types.StructType
-import org.apache.spark.util.Utils
-
-import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, File}
-import java.net.{InetAddress, ServerSocket, Socket, SocketException}
-import java.nio.charset.StandardCharsets.UTF_8
-import scala.util.control.NonFatal
-//
-//trait WorkerThreadFactory[IN] {
-//  self: BasePythonRunner [IN, _] =>
-//
-//  // Define common functionality for worker threads here
-//
-//  def createWorkerThread(
-//                          env: SparkEnv,
-//                          worker: Socket,
-//                          inputIterator: Iterator[IN],
-//                          partitionIndex: Int,
-//                          context: TaskContext,
-//                          schema: StructType,
-//                        ): WriterThread = {
-//    new SedonaThread1 (env, worker, inputIterator, partitionIndex, context, 
schema)
-//  }
-//
-//  class SedonaThread1(
-//                       env: SparkEnv,
-//                       worker: Socket,
-//                       inputIterator: Iterator[IN],
-//                       partitionIndex: Int,
-//                       context: TaskContext,
-//                       schema: StructType,
-//                     ) extends WriterThread(env, worker, inputIterator, 
partitionIndex, context) {
-//
-//
-//    override def run(): Unit = Utils.logUncaughtExceptions {
-//      try {
-//        val toReadCRS = inputIterator.buffered.headOption.flatMap(
-//          el => el.asInstanceOf[Iterator[IN]].buffered.headOption
-//        )
-//
-//        val row = toReadCRS match {
-//          case Some(value) => value match {
-//            case row: GenericInternalRow =>
-//              Some(row)
-//          }
-//          case None => None
-//        }
-//
-//        val geometryFields = schema.zipWithIndex.filter {
-//          case (field, index) => field.dataType == GeometryUDT
-//        }.map {
-//          case (field, index) =>
-//            if (row.isEmpty || row.get.values(index) == null) (index, 0) 
else {
-//              val geom = row.get.get(index, 
GeometryUDT).asInstanceOf[Array[Byte]]
-//              val preambleByte = geom(0) & 0xFF
-//              val hasSrid = (preambleByte & 0x01) != 0
-//
-//              var srid = 0
-//              if (hasSrid) {
-//                val srid2 = (geom(1) & 0xFF) << 16
-//                val srid1 = (geom(2) & 0xFF) << 8
-//                val srid0 = geom(3) & 0xFF
-//                srid = srid2 | srid1 | srid0
-//              }
-//              (index, srid)
-//            }
-//        }
-//
-//        TaskContext.setTaskContext(context)
-//        val stream = new BufferedOutputStream(worker.getOutputStream, 
bufferSize)
-//        val dataOut = new DataOutputStream(stream)
-//
-//        // Partition index
-//        dataOut.writeInt(partitionIndex)
-//        // Python version of driver
-//        PythonRDD.writeUTF(pythonVer, dataOut)
-//        // Init a ServerSocket to accept method calls from Python side.
-//        val isBarrier = context.isInstanceOf[BarrierTaskContext]
-//        if (isBarrier) {
-//          serverSocket = Some(new ServerSocket(/* port */ 0,
-//            /* backlog */ 1,
-//            InetAddress.getByName("localhost")))
-//          // A call to accept() for ServerSocket shall block infinitely.
-//          serverSocket.foreach(_.setSoTimeout(0))
-//          new Thread("accept-connections") {
-//            setDaemon(true)
-//
-//            override def run(): Unit = {
-//              while (!serverSocket.get.isClosed()) {
-//                var sock: Socket = null
-//                try {
-//                  sock = serverSocket.get.accept()
-//                  // Wait for function call from python side.
-//                  sock.setSoTimeout(10000)
-//                  authHelper.authClient(sock)
-//                  val input = new DataInputStream(sock.getInputStream())
-//                  val requestMethod = input.readInt()
-//                  // The BarrierTaskContext function may wait infinitely, 
socket shall not timeout
-//                  // before the function finishes.
-//                  sock.setSoTimeout(10000)
-//                  requestMethod match {
-//                    case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION 
=>
-//                      barrierAndServe(requestMethod, sock)
-//                    case 
BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
-//                      val length = input.readInt()
-//                      val message = new Array[Byte](length)
-//                      input.readFully(message)
-//                      barrierAndServe(requestMethod, sock, new 
String(message, UTF_8))
-//                    case _ =>
-//                      val out = new DataOutputStream(new 
BufferedOutputStream(
-//                        sock.getOutputStream))
-//                      
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
-//                  }
-//                } catch {
-//                  case e: SocketException if e.getMessage.contains("Socket 
closed") =>
-//                  // It is possible that the ServerSocket is not closed, but 
the native socket
-//                  // has already been closed, we shall catch and silently 
ignore this case.
-//                } finally {
-//                  if (sock != null) {
-//                    sock.close()
-//                  }
-//                }
-//              }
-//            }
-//          }.start()
-//        }
-//        val secret = if (isBarrier) {
-//          authHelper.secret
-//        } else {
-//          ""
-//        }
-//        // Close ServerSocket on task completion.
-//        serverSocket.foreach { server =>
-//          context.addTaskCompletionListener[Unit](_ => server.close())
-//        }
-//        val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
-//        if (boundPort == -1) {
-//          val message = "ServerSocket failed to bind to Java side."
-//          logError(message)
-//          throw new SparkException(message)
-//        } else if (isBarrier) {
-//          logDebug(s"Started ServerSocket on port $boundPort.")
-//        }
-//        // Write out the TaskContextInfo
-//        dataOut.writeBoolean(isBarrier)
-//        dataOut.writeInt(boundPort)
-//        val secretBytes = secret.getBytes(UTF_8)
-//        dataOut.writeInt(secretBytes.length)
-//        dataOut.write(secretBytes, 0, secretBytes.length)
-//        dataOut.writeInt(context.stageId())
-//        dataOut.writeInt(context.partitionId())
-//        dataOut.writeInt(context.attemptNumber())
-//        dataOut.writeLong(context.taskAttemptId())
-//        dataOut.writeInt(context.cpus())
-//        val resources = context.resources()
-//        dataOut.writeInt(resources.size)
-//        resources.foreach { case (k, v) =>
-//          PythonRDD.writeUTF(k, dataOut)
-//          PythonRDD.writeUTF(v.name, dataOut)
-//          dataOut.writeInt(v.addresses.size)
-//          v.addresses.foreach { case addr =>
-//            PythonRDD.writeUTF(addr, dataOut)
-//          }
-//        }
-//        val localProps = context.getLocalProperties.asScala
-//        dataOut.writeInt(localProps.size)
-//        localProps.foreach { case (k, v) =>
-//          PythonRDD.writeUTF(k, dataOut)
-//          PythonRDD.writeUTF(v, dataOut)
-//        }
-//
-//        // sparkFilesDir
-//        val root = jobArtifactUUID.map { uuid =>
-//          new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath
-//        }.getOrElse(SparkFiles.getRootDirectory())
-//        PythonRDD.writeUTF(root, dataOut)
-//        // Python includes (*.zip and *.egg files)
-//        dataOut.writeInt(pythonIncludes.size)
-//        for (include <- pythonIncludes) {
-//          PythonRDD.writeUTF(include, dataOut)
-//        }
-//        // Broadcast variables
-//        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
-//        val newBids = broadcastVars.map(_.id).toSet
-//        // number of different broadcasts
-//        val toRemove = oldBids.diff(newBids)
-//        val addedBids = newBids.diff(oldBids)
-//        val cnt = toRemove.size + addedBids.size
-//        val needsDecryptionServer = env.serializerManager.encryptionEnabled 
&& addedBids.nonEmpty
-//        dataOut.writeBoolean(needsDecryptionServer)
-//        dataOut.writeInt(cnt)
-//
-//        def sendBidsToRemove(): Unit = {
-//          for (bid <- toRemove) {
-//            // remove the broadcast from worker
-//            dataOut.writeLong(-bid - 1) // bid >= 0
-//            oldBids.remove(bid)
-//          }
-//        }
-//
-//        if (needsDecryptionServer) {
-//          // if there is encryption, we setup a server which reads the 
encrypted files, and sends
-//          // the decrypted data to python
-//          val idsAndFiles = broadcastVars.flatMap { broadcast =>
-//            if (!oldBids.contains(broadcast.id)) {
-//              oldBids.add(broadcast.id)
-//              Some((broadcast.id, broadcast.value.path))
-//            } else {
-//              None
-//            }
-//          }
-//          val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
-//          dataOut.writeInt(server.port)
-//          logTrace(s"broadcast decryption server setup on ${server.port}")
-//          PythonRDD.writeUTF(server.secret, dataOut)
-//          sendBidsToRemove()
-//          idsAndFiles.foreach { case (id, _) =>
-//            // send new broadcast
-//            dataOut.writeLong(id)
-//          }
-//          dataOut.flush()
-//          logTrace("waiting for python to read decrypted broadcast data from 
server")
-//          server.waitTillBroadcastDataSent()
-//          logTrace("done sending decrypted data to python")
-//        } else {
-//          sendBidsToRemove()
-//          for (broadcast <- broadcastVars) {
-//            if (!oldBids.contains(broadcast.id)) {
-//              // send new broadcast
-//              dataOut.writeLong(broadcast.id)
-//              PythonRDD.writeUTF(broadcast.value.path, dataOut)
-//              oldBids.add(broadcast.id)
-//            }
-//          }
-//        }
-//        dataOut.flush()
-//
-//        dataOut.writeInt(evalType)
-//        writeCommand(dataOut)
-//
-//        // write number of geometry fields
-//        dataOut.writeInt(geometryFields.length)
-//        // write geometry field indices and their SRIDs
-//        geometryFields.foreach { case (index, srid) =>
-//          dataOut.writeInt(index)
-//          dataOut.writeInt(srid)
-//        }
-//
-//        writeIteratorToStream(dataOut)
-//
-//        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
-//        dataOut.flush()
-//      } catch {
-//        case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) =>
-//          if (context.isCompleted || context.isInterrupted) {
-//            logDebug("Exception/NonFatal Error thrown after task completion 
(likely due to " +
-//              "cleanup)", t)
-//            if (!worker.isClosed) {
-//              Utils.tryLog(worker.shutdownOutput())
-//            }
-//          } else {
-//            // We must avoid throwing exceptions/NonFatals here, because the 
thread uncaught
-//            // exception handler will kill the whole executor (see
-//            // org.apache.spark.executor.Executor).
-//            _exception = t
-//            if (!worker.isClosed) {
-//              Utils.tryLog(worker.shutdownOutput())
-//            }
-//          }
-//      }
-//    }
-//
-//    override protected def writeCommand(dataOut: DataOutputStream): Unit = 
???
-//
-//    override protected def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = ???
-//  }
-//}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala
deleted file mode 100644
index 57cf2dc7bb..0000000000
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala
+++ /dev/null
@@ -1,349 +0,0 @@
-package org.apache.spark.sql.execution.python
-
-
-import org.apache.sedona.common.geometrySerde.CoordinateType
-import org.apache.spark._
-import org.apache.spark.SedonaSparkEnv
-import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, 
BasePythonRunner, ChainedPythonFunctions, EncryptedPythonBroadcastServer, 
PythonRDD, SpecialLengths}
-import org.apache.spark.internal.Logging
-import org.apache.spark.internal.config.Python._
-import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
-import 
org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, 
PYSPARK_MEMORY_LOCAL_PROPERTY}
-import org.apache.spark.security.SocketAuthHelper
-import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator}
-import org.apache.spark.util._
-
-import java.io._
-import java.net._
-import java.nio.charset.StandardCharsets
-import java.nio.charset.StandardCharsets.UTF_8
-import java.nio.file.{Path, Files => JavaFiles}
-import java.util.concurrent.atomic.AtomicBoolean
-import scala.collection.JavaConverters._
-import scala.util.control.NonFatal
-import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
-import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.types.StructType
-
-abstract class SedonaWriterThread[IN, OUT](
-                                            env: SparkEnv,
-                                            worker: Socket,
-                                            inputIterator: Iterator[IN],
-                                            partitionIndex: Int,
-                                            context: TaskContext,
-                                            pythonExec: String,
-                                            schema: 
org.apache.spark.sql.types.StructType,
-                                          )
-  extends Thread(s"stdout writer for $pythonExec") with Logging {
-  self: BasePythonRunner[IN, _] =>
-
-  @volatile private var _exception: Throwable = null
-  private val conf = SparkEnv.get.conf
-  private lazy val authHelper = new SocketAuthHelper(conf)
-  private val pythonIncludes = 
funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
-  private val broadcastVars = 
funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
-
-  setDaemon(true)
-
-  /** Contains the throwable thrown while writing the parent iterator to the 
Python process. */
-  def exception: Option[Throwable] = Option(_exception)
-
-  /**
-   * Terminates the writer thread and waits for it to exit, ignoring any 
exceptions that may occur
-   * due to cleanup.
-   */
-  def shutdownOnTaskCompletion(): Unit = {
-    assert(context.isCompleted)
-    this.interrupt()
-    // Task completion listeners that run after this method returns may 
invalidate
-    // `inputIterator`. For example, when `inputIterator` was generated by the 
off-heap vectorized
-    // reader, a task completion listener will free the underlying off-heap 
buffers. If the writer
-    // thread is still running when `inputIterator` is invalidated, it can 
cause a use-after-free
-    // bug that crashes the executor (SPARK-33277). Therefore this method must 
wait for the writer
-    // thread to exit before returning.
-    this.join()
-  }
-
-  /**
-   * Writes a command section to the stream connected to the Python worker.
-   */
-  protected def writeCommand(dataOut: DataOutputStream): Unit
-
-  /**
-   * Writes input data to the stream connected to the Python worker.
-   */
-  protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
-
-  override def run(): Unit = Utils.logUncaughtExceptions {
-    try {
-      println("ssss")
-      val toReadCRS = inputIterator.buffered.headOption.flatMap(
-        el => el.asInstanceOf[Iterator[IN]].buffered.headOption
-      )
-
-      val row = toReadCRS match {
-        case Some(value) => value match {
-          case row: GenericInternalRow =>
-            Some(row)
-        }
-        case None => None
-      }
-
-      val geometryFields = schema.zipWithIndex.filter {
-        case (field, index) => field.dataType == GeometryUDT
-      }.map {
-        case (field, index) =>
-          if (row.isEmpty || row.get.values(index) == null) (index, 0) else {
-            val geom = row.get.get(index, 
GeometryUDT).asInstanceOf[Array[Byte]]
-            val preambleByte = geom(0) & 0xFF
-            val hasSrid = (preambleByte & 0x01) != 0
-
-            var srid = 0
-            if (hasSrid) {
-              val srid2 = (geom(1) & 0xFF) << 16
-              val srid1 = (geom(2) & 0xFF) << 8
-              val srid0 = geom(3) & 0xFF
-              srid = srid2 | srid1 | srid0
-            }
-            (index, srid)
-          }
-      }
-
-      TaskContext.setTaskContext(context)
-      val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
-      val dataOut = new DataOutputStream(stream)
-
-      // Partition index
-      dataOut.writeInt(partitionIndex)
-      // Python version of driver
-      PythonRDD.writeUTF(pythonVer, dataOut)
-      // Init a ServerSocket to accept method calls from Python side.
-      val isBarrier = context.isInstanceOf[BarrierTaskContext]
-      if (isBarrier) {
-        serverSocket = Some(new ServerSocket(/* port */ 0,
-          /* backlog */ 1,
-          InetAddress.getByName("localhost")))
-        // A call to accept() for ServerSocket shall block infinitely.
-        serverSocket.foreach(_.setSoTimeout(0))
-        new Thread("accept-connections") {
-          setDaemon(true)
-
-          override def run(): Unit = {
-            while (!serverSocket.get.isClosed()) {
-              var sock: Socket = null
-              try {
-                sock = serverSocket.get.accept()
-                // Wait for function call from python side.
-                sock.setSoTimeout(10000)
-                authHelper.authClient(sock)
-                val input = new DataInputStream(sock.getInputStream())
-                val requestMethod = input.readInt()
-                // The BarrierTaskContext function may wait infinitely, socket 
shall not timeout
-                // before the function finishes.
-                sock.setSoTimeout(10000)
-                requestMethod match {
-                  case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
-                    barrierAndServe(requestMethod, sock)
-                  case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
-                    val length = input.readInt()
-                    val message = new Array[Byte](length)
-                    input.readFully(message)
-                    barrierAndServe(requestMethod, sock, new String(message, 
UTF_8))
-                  case _ =>
-                    val out = new DataOutputStream(new BufferedOutputStream(
-                      sock.getOutputStream))
-                    
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
-                }
-              } catch {
-                case e: SocketException if e.getMessage.contains("Socket 
closed") =>
-                // It is possible that the ServerSocket is not closed, but the 
native socket
-                // has already been closed, we shall catch and silently ignore 
this case.
-              } finally {
-                if (sock != null) {
-                  sock.close()
-                }
-              }
-            }
-          }
-        }.start()
-      }
-      val secret = if (isBarrier) {
-        authHelper.secret
-      } else {
-        ""
-      }
-      // Close ServerSocket on task completion.
-      serverSocket.foreach { server =>
-        context.addTaskCompletionListener[Unit](_ => server.close())
-      }
-      val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
-      if (boundPort == -1) {
-        val message = "ServerSocket failed to bind to Java side."
-        logError(message)
-        throw new SparkException(message)
-      } else if (isBarrier) {
-        logDebug(s"Started ServerSocket on port $boundPort.")
-      }
-      // Write out the TaskContextInfo
-      dataOut.writeBoolean(isBarrier)
-      dataOut.writeInt(boundPort)
-      val secretBytes = secret.getBytes(UTF_8)
-      dataOut.writeInt(secretBytes.length)
-      dataOut.write(secretBytes, 0, secretBytes.length)
-      dataOut.writeInt(context.stageId())
-      dataOut.writeInt(context.partitionId())
-      dataOut.writeInt(context.attemptNumber())
-      dataOut.writeLong(context.taskAttemptId())
-      dataOut.writeInt(context.cpus())
-      val resources = context.resources()
-      dataOut.writeInt(resources.size)
-      resources.foreach { case (k, v) =>
-        PythonRDD.writeUTF(k, dataOut)
-        PythonRDD.writeUTF(v.name, dataOut)
-        dataOut.writeInt(v.addresses.size)
-        v.addresses.foreach { case addr =>
-          PythonRDD.writeUTF(addr, dataOut)
-        }
-      }
-      val localProps = context.getLocalProperties.asScala
-      dataOut.writeInt(localProps.size)
-      localProps.foreach { case (k, v) =>
-        PythonRDD.writeUTF(k, dataOut)
-        PythonRDD.writeUTF(v, dataOut)
-      }
-
-      // sparkFilesDir
-      val root = jobArtifactUUID.map { uuid =>
-        new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath
-      }.getOrElse(SparkFiles.getRootDirectory())
-      PythonRDD.writeUTF(root, dataOut)
-      // Python includes (*.zip and *.egg files)
-      dataOut.writeInt(pythonIncludes.size)
-      for (include <- pythonIncludes) {
-        PythonRDD.writeUTF(include, dataOut)
-      }
-      // Broadcast variables
-      val oldBids = PythonRDD.getWorkerBroadcasts(worker)
-      val newBids = broadcastVars.map(_.id).toSet
-      // number of different broadcasts
-      val toRemove = oldBids.diff(newBids)
-      val addedBids = newBids.diff(oldBids)
-      val cnt = toRemove.size + addedBids.size
-      val needsDecryptionServer = env.serializerManager.encryptionEnabled && 
addedBids.nonEmpty
-      dataOut.writeBoolean(needsDecryptionServer)
-      dataOut.writeInt(cnt)
-
-      def sendBidsToRemove(): Unit = {
-        for (bid <- toRemove) {
-          // remove the broadcast from worker
-          dataOut.writeLong(-bid - 1) // bid >= 0
-          oldBids.remove(bid)
-        }
-      }
-
-      if (needsDecryptionServer) {
-        // if there is encryption, we setup a server which reads the encrypted 
files, and sends
-        // the decrypted data to python
-        val idsAndFiles = broadcastVars.flatMap { broadcast =>
-          if (!oldBids.contains(broadcast.id)) {
-            oldBids.add(broadcast.id)
-            Some((broadcast.id, broadcast.value.path))
-          } else {
-            None
-          }
-        }
-        val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
-        dataOut.writeInt(server.port)
-        logTrace(s"broadcast decryption server setup on ${server.port}")
-        PythonRDD.writeUTF(server.secret, dataOut)
-        sendBidsToRemove()
-        idsAndFiles.foreach { case (id, _) =>
-          // send new broadcast
-          dataOut.writeLong(id)
-        }
-        dataOut.flush()
-        logTrace("waiting for python to read decrypted broadcast data from 
server")
-        server.waitTillBroadcastDataSent()
-        logTrace("done sending decrypted data to python")
-      } else {
-        sendBidsToRemove()
-        for (broadcast <- broadcastVars) {
-          if (!oldBids.contains(broadcast.id)) {
-            // send new broadcast
-            dataOut.writeLong(broadcast.id)
-            PythonRDD.writeUTF(broadcast.value.path, dataOut)
-            oldBids.add(broadcast.id)
-          }
-        }
-      }
-      dataOut.flush()
-
-      dataOut.writeInt(evalType)
-      writeCommand(dataOut)
-
-      // write number of geometry fields
-      dataOut.writeInt(geometryFields.length)
-      // write geometry field indices and their SRIDs
-      geometryFields.foreach { case (index, srid) =>
-        dataOut.writeInt(index)
-        dataOut.writeInt(srid)
-      }
-
-      writeIteratorToStream(dataOut)
-
-      dataOut.writeInt(SpecialLengths.END_OF_STREAM)
-      dataOut.flush()
-    } catch {
-      case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) =>
-        if (context.isCompleted || context.isInterrupted) {
-          logDebug("Exception/NonFatal Error thrown after task completion 
(likely due to " +
-            "cleanup)", t)
-          if (!worker.isClosed) {
-            Utils.tryLog(worker.shutdownOutput())
-          }
-        } else {
-          // We must avoid throwing exceptions/NonFatals here, because the 
thread uncaught
-          // exception handler will kill the whole executor (see
-          // org.apache.spark.executor.Executor).
-          _exception = t
-          if (!worker.isClosed) {
-            Utils.tryLog(worker.shutdownOutput())
-          }
-        }
-    }
-  }
-
-  /**
-   * Gateway to call BarrierTaskContext methods.
-   */
-  def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): 
Unit = {
-    require(
-      serverSocket.isDefined,
-      "No available ServerSocket to redirect the BarrierTaskContext method 
call."
-    )
-    val out = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream))
-    try {
-      val messages = requestMethod match {
-        case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
-          context.asInstanceOf[BarrierTaskContext].barrier()
-          Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS)
-        case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
-          context.asInstanceOf[BarrierTaskContext].allGather(message)
-      }
-      out.writeInt(messages.length)
-      messages.foreach(writeUTF(_, out))
-    } catch {
-      case e: SparkException =>
-        writeUTF(e.getMessage, out)
-    } finally {
-      out.close()
-    }
-  }
-
-  def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
-    val bytes = str.getBytes(UTF_8)
-    dataOut.writeInt(bytes.length)
-    dataOut.write(bytes)
-  }
-}
-


Reply via email to