This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch arrow-worker
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 37389d606152519b8431903944b8abf4196acfd6
Author: pawelkocinski <[email protected]>
AuthorDate: Sun Dec 21 23:17:05 2025 +0100

    add code so far
---
 .../spark/api/python/SedonaPythonRunner.scala      |  99 +++---
 .../execution/python/SedonaArrowPythonRunner.scala |   2 +-
 .../sql/execution/python/SedonaArrowUtils.scala    |   8 +-
 .../execution/python/SedonaPythonArrowInput.scala  |   2 +
 .../execution/python/SedonaPythonUDFRunner.scala   |   2 +-
 .../spark/sql/execution/python/SedonaThread1.scala | 285 +++++++++++++++++
 .../sql/execution/python/SedonaWriterThread.scala  | 349 +++++++++++++++++++++
 .../org/apache/spark/sql/udf/StrategySuite.scala   |  25 +-
 8 files changed, 718 insertions(+), 54 deletions(-)

diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
index bdc989d82b..38a9c7182b 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
@@ -17,6 +17,7 @@ package org.apache.spark.api.python
  * limitations under the License.
  */
 
+import org.apache.sedona.common.geometrySerde.CoordinateType
 import org.apache.spark._
 import org.apache.spark.SedonaSparkEnv
 import org.apache.spark.internal.Logging
@@ -24,6 +25,7 @@ import org.apache.spark.internal.config.Python._
 import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
 import 
org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, 
PYSPARK_MEMORY_LOCAL_PROPERTY}
 import org.apache.spark.security.SocketAuthHelper
+import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator}
 import org.apache.spark.util._
 
 import java.io._
@@ -34,48 +36,11 @@ import java.nio.file.{Path, Files => JavaFiles}
 import java.util.concurrent.atomic.AtomicBoolean
 import scala.collection.JavaConverters._
 import scala.util.control.NonFatal
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.types.StructType
 
 
-/**
- * Enumerate the type of command that will be sent to the Python worker
- */
-private[spark] object PythonEvalType {
-  val NON_UDF = 0
-
-  val SQL_BATCHED_UDF = 100
-  val SQL_ARROW_BATCHED_UDF = 101
-
-  val SQL_SCALAR_PANDAS_UDF = 200
-  val SQL_GROUPED_MAP_PANDAS_UDF = 201
-  val SQL_GROUPED_AGG_PANDAS_UDF = 202
-  val SQL_WINDOW_AGG_PANDAS_UDF = 203
-  val SQL_SCALAR_PANDAS_ITER_UDF = 204
-  val SQL_MAP_PANDAS_ITER_UDF = 205
-  val SQL_COGROUPED_MAP_PANDAS_UDF = 206
-  val SQL_MAP_ARROW_ITER_UDF = 207
-  val SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE = 208
-
-  val SQL_TABLE_UDF = 300
-  val SQL_ARROW_TABLE_UDF = 301
-
-  def toString(pythonEvalType: Int): String = pythonEvalType match {
-    case NON_UDF => "NON_UDF"
-    case SQL_BATCHED_UDF => "SQL_BATCHED_UDF"
-    case SQL_ARROW_BATCHED_UDF => "SQL_ARROW_BATCHED_UDF"
-    case SQL_SCALAR_PANDAS_UDF => "SQL_SCALAR_PANDAS_UDF"
-    case SQL_GROUPED_MAP_PANDAS_UDF => "SQL_GROUPED_MAP_PANDAS_UDF"
-    case SQL_GROUPED_AGG_PANDAS_UDF => "SQL_GROUPED_AGG_PANDAS_UDF"
-    case SQL_WINDOW_AGG_PANDAS_UDF => "SQL_WINDOW_AGG_PANDAS_UDF"
-    case SQL_SCALAR_PANDAS_ITER_UDF => "SQL_SCALAR_PANDAS_ITER_UDF"
-    case SQL_MAP_PANDAS_ITER_UDF => "SQL_MAP_PANDAS_ITER_UDF"
-    case SQL_COGROUPED_MAP_PANDAS_UDF => "SQL_COGROUPED_MAP_PANDAS_UDF"
-    case SQL_MAP_ARROW_ITER_UDF => "SQL_MAP_ARROW_ITER_UDF"
-    case SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE => 
"SQL_GROUPED_MAP_PANDAS_UDF_WITH_STATE"
-    case SQL_TABLE_UDF => "SQL_TABLE_UDF"
-    case SQL_ARROW_TABLE_UDF => "SQL_ARROW_TABLE_UDF"
-  }
-}
-
 private object SedonaBasePythonRunner {
 
   private lazy val faultHandlerLogDir = Utils.createTempDir(namePrefix = 
"faulthandler")
@@ -92,10 +57,11 @@ private object SedonaBasePythonRunner {
  * functions (from bottom to top).
  */
 private[spark] abstract class SedonaBasePythonRunner[IN, OUT](
-                                                         protected val funcs: 
Seq[ChainedPythonFunctions],
-                                                         protected val 
evalType: Int,
-                                                         protected val 
argOffsets: Array[Array[Int]],
-                                                         protected val 
jobArtifactUUID: Option[String])
+                                                               protected val 
funcs: Seq[ChainedPythonFunctions],
+                                                               protected val 
evalType: Int,
+                                                               protected val 
argOffsets: Array[Array[Int]],
+                                                               protected val 
jobArtifactUUID: Option[String],
+                                                               schema: 
StructType)
   extends Logging {
 
   require(funcs.length == argOffsets.length, "argOffsets should have the same 
length as funcs")
@@ -279,9 +245,43 @@ private[spark] abstract class SedonaBasePythonRunner[IN, 
OUT](
 
     override def run(): Unit = Utils.logUncaughtExceptions {
       try {
+        println("ssss")
+        val toReadCRS = inputIterator.buffered.headOption.flatMap(
+          el => el.asInstanceOf[Iterator[IN]].buffered.headOption
+        )
+
+        val row = toReadCRS match {
+          case Some(value) => value match {
+            case row: GenericInternalRow =>
+              Some(row)
+          }
+          case None => None
+        }
+
+        val geometryFields = schema.zipWithIndex.filter {
+          case (field, index) => field.dataType == GeometryUDT
+        }.map {
+          case (field, index) =>
+            if (row.isEmpty || row.get.values(index) == null) (index, 0) else {
+              val geom = row.get.get(index, 
GeometryUDT).asInstanceOf[Array[Byte]]
+              val preambleByte = geom(0) & 0xFF
+              val hasSrid = (preambleByte & 0x01) != 0
+
+              var srid = 0
+              if (hasSrid) {
+                val srid2 = (geom(1) & 0xFF) << 16
+                val srid1 = (geom(2) & 0xFF) << 8
+                val srid0 = geom(3) & 0xFF
+                srid = srid2 | srid1 | srid0
+              }
+              (index, srid)
+            }
+        }
+
         TaskContext.setTaskContext(context)
         val stream = new BufferedOutputStream(worker.getOutputStream, 
bufferSize)
         val dataOut = new DataOutputStream(stream)
+
         // Partition index
         dataOut.writeInt(partitionIndex)
         // Python version of driver
@@ -401,6 +401,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, 
OUT](
         val needsDecryptionServer = env.serializerManager.encryptionEnabled && 
addedBids.nonEmpty
         dataOut.writeBoolean(needsDecryptionServer)
         dataOut.writeInt(cnt)
+
         def sendBidsToRemove(): Unit = {
           for (bid <- toRemove) {
             // remove the broadcast from worker
@@ -408,6 +409,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, 
OUT](
             oldBids.remove(bid)
           }
         }
+
         if (needsDecryptionServer) {
           // if there is encryption, we setup a server which reads the 
encrypted files, and sends
           // the decrypted data to python
@@ -447,6 +449,15 @@ private[spark] abstract class SedonaBasePythonRunner[IN, 
OUT](
 
         dataOut.writeInt(evalType)
         writeCommand(dataOut)
+
+        // write number of geometry fields
+        dataOut.writeInt(geometryFields.length)
+        // write geometry field indices and their SRIDs
+        geometryFields.foreach { case (index, srid) =>
+          dataOut.writeInt(index)
+          dataOut.writeInt(srid)
+        }
+
         writeIteratorToStream(dataOut)
 
         dataOut.writeInt(SpecialLengths.END_OF_STREAM)
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
index 27e4b851ee..c3eafc9766 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala
@@ -38,7 +38,7 @@ class SedonaArrowPythonRunner(
                          val pythonMetrics: Map[String, SQLMetric],
                          jobArtifactUUID: Option[String])
   extends SedonaBasePythonRunner[Iterator[InternalRow], ColumnarBatch](
-    funcs, evalType, argOffsets, jobArtifactUUID)
+    funcs, evalType, argOffsets, jobArtifactUUID, schema)
     with SedonaBasicPythonArrowInput
     with SedonaBasicPythonArrowOutput {
 
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
index ec4f7c00d0..6c1bb9edd5 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
@@ -40,10 +40,11 @@ private[sql] object SedonaArrowUtils {
                     largeVarTypes: Boolean = false): Field = {
     dt match {
       case GeometryUDT =>
-        val jsonData = """{"crs": {"$schema": 
"https://proj.org/schemas/v0.7/projjson.schema.json";, "type": "GeographicCRS", 
"name": "WGS 84", "datum_ensemble": {"name": "World Geodetic System 1984 
ensemble", "members": [{"name": "World Geodetic System 1984 (Transit)", "id": 
{"authority": "EPSG", "code": 1166}}, {"name": "World Geodetic System 1984 
(G730)", "id": {"authority": "EPSG", "code": 1152}}, {"name": "World Geodetic 
System 1984 (G873)", "id": {"authority": "EPSG", "code": 1153}} [...]
+//        val jsonData = """{"crs": {"$schema": 
"https://proj.org/schemas/v0.7/projjson.schema.json";, "type": "GeographicCRS", 
"name": "WGS 84", "datum_ensemble": {"name": "World Geodetic System 1984 
ensemble", "members": [{"name": "World Geodetic System 1984 (Transit)", "id": 
{"authority": "EPSG", "code": 1166}}, {"name": "World Geodetic System 1984 
(G730)", "id": {"authority": "EPSG", "code": 1152}}, {"name": "World Geodetic 
System 1984 (G873)", "id": {"authority": "EPSG", "code": 1153 [...]
         val metadata = Map(
-          "ARROW:extension:name" -> "geoarrow.wkb",
-          "ARROW:extension:metadata" -> jsonData,
+          "empty" -> "empty",
+//          "ARROW:extension:name" -> "geoarrow.wkb",
+//          "ARROW:extension:metadata" -> jsonData,
         ).asJava
 
         val fieldType = new FieldType(nullable, ArrowType.Binary.INSTANCE, 
null, metadata)
@@ -77,6 +78,7 @@ private[sql] object SedonaArrowUtils {
         val fieldType = new FieldType(nullable, toArrowType(dataType, 
timeZoneId,
           largeVarTypes), null)
         new Field(name, fieldType, Seq.empty[Field].asJava)
+      case _ => toArrowField(name, dt, nullable, timeZoneId, largeVarTypes)
     }
   }
 
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
index 8a5e241c51..bf353539bc 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
@@ -77,6 +77,8 @@ private[python] trait SedonaPythonArrowInput[IN] { self: 
SedonaBasePythonRunner[
                                           inputIterator: Iterator[IN],
                                           partitionIndex: Int,
                                           context: TaskContext): WriterThread 
= {
+//    createWorkerThread(env, worker, inputIterator, partitionIndex, context, 
schema)
+
     new WriterThread(env, worker, inputIterator, partitionIndex, context) {
 
       protected override def writeCommand(dataOut: DataOutputStream): Unit = {
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
index ced32cf801..beb49c1dde 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonUDFRunner.scala
@@ -35,7 +35,7 @@ abstract class SedonaBasePythonUDFRunner(
                                     pythonMetrics: Map[String, SQLMetric],
                                     jobArtifactUUID: Option[String])
   extends SedonaBasePythonRunner[Array[Byte], Array[Byte]](
-    funcs, evalType, argOffsets, jobArtifactUUID) {
+    funcs, evalType, argOffsets, jobArtifactUUID, null) {
 
   override val pythonExec: String =
     SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala
new file mode 100644
index 0000000000..41fc67df3a
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaThread1.scala
@@ -0,0 +1,285 @@
+package org.apache.spark.sql.execution.python
+
+import org.apache.arrow.vector.VectorSchemaRoot
+import org.apache.arrow.vector.ipc.ArrowStreamWriter
+import org.apache.spark.{BarrierTaskContext, SparkEnv, SparkException, 
SparkFiles, TaskContext}
+import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, 
BasePythonRunner, EncryptedPythonBroadcastServer, PythonRDD, 
SedonaBasePythonRunner, SpecialLengths}
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.Utils
+
+import java.io.{BufferedOutputStream, DataInputStream, DataOutputStream, File}
+import java.net.{InetAddress, ServerSocket, Socket, SocketException}
+import java.nio.charset.StandardCharsets.UTF_8
+import scala.util.control.NonFatal
+//
+//trait WorkerThreadFactory[IN] {
+//  self: BasePythonRunner [IN, _] =>
+//
+//  // Define common functionality for worker threads here
+//
+//  def createWorkerThread(
+//                          env: SparkEnv,
+//                          worker: Socket,
+//                          inputIterator: Iterator[IN],
+//                          partitionIndex: Int,
+//                          context: TaskContext,
+//                          schema: StructType,
+//                        ): WriterThread = {
+//    new SedonaThread1 (env, worker, inputIterator, partitionIndex, context, 
schema)
+//  }
+//
+//  class SedonaThread1(
+//                       env: SparkEnv,
+//                       worker: Socket,
+//                       inputIterator: Iterator[IN],
+//                       partitionIndex: Int,
+//                       context: TaskContext,
+//                       schema: StructType,
+//                     ) extends WriterThread(env, worker, inputIterator, 
partitionIndex, context) {
+//
+//
+//    override def run(): Unit = Utils.logUncaughtExceptions {
+//      try {
+//        val toReadCRS = inputIterator.buffered.headOption.flatMap(
+//          el => el.asInstanceOf[Iterator[IN]].buffered.headOption
+//        )
+//
+//        val row = toReadCRS match {
+//          case Some(value) => value match {
+//            case row: GenericInternalRow =>
+//              Some(row)
+//          }
+//          case None => None
+//        }
+//
+//        val geometryFields = schema.zipWithIndex.filter {
+//          case (field, index) => field.dataType == GeometryUDT
+//        }.map {
+//          case (field, index) =>
+//            if (row.isEmpty || row.get.values(index) == null) (index, 0) 
else {
+//              val geom = row.get.get(index, 
GeometryUDT).asInstanceOf[Array[Byte]]
+//              val preambleByte = geom(0) & 0xFF
+//              val hasSrid = (preambleByte & 0x01) != 0
+//
+//              var srid = 0
+//              if (hasSrid) {
+//                val srid2 = (geom(1) & 0xFF) << 16
+//                val srid1 = (geom(2) & 0xFF) << 8
+//                val srid0 = geom(3) & 0xFF
+//                srid = srid2 | srid1 | srid0
+//              }
+//              (index, srid)
+//            }
+//        }
+//
+//        TaskContext.setTaskContext(context)
+//        val stream = new BufferedOutputStream(worker.getOutputStream, 
bufferSize)
+//        val dataOut = new DataOutputStream(stream)
+//
+//        // Partition index
+//        dataOut.writeInt(partitionIndex)
+//        // Python version of driver
+//        PythonRDD.writeUTF(pythonVer, dataOut)
+//        // Init a ServerSocket to accept method calls from Python side.
+//        val isBarrier = context.isInstanceOf[BarrierTaskContext]
+//        if (isBarrier) {
+//          serverSocket = Some(new ServerSocket(/* port */ 0,
+//            /* backlog */ 1,
+//            InetAddress.getByName("localhost")))
+//          // A call to accept() for ServerSocket shall block infinitely.
+//          serverSocket.foreach(_.setSoTimeout(0))
+//          new Thread("accept-connections") {
+//            setDaemon(true)
+//
+//            override def run(): Unit = {
+//              while (!serverSocket.get.isClosed()) {
+//                var sock: Socket = null
+//                try {
+//                  sock = serverSocket.get.accept()
+//                  // Wait for function call from python side.
+//                  sock.setSoTimeout(10000)
+//                  authHelper.authClient(sock)
+//                  val input = new DataInputStream(sock.getInputStream())
+//                  val requestMethod = input.readInt()
+//                  // The BarrierTaskContext function may wait infinitely, 
socket shall not timeout
+//                  // before the function finishes.
+//                  sock.setSoTimeout(10000)
+//                  requestMethod match {
+//                    case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION 
=>
+//                      barrierAndServe(requestMethod, sock)
+//                    case 
BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
+//                      val length = input.readInt()
+//                      val message = new Array[Byte](length)
+//                      input.readFully(message)
+//                      barrierAndServe(requestMethod, sock, new 
String(message, UTF_8))
+//                    case _ =>
+//                      val out = new DataOutputStream(new 
BufferedOutputStream(
+//                        sock.getOutputStream))
+//                      
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
+//                  }
+//                } catch {
+//                  case e: SocketException if e.getMessage.contains("Socket 
closed") =>
+//                  // It is possible that the ServerSocket is not closed, but 
the native socket
+//                  // has already been closed, we shall catch and silently 
ignore this case.
+//                } finally {
+//                  if (sock != null) {
+//                    sock.close()
+//                  }
+//                }
+//              }
+//            }
+//          }.start()
+//        }
+//        val secret = if (isBarrier) {
+//          authHelper.secret
+//        } else {
+//          ""
+//        }
+//        // Close ServerSocket on task completion.
+//        serverSocket.foreach { server =>
+//          context.addTaskCompletionListener[Unit](_ => server.close())
+//        }
+//        val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
+//        if (boundPort == -1) {
+//          val message = "ServerSocket failed to bind to Java side."
+//          logError(message)
+//          throw new SparkException(message)
+//        } else if (isBarrier) {
+//          logDebug(s"Started ServerSocket on port $boundPort.")
+//        }
+//        // Write out the TaskContextInfo
+//        dataOut.writeBoolean(isBarrier)
+//        dataOut.writeInt(boundPort)
+//        val secretBytes = secret.getBytes(UTF_8)
+//        dataOut.writeInt(secretBytes.length)
+//        dataOut.write(secretBytes, 0, secretBytes.length)
+//        dataOut.writeInt(context.stageId())
+//        dataOut.writeInt(context.partitionId())
+//        dataOut.writeInt(context.attemptNumber())
+//        dataOut.writeLong(context.taskAttemptId())
+//        dataOut.writeInt(context.cpus())
+//        val resources = context.resources()
+//        dataOut.writeInt(resources.size)
+//        resources.foreach { case (k, v) =>
+//          PythonRDD.writeUTF(k, dataOut)
+//          PythonRDD.writeUTF(v.name, dataOut)
+//          dataOut.writeInt(v.addresses.size)
+//          v.addresses.foreach { case addr =>
+//            PythonRDD.writeUTF(addr, dataOut)
+//          }
+//        }
+//        val localProps = context.getLocalProperties.asScala
+//        dataOut.writeInt(localProps.size)
+//        localProps.foreach { case (k, v) =>
+//          PythonRDD.writeUTF(k, dataOut)
+//          PythonRDD.writeUTF(v, dataOut)
+//        }
+//
+//        // sparkFilesDir
+//        val root = jobArtifactUUID.map { uuid =>
+//          new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath
+//        }.getOrElse(SparkFiles.getRootDirectory())
+//        PythonRDD.writeUTF(root, dataOut)
+//        // Python includes (*.zip and *.egg files)
+//        dataOut.writeInt(pythonIncludes.size)
+//        for (include <- pythonIncludes) {
+//          PythonRDD.writeUTF(include, dataOut)
+//        }
+//        // Broadcast variables
+//        val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+//        val newBids = broadcastVars.map(_.id).toSet
+//        // number of different broadcasts
+//        val toRemove = oldBids.diff(newBids)
+//        val addedBids = newBids.diff(oldBids)
+//        val cnt = toRemove.size + addedBids.size
+//        val needsDecryptionServer = env.serializerManager.encryptionEnabled 
&& addedBids.nonEmpty
+//        dataOut.writeBoolean(needsDecryptionServer)
+//        dataOut.writeInt(cnt)
+//
+//        def sendBidsToRemove(): Unit = {
+//          for (bid <- toRemove) {
+//            // remove the broadcast from worker
+//            dataOut.writeLong(-bid - 1) // bid >= 0
+//            oldBids.remove(bid)
+//          }
+//        }
+//
+//        if (needsDecryptionServer) {
+//          // if there is encryption, we setup a server which reads the 
encrypted files, and sends
+//          // the decrypted data to python
+//          val idsAndFiles = broadcastVars.flatMap { broadcast =>
+//            if (!oldBids.contains(broadcast.id)) {
+//              oldBids.add(broadcast.id)
+//              Some((broadcast.id, broadcast.value.path))
+//            } else {
+//              None
+//            }
+//          }
+//          val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
+//          dataOut.writeInt(server.port)
+//          logTrace(s"broadcast decryption server setup on ${server.port}")
+//          PythonRDD.writeUTF(server.secret, dataOut)
+//          sendBidsToRemove()
+//          idsAndFiles.foreach { case (id, _) =>
+//            // send new broadcast
+//            dataOut.writeLong(id)
+//          }
+//          dataOut.flush()
+//          logTrace("waiting for python to read decrypted broadcast data from 
server")
+//          server.waitTillBroadcastDataSent()
+//          logTrace("done sending decrypted data to python")
+//        } else {
+//          sendBidsToRemove()
+//          for (broadcast <- broadcastVars) {
+//            if (!oldBids.contains(broadcast.id)) {
+//              // send new broadcast
+//              dataOut.writeLong(broadcast.id)
+//              PythonRDD.writeUTF(broadcast.value.path, dataOut)
+//              oldBids.add(broadcast.id)
+//            }
+//          }
+//        }
+//        dataOut.flush()
+//
+//        dataOut.writeInt(evalType)
+//        writeCommand(dataOut)
+//
+//        // write number of geometry fields
+//        dataOut.writeInt(geometryFields.length)
+//        // write geometry field indices and their SRIDs
+//        geometryFields.foreach { case (index, srid) =>
+//          dataOut.writeInt(index)
+//          dataOut.writeInt(srid)
+//        }
+//
+//        writeIteratorToStream(dataOut)
+//
+//        dataOut.writeInt(SpecialLengths.END_OF_STREAM)
+//        dataOut.flush()
+//      } catch {
+//        case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) =>
+//          if (context.isCompleted || context.isInterrupted) {
+//            logDebug("Exception/NonFatal Error thrown after task completion 
(likely due to " +
+//              "cleanup)", t)
+//            if (!worker.isClosed) {
+//              Utils.tryLog(worker.shutdownOutput())
+//            }
+//          } else {
+//            // We must avoid throwing exceptions/NonFatals here, because the 
thread uncaught
+//            // exception handler will kill the whole executor (see
+//            // org.apache.spark.executor.Executor).
+//            _exception = t
+//            if (!worker.isClosed) {
+//              Utils.tryLog(worker.shutdownOutput())
+//            }
+//          }
+//      }
+//    }
+//
+//    override protected def writeCommand(dataOut: DataOutputStream): Unit = 
???
+//
+//    override protected def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = ???
+//  }
+//}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala
new file mode 100644
index 0000000000..b28cf81906
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaWriterThread.scala
@@ -0,0 +1,349 @@
+package org.apache.spark.sql.execution.python
+
+
+import org.apache.sedona.common.geometrySerde.CoordinateType
+import org.apache.spark._
+import org.apache.spark.SedonaSparkEnv
+import org.apache.spark.api.python.{BarrierTaskContextMessageProtocol, 
BasePythonRunner, ChainedPythonFunctions, EncryptedPythonBroadcastServer, 
PythonRDD, SedonaBasePythonRunner, SpecialLengths}
+import org.apache.spark.internal.Logging
+import org.apache.spark.internal.config.Python._
+import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
+import 
org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, 
PYSPARK_MEMORY_LOCAL_PROPERTY}
+import org.apache.spark.security.SocketAuthHelper
+import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator}
+import org.apache.spark.util._
+
+import java.io._
+import java.net._
+import java.nio.charset.StandardCharsets
+import java.nio.charset.StandardCharsets.UTF_8
+import java.nio.file.{Path, Files => JavaFiles}
+import java.util.concurrent.atomic.AtomicBoolean
+import scala.collection.JavaConverters._
+import scala.util.control.NonFatal
+import org.apache.spark.sql.catalyst.expressions.GenericInternalRow
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.types.StructType
+
+abstract class SedonaWriterThread[IN, OUT](
+                                            env: SparkEnv,
+                                            worker: Socket,
+                                            inputIterator: Iterator[IN],
+                                            partitionIndex: Int,
+                                            context: TaskContext,
+                                            pythonExec: String,
+                                            schema: 
org.apache.spark.sql.types.StructType,
+                                          )
+  extends Thread(s"stdout writer for $pythonExec") with Logging {
+  self: BasePythonRunner[IN, _] =>
+
+  @volatile private var _exception: Throwable = null
+  private val conf = SparkEnv.get.conf
+  private lazy val authHelper = new SocketAuthHelper(conf)
+  private val pythonIncludes = 
funcs.flatMap(_.funcs.flatMap(_.pythonIncludes.asScala)).toSet
+  private val broadcastVars = 
funcs.flatMap(_.funcs.flatMap(_.broadcastVars.asScala))
+
+  setDaemon(true)
+
+  /** Contains the throwable thrown while writing the parent iterator to the 
Python process. */
+  def exception: Option[Throwable] = Option(_exception)
+
+  /**
+   * Terminates the writer thread and waits for it to exit, ignoring any 
exceptions that may occur
+   * due to cleanup.
+   */
+  def shutdownOnTaskCompletion(): Unit = {
+    assert(context.isCompleted)
+    this.interrupt()
+    // Task completion listeners that run after this method returns may 
invalidate
+    // `inputIterator`. For example, when `inputIterator` was generated by the 
off-heap vectorized
+    // reader, a task completion listener will free the underlying off-heap 
buffers. If the writer
+    // thread is still running when `inputIterator` is invalidated, it can 
cause a use-after-free
+    // bug that crashes the executor (SPARK-33277). Therefore this method must 
wait for the writer
+    // thread to exit before returning.
+    this.join()
+  }
+
+  /**
+   * Writes a command section to the stream connected to the Python worker.
+   */
+  protected def writeCommand(dataOut: DataOutputStream): Unit
+
+  /**
+   * Writes input data to the stream connected to the Python worker.
+   */
+  protected def writeIteratorToStream(dataOut: DataOutputStream): Unit
+
+  override def run(): Unit = Utils.logUncaughtExceptions {
+    try {
+      println("ssss")
+      val toReadCRS = inputIterator.buffered.headOption.flatMap(
+        el => el.asInstanceOf[Iterator[IN]].buffered.headOption
+      )
+
+      val row = toReadCRS match {
+        case Some(value) => value match {
+          case row: GenericInternalRow =>
+            Some(row)
+        }
+        case None => None
+      }
+
+      val geometryFields = schema.zipWithIndex.filter {
+        case (field, index) => field.dataType == GeometryUDT
+      }.map {
+        case (field, index) =>
+          if (row.isEmpty || row.get.values(index) == null) (index, 0) else {
+            val geom = row.get.get(index, 
GeometryUDT).asInstanceOf[Array[Byte]]
+            val preambleByte = geom(0) & 0xFF
+            val hasSrid = (preambleByte & 0x01) != 0
+
+            var srid = 0
+            if (hasSrid) {
+              val srid2 = (geom(1) & 0xFF) << 16
+              val srid1 = (geom(2) & 0xFF) << 8
+              val srid0 = geom(3) & 0xFF
+              srid = srid2 | srid1 | srid0
+            }
+            (index, srid)
+          }
+      }
+
+      TaskContext.setTaskContext(context)
+      val stream = new BufferedOutputStream(worker.getOutputStream, bufferSize)
+      val dataOut = new DataOutputStream(stream)
+
+      // Partition index
+      dataOut.writeInt(partitionIndex)
+      // Python version of driver
+      PythonRDD.writeUTF(pythonVer, dataOut)
+      // Init a ServerSocket to accept method calls from Python side.
+      val isBarrier = context.isInstanceOf[BarrierTaskContext]
+      if (isBarrier) {
+        serverSocket = Some(new ServerSocket(/* port */ 0,
+          /* backlog */ 1,
+          InetAddress.getByName("localhost")))
+        // A call to accept() for ServerSocket shall block infinitely.
+        serverSocket.foreach(_.setSoTimeout(0))
+        new Thread("accept-connections") {
+          setDaemon(true)
+
+          override def run(): Unit = {
+            while (!serverSocket.get.isClosed()) {
+              var sock: Socket = null
+              try {
+                sock = serverSocket.get.accept()
+                // Wait for function call from python side.
+                sock.setSoTimeout(10000)
+                authHelper.authClient(sock)
+                val input = new DataInputStream(sock.getInputStream())
+                val requestMethod = input.readInt()
+                // The BarrierTaskContext function may wait infinitely, socket 
shall not timeout
+                // before the function finishes.
+                sock.setSoTimeout(10000)
+                requestMethod match {
+                  case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
+                    barrierAndServe(requestMethod, sock)
+                  case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
+                    val length = input.readInt()
+                    val message = new Array[Byte](length)
+                    input.readFully(message)
+                    barrierAndServe(requestMethod, sock, new String(message, 
UTF_8))
+                  case _ =>
+                    val out = new DataOutputStream(new BufferedOutputStream(
+                      sock.getOutputStream))
+                    
writeUTF(BarrierTaskContextMessageProtocol.ERROR_UNRECOGNIZED_FUNCTION, out)
+                }
+              } catch {
+                case e: SocketException if e.getMessage.contains("Socket 
closed") =>
+                // It is possible that the ServerSocket is not closed, but the 
native socket
+                // has already been closed, we shall catch and silently ignore 
this case.
+              } finally {
+                if (sock != null) {
+                  sock.close()
+                }
+              }
+            }
+          }
+        }.start()
+      }
+      val secret = if (isBarrier) {
+        authHelper.secret
+      } else {
+        ""
+      }
+      // Close ServerSocket on task completion.
+      serverSocket.foreach { server =>
+        context.addTaskCompletionListener[Unit](_ => server.close())
+      }
+      val boundPort: Int = serverSocket.map(_.getLocalPort).getOrElse(0)
+      if (boundPort == -1) {
+        val message = "ServerSocket failed to bind to Java side."
+        logError(message)
+        throw new SparkException(message)
+      } else if (isBarrier) {
+        logDebug(s"Started ServerSocket on port $boundPort.")
+      }
+      // Write out the TaskContextInfo
+      dataOut.writeBoolean(isBarrier)
+      dataOut.writeInt(boundPort)
+      val secretBytes = secret.getBytes(UTF_8)
+      dataOut.writeInt(secretBytes.length)
+      dataOut.write(secretBytes, 0, secretBytes.length)
+      dataOut.writeInt(context.stageId())
+      dataOut.writeInt(context.partitionId())
+      dataOut.writeInt(context.attemptNumber())
+      dataOut.writeLong(context.taskAttemptId())
+      dataOut.writeInt(context.cpus())
+      val resources = context.resources()
+      dataOut.writeInt(resources.size)
+      resources.foreach { case (k, v) =>
+        PythonRDD.writeUTF(k, dataOut)
+        PythonRDD.writeUTF(v.name, dataOut)
+        dataOut.writeInt(v.addresses.size)
+        v.addresses.foreach { case addr =>
+          PythonRDD.writeUTF(addr, dataOut)
+        }
+      }
+      val localProps = context.getLocalProperties.asScala
+      dataOut.writeInt(localProps.size)
+      localProps.foreach { case (k, v) =>
+        PythonRDD.writeUTF(k, dataOut)
+        PythonRDD.writeUTF(v, dataOut)
+      }
+
+      // sparkFilesDir
+      val root = jobArtifactUUID.map { uuid =>
+        new File(SparkFiles.getRootDirectory(), uuid).getAbsolutePath
+      }.getOrElse(SparkFiles.getRootDirectory())
+      PythonRDD.writeUTF(root, dataOut)
+      // Python includes (*.zip and *.egg files)
+      dataOut.writeInt(pythonIncludes.size)
+      for (include <- pythonIncludes) {
+        PythonRDD.writeUTF(include, dataOut)
+      }
+      // Broadcast variables
+      val oldBids = PythonRDD.getWorkerBroadcasts(worker)
+      val newBids = broadcastVars.map(_.id).toSet
+      // number of different broadcasts
+      val toRemove = oldBids.diff(newBids)
+      val addedBids = newBids.diff(oldBids)
+      val cnt = toRemove.size + addedBids.size
+      val needsDecryptionServer = env.serializerManager.encryptionEnabled && 
addedBids.nonEmpty
+      dataOut.writeBoolean(needsDecryptionServer)
+      dataOut.writeInt(cnt)
+
+      def sendBidsToRemove(): Unit = {
+        for (bid <- toRemove) {
+          // remove the broadcast from worker
+          dataOut.writeLong(-bid - 1) // bid >= 0
+          oldBids.remove(bid)
+        }
+      }
+
+      if (needsDecryptionServer) {
+        // if there is encryption, we setup a server which reads the encrypted 
files, and sends
+        // the decrypted data to python
+        val idsAndFiles = broadcastVars.flatMap { broadcast =>
+          if (!oldBids.contains(broadcast.id)) {
+            oldBids.add(broadcast.id)
+            Some((broadcast.id, broadcast.value.path))
+          } else {
+            None
+          }
+        }
+        val server = new EncryptedPythonBroadcastServer(env, idsAndFiles)
+        dataOut.writeInt(server.port)
+        logTrace(s"broadcast decryption server setup on ${server.port}")
+        PythonRDD.writeUTF(server.secret, dataOut)
+        sendBidsToRemove()
+        idsAndFiles.foreach { case (id, _) =>
+          // send new broadcast
+          dataOut.writeLong(id)
+        }
+        dataOut.flush()
+        logTrace("waiting for python to read decrypted broadcast data from 
server")
+        server.waitTillBroadcastDataSent()
+        logTrace("done sending decrypted data to python")
+      } else {
+        sendBidsToRemove()
+        for (broadcast <- broadcastVars) {
+          if (!oldBids.contains(broadcast.id)) {
+            // send new broadcast
+            dataOut.writeLong(broadcast.id)
+            PythonRDD.writeUTF(broadcast.value.path, dataOut)
+            oldBids.add(broadcast.id)
+          }
+        }
+      }
+      dataOut.flush()
+
+      dataOut.writeInt(evalType)
+      writeCommand(dataOut)
+
+      // write number of geometry fields
+      dataOut.writeInt(geometryFields.length)
+      // write geometry field indices and their SRIDs
+      geometryFields.foreach { case (index, srid) =>
+        dataOut.writeInt(index)
+        dataOut.writeInt(srid)
+      }
+
+      writeIteratorToStream(dataOut)
+
+      dataOut.writeInt(SpecialLengths.END_OF_STREAM)
+      dataOut.flush()
+    } catch {
+      case t: Throwable if (NonFatal(t) || t.isInstanceOf[Exception]) =>
+        if (context.isCompleted || context.isInterrupted) {
+          logDebug("Exception/NonFatal Error thrown after task completion 
(likely due to " +
+            "cleanup)", t)
+          if (!worker.isClosed) {
+            Utils.tryLog(worker.shutdownOutput())
+          }
+        } else {
+          // We must avoid throwing exceptions/NonFatals here, because the 
thread uncaught
+          // exception handler will kill the whole executor (see
+          // org.apache.spark.executor.Executor).
+          _exception = t
+          if (!worker.isClosed) {
+            Utils.tryLog(worker.shutdownOutput())
+          }
+        }
+    }
+  }
+
+  /**
+   * Gateway to call BarrierTaskContext methods.
+   */
+  def barrierAndServe(requestMethod: Int, sock: Socket, message: String = ""): 
Unit = {
+    require(
+      serverSocket.isDefined,
+      "No available ServerSocket to redirect the BarrierTaskContext method 
call."
+    )
+    val out = new DataOutputStream(new 
BufferedOutputStream(sock.getOutputStream))
+    try {
+      val messages = requestMethod match {
+        case BarrierTaskContextMessageProtocol.BARRIER_FUNCTION =>
+          context.asInstanceOf[BarrierTaskContext].barrier()
+          Array(BarrierTaskContextMessageProtocol.BARRIER_RESULT_SUCCESS)
+        case BarrierTaskContextMessageProtocol.ALL_GATHER_FUNCTION =>
+          context.asInstanceOf[BarrierTaskContext].allGather(message)
+      }
+      out.writeInt(messages.length)
+      messages.foreach(writeUTF(_, out))
+    } catch {
+      case e: SparkException =>
+        writeUTF(e.getMessage, out)
+    } finally {
+      out.close()
+    }
+  }
+
+  def writeUTF(str: String, dataOut: DataOutputStream): Unit = {
+    val bytes = str.getBytes(UTF_8)
+    dataOut.writeInt(bytes.length)
+    dataOut.write(bytes)
+  }
+}
+
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
index 391587e586..c92af92cdd 100644
--- 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
@@ -20,8 +20,8 @@ package org.apache.spark.sql.udf
 
 import org.apache.sedona.spark.SedonaContext
 import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.functions.{col, expr}
-import org.apache.spark.sql.udf.ScalarUDF.{geometryToGeometryFunction}
+import org.apache.spark.sql.functions.{col, expr, lit}
+import org.apache.spark.sql.udf.ScalarUDF.geometryToGeometryFunction
 import org.locationtech.jts.io.WKTReader
 import org.scalatest.funsuite.AnyFunSuite
 import org.scalatest.matchers.should.Matchers
@@ -47,14 +47,29 @@ class StrategySuite extends AnyFunSuite with Matchers {
 //    spark.sql("select 1").show()
     val df = spark.read.format("geoparquet")
       
.load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings")
-      .selectExpr("ST_Centroid(geometry) AS geometry")
+      .withColumn("geometry", expr("ST_SetSRID(geometry, '4326')"))
+
+    df.show()
 
       df
       .select(
+        col("id"),
+        col("version"),
+        col("bbox"),
 //        geometryToNonGeometryFunction(col("geometry")),
-        geometryToGeometryFunction(col("geometry")),
+        geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom"),
 //        nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")),
-      ).show(10, false)
+      ).show(10)
+
+    println(df
+      .select(
+        col("id"),
+        col("version"),
+        col("bbox"),
+        //        geometryToNonGeometryFunction(col("geometry")),
+        geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom"),
+        //        nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")),
+      ).count())
 
 //    df.show()
     1 shouldBe 1


Reply via email to