This is an automated email from the ASF dual-hosted git repository.

imbruced pushed a commit to branch arrow-worker
in repository https://gitbox.apache.org/repos/asf/sedona.git

commit 3874a5ab5f743a8010d06c773fff13a44b95b6bc
Author: pawelkocinski <[email protected]>
AuthorDate: Thu Nov 13 22:19:41 2025 +0100

    SEDONA-748 add working example
---
 .../sql/execution/python/SedonaArrowStrategy.scala | 171 +--------------------
 .../sql/execution/python/SedonaArrowUtils.scala    |  64 +-------
 .../execution/python/SedonaPythonArrowInput.scala  |   1 +
 3 files changed, 6 insertions(+), 230 deletions(-)

diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala
index 3869ab24b8..ff3c027c5d 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala
@@ -20,32 +20,19 @@ package org.apache.spark.sql.execution.python
 
 import org.apache.sedona.sql.UDF.PythonEvalType
 import org.apache.spark.api.python.ChainedPythonFunctions
-import org.apache.spark.rdd.RDD
 import org.apache.spark.sql.Strategy
 import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.InternalRow.copyValue
 import org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
-import org.apache.spark.sql.catalyst.expressions.{GenericInternalRow, 
UnsafeProjection}
-import org.apache.spark.sql.catalyst.expressions.UnsafeProjection.createObject
+import org.apache.spark.sql.catalyst.expressions.UnsafeProjection
 import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.vectorized.{ColumnarBatchRow, ColumnarRow}
-//import 
org.apache.spark.sql.catalyst.expressions.BindReferences.bindReferences
-//import 
org.apache.spark.sql.catalyst.expressions.codegen.GenerateUnsafeProjection
-import org.apache.spark.sql.catalyst.expressions.{Attribute, BoundReference, 
CodeGeneratorWithInterpretedFallback, Expression, InterpretedUnsafeProjection, 
JoinedRow, MutableProjection, Projection, PythonUDF, UnsafeRow}
+import org.apache.spark.sql.catalyst.expressions.{Attribute, Expression, 
PythonUDF}
 import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
-import org.apache.spark.sql.catalyst.util.GenericArrayData
 import org.apache.spark.sql.execution.SparkPlan
 import org.apache.spark.sql.internal.SQLConf
-import org.apache.spark.sql.types.{DataType, StructField, StructType}
+import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.udf.SedonaArrowEvalPython
-import org.apache.spark.util.Utils
-import org.apache.spark.{ContextAwareIterator, JobArtifactSet, SparkEnv, 
TaskContext}
-import org.locationtech.jts.io.WKTReader
-import org.apache.spark.sql.catalyst.expressions.codegen.BufferHolder
-import java.io.File
+import org.apache.spark.{JobArtifactSet, TaskContext}
 import scala.collection.JavaConverters.asScalaIteratorConverter
-import scala.collection.mutable.ArrayBuffer
 
 // We use custom Strategy to avoid Apache Spark assert on types, we
 // can consider extending this to support other engines working with
@@ -58,16 +45,6 @@ class SedonaArrowStrategy extends Strategy {
   }
 }
 
-/**
- * The factory object for `UnsafeProjection`.
- */
-object SedonaUnsafeProjection {
-
-  def create(exprs: Seq[Expression], inputSchema: Seq[Attribute]): 
UnsafeProjection = {
-    GenerateUnsafeProjection.generate(bindReferences(exprs, inputSchema), 
SQLConf.get.subexpressionEliminationEnabled)
-//    createObject(bindReferences(exprs, inputSchema))
-  }
-}
 // It's modification og Apache Spark's ArrowEvalPythonExec, we remove the 
check on the types to allow geometry types
 // here, it's initial version to allow the vectorized udf for Sedona geometry 
types. We can consider extending this
 // to support other engines working with arrow data
@@ -112,144 +89,4 @@ case class SedonaArrowEvalPythonExec(
 
   override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
     copy(child = newChild)
-
-  private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, 
Seq[Expression]) = {
-    udf.children match {
-      case Seq(u: PythonUDF) =>
-        val (chained, children) = collectFunctions(u)
-        (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children)
-      case children =>
-        // There should not be any other UDFs, or the children can't be 
evaluated directly.
-        assert(children.forall(!_.exists(_.isInstanceOf[PythonUDF])))
-        (ChainedPythonFunctions(Seq(udf.func)), udf.children)
-    }
-  }
-
-  override def doExecute(): RDD[InternalRow] = {
-
-      val customProjection = new Projection with Serializable{
-        def apply(row: InternalRow): InternalRow = {
-          row match {
-            case joinedRow: JoinedRow =>
-              val arrowField = 
joinedRow.getRight.asInstanceOf[ColumnarBatchRow]
-              val left = joinedRow.getLeft
-
-
-//              resultAttrs.zipWithIndex.map {
-//                case (x, y) =>
-//                  if (x.dataType.isInstanceOf[GeometryUDT]) {
-//                    val wkbReader = new org.locationtech.jts.io.WKBReader()
-//                    wkbReader.read(left.getBinary(y))
-//
-//                    println("ssss")
-//                  }
-//                  GeometryUDT
-//                  left.getByte(y)
-//
-//                  left.setByte(y, 1.toByte)
-//
-//                  println(left.getByte(y))
-//              }
-//
-//              println("ssss")
-//              arrowField.
-              row
-              // We need to convert JoinedRow to UnsafeRow
-//              val leftUnsafe = left.asInstanceOf[UnsafeRow]
-//              val rightUnsafe = right.asInstanceOf[UnsafeRow]
-//              val joinedUnsafe = new UnsafeRow(leftUnsafe.numFields + 
rightUnsafe.numFields)
-//              joinedUnsafe.pointTo(
-//                leftUnsafe.getBaseObject, leftUnsafe.getBaseOffset,
-//                leftUnsafe.getSizeInBytes + rightUnsafe.getSizeInBytes)
-//              joinedUnsafe.setLeft(rightUnsafe)
-//              joinedUnsafe.setRight(leftUnsafe)
-//              joinedUnsafe
-//              val wktReader = new WKTReader()
-              val resultProj = SedonaUnsafeProjection.create(output, output)
-//              val WKBWriter = new org.locationtech.jts.io.WKBWriter()
-              resultProj(new JoinedRow(left, arrowField))
-            case _ =>
-              println(row.getClass)
-              throw new UnsupportedOperationException("Unsupported row type")
-          }
-        }
-      }
-    val inputRDD = child.execute().map(_.copy())
-
-    inputRDD.mapPartitions { iter =>
-      val context = TaskContext.get()
-      val contextAwareIterator = new ContextAwareIterator(context, iter)
-
-      // The queue used to buffer input rows so we can drain it to
-      // combine input with output from Python.
-      val queue = HybridRowQueue(context.taskMemoryManager(),
-        new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length)
-      context.addTaskCompletionListener[Unit] { ctx =>
-        queue.close()
-      }
-
-      val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip
-
-      // flatten all the arguments
-      val allInputs = new ArrayBuffer[Expression]
-      val dataTypes = new ArrayBuffer[DataType]
-      val argOffsets = inputs.map { input =>
-        input.map { e =>
-          if (allInputs.exists(_.semanticEquals(e))) {
-            allInputs.indexWhere(_.semanticEquals(e))
-          } else {
-            allInputs += e
-            dataTypes += e.dataType
-            allInputs.length - 1
-          }
-        }.toArray
-      }.toArray
-      val projection = MutableProjection.create(allInputs.toSeq, child.output)
-      projection.initialize(context.partitionId())
-      val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) =>
-        StructField(s"_$i", dt)
-      }.toArray)
-
-      // Add rows to queue to join later with the result.
-      val projectedRowIter = contextAwareIterator.map { inputRow =>
-        queue.add(inputRow.asInstanceOf[UnsafeRow])
-        projection(inputRow)
-      }
-
-      val outputRowIterator = evaluate(
-        pyFuncs, argOffsets, projectedRowIter, schema, context)
-
-      val joined = new JoinedRow
-
-      outputRowIterator.map { outputRow =>
-        val joinedRow = joined(queue.remove(), outputRow)
-
-        val projected = customProjection(joinedRow)
-
-        val numFields = projected.numFields
-        val startField = numFields - resultAttrs.length
-        println(resultAttrs.length)
-
-        val row = new GenericInternalRow(numFields)
-
-        resultAttrs.zipWithIndex.map {
-          case (attr, index) =>
-            if (attr.dataType.isInstanceOf[GeometryUDT]) {
-              // Convert the geometry type to WKB
-              val wkbReader = new org.locationtech.jts.io.WKBReader()
-              val wkbWriter = new org.locationtech.jts.io.WKBWriter()
-              val geom = wkbReader.read(projected.getBinary(startField + 
index))
-
-              row.update(startField + index, wkbWriter.write(geom))
-
-              println("ssss")
-            }
-        }
-
-        println("ssss")
-//        3.2838116E-8
-        row
-      }
-    }
-  }
 }
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
index 58166d173d..ec4f7c00d0 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowUtils.scala
@@ -21,71 +21,16 @@ import java.util.concurrent.atomic.AtomicInteger
 import scala.collection.JavaConverters._
 import org.apache.arrow.memory.RootAllocator
 import org.apache.arrow.vector.complex.MapVector
-import org.apache.arrow.vector.types.{DateUnit, FloatingPointPrecision, 
IntervalUnit, TimeUnit}
 import org.apache.arrow.vector.types.pojo.{ArrowType, Field, FieldType, Schema}
 import org.apache.spark.sql.errors.ExecutionErrors
 import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.types._
+import org.apache.spark.sql.util.ArrowUtils.{fromArrowType, toArrowType}
 
 private[sql] object SedonaArrowUtils {
 
   val rootAllocator = new RootAllocator(Long.MaxValue)
 
-  // todo: support more types.
-
-  /** Maps data type from Spark to Arrow. NOTE: timeZoneId required for 
TimestampTypes */
-  def toArrowType(
-                   dt: DataType, timeZoneId: String, largeVarTypes: Boolean = 
false): ArrowType = dt match {
-    case BooleanType => ArrowType.Bool.INSTANCE
-    case ByteType => new ArrowType.Int(8, true)
-    case ShortType => new ArrowType.Int(8 * 2, true)
-    case IntegerType => new ArrowType.Int(8 * 4, true)
-    case LongType => new ArrowType.Int(8 * 8, true)
-    case FloatType => new 
ArrowType.FloatingPoint(FloatingPointPrecision.SINGLE)
-    case DoubleType => new 
ArrowType.FloatingPoint(FloatingPointPrecision.DOUBLE)
-    case StringType if !largeVarTypes => ArrowType.Utf8.INSTANCE
-    case BinaryType if !largeVarTypes => ArrowType.Binary.INSTANCE
-    case StringType if largeVarTypes => ArrowType.LargeUtf8.INSTANCE
-    case BinaryType if largeVarTypes => ArrowType.LargeBinary.INSTANCE
-    case DecimalType.Fixed(precision, scale) => new 
ArrowType.Decimal(precision, scale)
-    case DateType => new ArrowType.Date(DateUnit.DAY)
-    case TimestampType if timeZoneId == null =>
-      throw new IllegalStateException("Missing timezoneId where it is 
mandatory.")
-    case TimestampType => new ArrowType.Timestamp(TimeUnit.MICROSECOND, 
timeZoneId)
-    case TimestampNTZType =>
-      new ArrowType.Timestamp(TimeUnit.MICROSECOND, null)
-    case NullType => ArrowType.Null.INSTANCE
-    case _: YearMonthIntervalType => new 
ArrowType.Interval(IntervalUnit.YEAR_MONTH)
-    case _: DayTimeIntervalType => new ArrowType.Duration(TimeUnit.MICROSECOND)
-    case _ =>
-      throw ExecutionErrors.unsupportedDataTypeError(dt)
-  }
-
-  def fromArrowType(dt: ArrowType): DataType = dt match {
-    case ArrowType.Bool.INSTANCE => BooleanType
-    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 => 
ByteType
-    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 2 => 
ShortType
-    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 4 => 
IntegerType
-    case int: ArrowType.Int if int.getIsSigned && int.getBitWidth == 8 * 8 => 
LongType
-    case float: ArrowType.FloatingPoint
-      if float.getPrecision() == FloatingPointPrecision.SINGLE => FloatType
-    case float: ArrowType.FloatingPoint
-      if float.getPrecision() == FloatingPointPrecision.DOUBLE => DoubleType
-    case ArrowType.Utf8.INSTANCE => StringType
-    case ArrowType.Binary.INSTANCE => BinaryType
-    case ArrowType.LargeUtf8.INSTANCE => StringType
-    case ArrowType.LargeBinary.INSTANCE => BinaryType
-    case d: ArrowType.Decimal => DecimalType(d.getPrecision, d.getScale)
-    case date: ArrowType.Date if date.getUnit == DateUnit.DAY => DateType
-    case ts: ArrowType.Timestamp
-      if ts.getUnit == TimeUnit.MICROSECOND && ts.getTimezone == null => 
TimestampNTZType
-    case ts: ArrowType.Timestamp if ts.getUnit == TimeUnit.MICROSECOND => 
TimestampType
-    case ArrowType.Null.INSTANCE => NullType
-    case yi: ArrowType.Interval if yi.getUnit == IntervalUnit.YEAR_MONTH => 
YearMonthIntervalType()
-    case di: ArrowType.Duration if di.getUnit == TimeUnit.MICROSECOND => 
DayTimeIntervalType()
-    case _ => throw ExecutionErrors.unsupportedArrowTypeError(dt)
-  }
-
   /** Maps field from Spark to Arrow. NOTE: timeZoneId required for 
TimestampType */
   def toArrowField(
                     name: String,
@@ -172,13 +117,6 @@ private[sql] object SedonaArrowUtils {
     }.asJava)
   }
 
-  def fromArrowSchema(schema: Schema): StructType = {
-    StructType(schema.getFields.asScala.map { field =>
-      val dt = fromArrowField(field)
-      StructField(field.getName, dt, field.isNullable)
-    }.toArray)
-  }
-
   private def deduplicateFieldNames(
                                      dt: DataType, 
errorOnDuplicatedFieldNames: Boolean): DataType = dt match {
     case geometryType: GeometryUDT => geometryType
diff --git 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
index 6791015ae9..8a5e241c51 100644
--- 
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
+++ 
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala
@@ -26,6 +26,7 @@ import org.apache.spark.sql.execution.arrow.ArrowWriter
 import org.apache.spark.sql.execution.metric.SQLMetric
 import org.apache.spark.sql.types.StructType
 import org.apache.spark.sql.util.ArrowUtils
+import org.apache.spark.sql.util.ArrowUtils.toArrowSchema
 import org.apache.spark.util.Utils
 import org.apache.spark.{SparkEnv, TaskContext}
 

Reply via email to