This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch arrow-worker
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit b4a344398a0ea90e27e575c4a69dfdb55f02fd8b
Author: pawelkocinski <[email protected]>
AuthorDate: Sun Aug 3 00:39:23 2025 +0200

    SEDONA-738 Fix unit tests.
---
 .../common/geometrySerde/GeometrySerializer.java   |  28 +++
 .../org/apache/sedona/spark/SedonaContext.scala    |   2 +-
 .../spark/api/python/SedonaPythonRunner.scala      | 130 -----------
 .../sql/execution/python/SedonaArrowStrategy.scala | 255 +++++++++++++++++++++
 .../execution/python/SedonaPythonArrowOutput.scala |   3 +-
 .../apache/spark/sql/udf/SedonaArrowStrategy.scala |  89 -------
 .../org/apache/spark/sql/udf/StrategySuite.scala   |   3 +-
 7 files changed, 288 insertions(+), 222 deletions(-)

diff --git 
a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
 
b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
index 325098c6ac..ba135aa6a1 100644
--- 
a/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
+++ 
b/common/src/main/java/org/apache/sedona/common/geometrySerde/GeometrySerializer.java
@@ -63,6 +63,29 @@ public class GeometrySerializer {
 //    return buffer.toByteArray();
   }
 
+  public static byte[] serializeLegacy(Geometry geometry) {
+    GeometryBuffer buffer;
+    if (geometry instanceof Point) {
+      buffer = serializePoint((Point) geometry);
+    } else if (geometry instanceof MultiPoint) {
+      buffer = serializeMultiPoint((MultiPoint) geometry);
+    } else if (geometry instanceof LineString) {
+      buffer = serializeLineString((LineString) geometry);
+    } else if (geometry instanceof MultiLineString) {
+      buffer = serializeMultiLineString((MultiLineString) geometry);
+    } else if (geometry instanceof Polygon) {
+      buffer = serializePolygon((Polygon) geometry);
+    } else if (geometry instanceof MultiPolygon) {
+      buffer = serializeMultiPolygon((MultiPolygon) geometry);
+    } else if (geometry instanceof GeometryCollection) {
+      buffer = serializeGeometryCollection((GeometryCollection) geometry);
+    } else {
+      throw new UnsupportedOperationException(
+          "Geometry type is not supported: " + 
geometry.getClass().getSimpleName());
+    }
+    return buffer.toByteArray();
+  }
+
   public static Geometry deserialize(byte[] bytes) {
     WKBReader reader = new WKBReader();
     try {
@@ -74,6 +97,11 @@ public class GeometrySerializer {
 //    return deserialize(buffer);
   }
 
+  public static Geometry deserializeLegacy(byte[] bytes) {
+    GeometryBuffer buffer = GeometryBufferFactory.wrap(bytes);
+    return deserialize(buffer);
+  }
+
   public static Geometry deserialize(GeometryBuffer buffer) {
     return deserialize(buffer, null);
   }
diff --git 
a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala 
b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
index b0e46cf6e9..c9e8497f7e 100644
--- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
@@ -72,7 +72,7 @@ object SedonaContext {
 
     val sedonaArrowStrategy = Try(
       Class
-        .forName("org.apache.spark.sql.udf.SedonaArrowStrategy")
+        .forName("org.apache.spark.sql.execution.python.SedonaArrowStrategy")
         .getDeclaredConstructor()
         .newInstance()
         .asInstanceOf[SparkStrategy])
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
index fb01e62b5e..6656d85f5c 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/api/python/SedonaPythonRunner.scala
@@ -19,8 +19,6 @@ package org.apache.spark.api.python
 
 import org.apache.spark._
 import org.apache.spark.SedonaSparkEnv
-import org.apache.spark.api.python.PythonRDD.writeUTF
-import org.apache.spark.input.PortableDataStream
 import org.apache.spark.internal.Logging
 import org.apache.spark.internal.config.Python._
 import org.apache.spark.internal.config.{BUFFER_SIZE, EXECUTOR_CORES}
@@ -33,7 +31,6 @@ import java.net._
 import java.nio.charset.StandardCharsets
 import java.nio.charset.StandardCharsets.UTF_8
 import java.nio.file.{Path, Files => JavaFiles}
-import java.util.concurrent.ConcurrentHashMap
 import java.util.concurrent.atomic.AtomicBoolean
 import scala.collection.JavaConverters._
 import scala.util.control.NonFatal
@@ -712,130 +709,3 @@ private[spark] abstract class SedonaBasePythonRunner[IN, 
OUT](
     }
   }
 }
-
-private[spark] object PythonRunner {
-
-  // already running worker monitor threads for worker and task attempts ID 
pairs
-  val runningMonitorThreads = ConcurrentHashMap.newKeySet[(Socket, Long)]()
-
-  private var printPythonInfo: AtomicBoolean = new AtomicBoolean(true)
-
-  def apply(func: PythonFunction, jobArtifactUUID: Option[String]): 
PythonRunner = {
-    if (printPythonInfo.compareAndSet(true, false)) {
-      PythonUtils.logPythonInfo(func.pythonExec)
-    }
-    new PythonRunner(Seq(ChainedPythonFunctions(Seq(func))), jobArtifactUUID)
-  }
-}
-
-/**
- * A helper class to run Python mapPartition in Spark.
- */
-private[spark] class PythonRunner(
-                                   funcs: Seq[ChainedPythonFunctions], 
jobArtifactUUID: Option[String])
-  extends BasePythonRunner[Array[Byte], Array[Byte]](
-    funcs, PythonEvalType.NON_UDF, Array(Array(0)), jobArtifactUUID) {
-
-  protected override def newWriterThread(
-                                          env: SparkEnv,
-                                          worker: Socket,
-                                          inputIterator: Iterator[Array[Byte]],
-                                          partitionIndex: Int,
-                                          context: TaskContext): WriterThread 
= {
-    new WriterThread(env, worker, inputIterator, partitionIndex, context) {
-
-      protected override def writeCommand(dataOut: DataOutputStream): Unit = {
-        val command = funcs.head.funcs.head.command
-        dataOut.writeInt(command.length)
-        dataOut.write(command.toArray)
-      }
-
-      protected override def writeIteratorToStream(dataOut: DataOutputStream): 
Unit = {
-        GeoArrowWriter.writeIteratorToStream(inputIterator, dataOut)
-        dataOut.writeInt(SpecialLengths.END_OF_DATA_SECTION)
-      }
-    }
-  }
-
-  protected override def newReaderIterator(
-                                            stream: DataInputStream,
-                                            writerThread: WriterThread,
-                                            startTime: Long,
-                                            env: SparkEnv,
-                                            worker: Socket,
-                                            pid: Option[Int],
-                                            releasedOrClosed: AtomicBoolean,
-                                            context: TaskContext): 
Iterator[Array[Byte]] = {
-    new ReaderIterator(
-      stream, writerThread, startTime, env, worker, pid, releasedOrClosed, 
context) {
-
-      protected override def read(): Array[Byte] = {
-        if (writerThread.exception.isDefined) {
-          throw writerThread.exception.get
-        }
-        try {
-          stream.readInt() match {
-            case length if length > 0 =>
-              val obj = new Array[Byte](length)
-              stream.readFully(obj)
-              obj
-            case 0 => Array.emptyByteArray
-            case SpecialLengths.TIMING_DATA =>
-              handleTimingData()
-              read()
-            case SpecialLengths.PYTHON_EXCEPTION_THROWN =>
-              throw handlePythonException()
-            case SpecialLengths.END_OF_DATA_SECTION =>
-              handleEndOfDataSection()
-              null
-          }
-        } catch handleException
-      }
-    }
-  }
-}
-
-private[spark] object SpecialLengths {
-  val END_OF_DATA_SECTION = -1
-  val PYTHON_EXCEPTION_THROWN = -2
-  val TIMING_DATA = -3
-  val END_OF_STREAM = -4
-  val NULL = -5
-  val START_ARROW_STREAM = -6
-  val END_OF_MICRO_BATCH = -7
-}
-
-private[spark] object BarrierTaskContextMessageProtocol {
-  val BARRIER_FUNCTION = 1
-  val ALL_GATHER_FUNCTION = 2
-  val BARRIER_RESULT_SUCCESS = "success"
-  val ERROR_UNRECOGNIZED_FUNCTION = "Not recognized function call from python 
side."
-}
-
-object GeoArrowWriter extends Logging {
-  def writeIteratorToStream[T](iter: Iterator[T], dataOut: DataOutputStream): 
Unit = {
-
-    def write(obj: Any): Unit = obj match {
-      case null =>
-        dataOut.writeInt(SpecialLengths.NULL)
-      case arr: Array[Byte] =>
-        logError("some random array")
-        dataOut.writeInt(arr.length)
-        dataOut.write(arr)
-      case str: String =>
-        logError("some random string")
-        writeUTF(str, dataOut)
-      case stream: PortableDataStream =>
-        logError("some random stream")
-        write(stream.toArray())
-      case (key, value) =>
-        logError("some random key value")
-        write(key)
-        write(value)
-      case other =>
-        throw new SparkException("Unexpected element type " + other.getClass)
-    }
-
-    iter.foreach(write)
-  }
-}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala
new file mode 100644
index 0000000000..3869ab24b8
--- /dev/null
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala
@@ -0,0 +1,255 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.execution.python
+
+import org.apache.sedona.sql.UDF.PythonEvalType
+import org.apache.spark.api.python.ChainedPythonFunctions
+import org.apache.spark.rdd.RDD
+import org.apache.spark.sql.Strategy
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.InternalRow.copyValue
+import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection}
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection.createObject
+import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.sql.vectorized.{ColumnarBatchRow, ColumnarRow}
+//import 
org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
+//import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
+import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
CodeGeneratorWithInterpretedFallback, Expression, InterpretedUnsafeProjection, 
JoinedRow, MutableProjection, Projection, PythonUDF, UnsafeRow}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.util.GenericArrayData
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.internal.SQLConf
+import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.udf.SedonaArrowEvalPython
+import org.apache.spark.util.Utils
+import org.apache.spark.{ContextAwareIterator, JobArtifactSet, SparkEnv, 
TaskContext}
+import org.locationtech.jts.io.WKTReader
+import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder
+import java.io.File
+import scala.collection.JavaConverters.asScalaIteratorConverter
+import scala.collection.mutable.ArrayBuffer
+
+// We use custom Strategy to avoid Apache Spark assert on types, we
+// can consider extending this to support other engines working with
+// arrow data
+class SedonaArrowStrategy extends Strategy {
+  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+    case SedonaArrowEvalPython(udfs, output, child, evalType) =>
+      SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: 
Nil
+    case _ => Nil
+  }
+}
+
+/**
+ * The factory object for `UnsafeProjection`.
+ */
+object SedonaUnsafeProjection {
+
+  def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): 
UnsafeProjection = {
+    GenerateUnsafeProjection.generate(bindReferences(exprs, inputSchema), 
SQLConf.get.subexpressionEliminationEnabled)
+//    createObject(bindReferences(exprs, inputSchema))
+  }
+}
+// It's modification og Apache Spark's ArrowEvalPythonExec, we remove the 
check on the types to allow geometry types
+// here, it's initial version to allow the vectorized udf for Sedona geometry 
types. We can consider extending this
+// to support other engines working with arrow data
+case class SedonaArrowEvalPythonExec(
+    udfs: Seq[PythonUDF],
+    resultAttrs: Seq[Attribute],
+    child: SparkPlan,
+    evalType: Int)
+    extends EvalPythonExec
+    with PythonSQLMetrics {
+
+  private val batchSize = conf.arrowMaxRecordsPerBatch
+  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+  private val largeVarTypes = conf.arrowUseLargeVarTypes
+  private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
+  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+
+  protected override def evaluate(
+      funcs: Seq[ChainedPythonFunctions],
+      argOffsets: Array[Array[Int]],
+      iter: Iterator[InternalRow],
+      schema: StructType,
+      context: TaskContext): Iterator[InternalRow] = {
+
+    val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else 
Iterator(iter)
+
+    val columnarBatchIter = new SedonaArrowPythonRunner(
+      funcs,
+      evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT,
+      argOffsets,
+      schema,
+      sessionLocalTimeZone,
+      largeVarTypes,
+      pythonRunnerConf,
+      pythonMetrics,
+      jobArtifactUUID).compute(batchIter, context.partitionId(), context)
+
+    columnarBatchIter.flatMap { batch =>
+      batch.rowIterator.asScala
+    }
+  }
+
+  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+    copy(child = newChild)
+
+  private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, 
Seq[Expression]) = {
+    udf.children match {
+      case Seq(u: PythonUDF) =>
+        val (chained, children) = collectFunctions(u)
+        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
+      case children =>
+        // There should not be any other UDFs, or the children can't be 
evaluated directly.
+        assert(children.forall(!_.exists(_.isInstanceOf[PythonUDF])))
+        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
+    }
+  }
+
+  override def doExecute(): RDD[InternalRow] = {
+
+      val customProjection = new Projection with Serializable{
+        def apply(row: InternalRow): InternalRow = {
+          row match {
+            case joinedRow: JoinedRow =>
+              val arrowField = 
joinedRow.getRight.asInstanceOf[ColumnarBatchRow]
+              val left = joinedRow.getLeft
+
+
+//              resultAttrs.zipWithIndex.map {
+//                case (x, y) =>
+//                  if (x.dataType.isInstanceOf[GeometryUDT]) {
+//                    val wkbReader = new org.locationtech.jts.io.WKBReader()
+//                    wkbReader.read(left.getBinary(y))
+//
+//                    println("ssss")
+//                  }
+//                  GeometryUDT
+//                  left.getByte(y)
+//
+//                  left.setByte(y, 1.toByte)
+//
+//                  println(left.getByte(y))
+//              }
+//
+//              println("ssss")
+//              arrowField.
+              row
+              // We need to convert JoinedRow to UnsafeRow
+//              val leftUnsafe = left.asInstanceOf[UnsafeRow]
+//              val rightUnsafe = right.asInstanceOf[UnsafeRow]
+//              val joinedUnsafe = new UnsafeRow(leftUnsafe.numFields + 
rightUnsafe.numFields)
+//              joinedUnsafe.pointTo(
+//                leftUnsafe.getBaseObject, leftUnsafe.getBaseOffset,
+//                leftUnsafe.getSizeInBytes + rightUnsafe.getSizeInBytes)
+//              joinedUnsafe.setLeft(rightUnsafe)
+//              joinedUnsafe.setRight(leftUnsafe)
+//              joinedUnsafe
+//              val wktReader = new WKTReader()
+              val resultProj = SedonaUnsafeProjection.create(output, output)
+//              val WKBWriter = new org.locationtech.jts.io.WKBWriter()
+              resultProj(new JoinedRow(left, arrowField))
+            case _ =>
+              println(row.getClass)
+              throw new UnsupportedOperationException("Unsupported row type")
+          }
+        }
+      }
+    val inputRDD = child.execute().map(_.copy())
+
+    inputRDD.mapPartitions { iter =>
+      val context = TaskContext.get()
+      val contextAwareIterator = new ContextAwareIterator(context, iter)
+
+      // The queue used to buffer input rows so we can drain it to
+      // combine input with output from Python.
+      val queue = HybridRowQueue(context.taskMemoryManager(),
+        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
+      context.addTaskCompletionListener[Unit] { ctx =>
+        queue.close()
+      }
+
+      val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
+
+      // flatten all the arguments
+      val allInputs = new ArrayBuffer[Expression]
+      val dataTypes = new ArrayBuffer[DataType]
+      val argOffsets = inputs.map { input =>
+        input.map { e =>
+          if (allInputs.exists(_.semanticEquals(e))) {
+            allInputs.indexWhere(_.semanticEquals(e))
+          } else {
+            allInputs += e
+            dataTypes += e.dataType
+            allInputs.length - 1
+          }
+        }.toArray
+      }.toArray
+      val projection = MutableProjection.create(allInputs.toSeq, child.output)
+      projection.initialize(context.partitionId())
+      val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
+        StructField(s"_$i", dt)
+      }.toArray)
+
+      // Add rows to queue to join later with the result.
+      val projectedRowIter = contextAwareIterator.map { inputRow =>
+        queue.add(inputRow.asInstanceOf[UnsafeRow])
+        projection(inputRow)
+      }
+
+      val outputRowIterator = evaluate(
+        pyFuncs, argOffsets, projectedRowIter, schema, context)
+
+      val joined = new JoinedRow
+
+      outputRowIterator.map { outputRow =>
+        val joinedRow = joined(queue.remove(), outputRow)
+
+        val projected = customProjection(joinedRow)
+
+        val numFields = projected.numFields
+        val startField = numFields - resultAttrs.length
+        println(resultAttrs.length)
+
+        val row = new GenericInternalRow(numFields)
+
+        resultAttrs.zipWithIndex.map {
+          case (attr, index) =>
+            if (attr.dataType.isInstanceOf[GeometryUDT]) {
+              // Convert the geometry type to WKB
+              val wkbReader = new org.locationtech.jts.io.WKBReader()
+              val wkbWriter = new org.locationtech.jts.io.WKBWriter()
+              val geom = wkbReader.read(projected.getBinary(startField + 
index))
+
+              row.update(startField + index, wkbWriter.write(geom))
+
+              println("ssss")
+            }
+        }
+
+        println("ssss")
+//        3.2838116E-8
+        row
+      }
+    }
+  }
+}
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala
index 91e840da58..f2c8543537 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala
@@ -93,7 +93,8 @@ private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] 
{ self: SedonaBaseP
               val bytesReadEnd = reader.bytesRead()
               pythonMetrics("pythonNumRowsReceived") += rowCount
               pythonMetrics("pythonDataReceived") += bytesReadEnd - 
bytesReadStart
-              deserializeColumnarBatch(batch, schema)
+              val result = deserializeColumnarBatch(batch, schema)
+              result
             } else {
               reader.close(false)
               allocator.close()
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
deleted file mode 100644
index 5883fd905d..0000000000
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
+++ /dev/null
@@ -1,89 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-package org.apache.spark.sql.udf
-
-import org.apache.sedona.sql.UDF.PythonEvalType
-import org.apache.spark.api.python.ChainedPythonFunctions
-import org.apache.spark.{JobArtifactSet, TaskContext}
-import org.apache.spark.sql.Strategy
-import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF}
-import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.execution.SparkPlan
-import org.apache.spark.sql.execution.python.{ArrowPythonRunner, 
BatchIterator, EvalPythonExec, PythonSQLMetrics, SedonaArrowPythonRunner}
-import org.apache.spark.sql.types.StructType
-
-import scala.collection.JavaConverters.asScalaIteratorConverter
-
-// We use custom Strategy to avoid Apache Spark assert on types, we
-// can consider extending this to support other engines working with
-// arrow data
-class SedonaArrowStrategy extends Strategy {
-  override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
-    case SedonaArrowEvalPython(udfs, output, child, evalType) =>
-      SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: 
Nil
-    case _ => Nil
-  }
-}
-
-// It's modification og Apache Spark's ArrowEvalPythonExec, we remove the 
check on the types to allow geometry types
-// here, it's initial version to allow the vectorized udf for Sedona geometry 
types. We can consider extending this
-// to support other engines working with arrow data
-case class SedonaArrowEvalPythonExec(
-    udfs: Seq[PythonUDF],
-    resultAttrs: Seq[Attribute],
-    child: SparkPlan,
-    evalType: Int)
-    extends EvalPythonExec
-    with PythonSQLMetrics {
-
-  private val batchSize = conf.arrowMaxRecordsPerBatch
-  private val sessionLocalTimeZone = conf.sessionLocalTimeZone
-  private val largeVarTypes = conf.arrowUseLargeVarTypes
-  private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
-  private[this] val jobArtifactUUID = 
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
-
-  protected override def evaluate(
-      funcs: Seq[ChainedPythonFunctions],
-      argOffsets: Array[Array[Int]],
-      iter: Iterator[InternalRow],
-      schema: StructType,
-      context: TaskContext): Iterator[InternalRow] = {
-
-    val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else 
Iterator(iter)
-
-    val columnarBatchIter = new SedonaArrowPythonRunner(
-      funcs,
-      evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT,
-      argOffsets,
-      schema,
-      sessionLocalTimeZone,
-      largeVarTypes,
-      pythonRunnerConf,
-      pythonMetrics,
-      jobArtifactUUID).compute(batchIter, context.partitionId(), context)
-
-    columnarBatchIter.flatMap { batch =>
-      batch.rowIterator.asScala
-    }
-  }
-
-  override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
-    copy(child = newChild)
-}
diff --git 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
index 350e4a515b..77ab4abbb8 100644
--- 
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
+++ 
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
@@ -44,13 +44,14 @@ class StrategySuite extends AnyFunSuite with Matchers {
   import spark.implicits._
 
   test("sedona geospatial UDF") {
-    spark.sql("select 1").show()
+//    spark.sql("select 1").show()
     val df = spark.read.format("parquet")
       
.load("/Users/pawelkocinski/Desktop/projects/sedona-book/apache-sedona-book/book/chapter10/data/buildings/partitioned")
       .select(
         geometryToNonGeometryFunction(col("geometry")),
         geometryToGeometryFunction(col("geometry")),
         nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")),
+        col("geohash")
       )
 
     df.show()

Reply via email to