This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch arrow-worker in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 6ce7b4c7785e805be105087a622a0f7e3f457036 Author: pawelkocinski <[email protected]> AuthorDate: Mon Dec 29 22:12:52 2025 +0100 add code so far --- .../sql/execution/python/EvalPythonExec.scala | 102 +++++++ .../execution/python/SedonaArrowPythonRunner.scala | 12 +- .../sql/execution/python/SedonaArrowStrategy.scala | 59 +++- .../execution/python/SedonaBasePythonRunner.scala | 48 ++-- .../execution/python/SedonaDBWorkerFactory.scala | 6 +- .../execution/python/SedonaPythonArrowInput.scala | 118 +++++--- .../execution/python/SedonaPythonArrowOutput.scala | 251 ++++++++++++++++ .../spark/sql/execution/python/WorkerContext.scala | 9 +- .../org/apache/sedona/sql/TestBaseScala.scala | 2 +- .../org/apache/spark/sql/udf/StrategySuite.scala | 315 ++++++++++----------- .../apache/spark/sql/udf/TestScalarPandasUDF.scala | 170 +++++------ 11 files changed, 764 insertions(+), 328 deletions(-) diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala new file mode 100644 index 0000000000..11cc8c121f --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/EvalPythonExec.scala @@ -0,0 +1,102 @@ +package org.apache.spark.sql.execution.python + +import org.apache.sedona.common.geometrySerde.GeometrySerde +import org.apache.sedona.sql.utils.GeometrySerializer +import org.apache.spark.{ContextAwareIterator, SparkEnv, TaskContext} +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.rdd.RDD +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, AttributeSet, Expression, JoinedRow, MutableProjection, PythonUDF, UnsafeProjection, UnsafeRow} +import org.apache.spark.sql.execution.UnaryExecNode +import org.apache.spark.sql.types.{DataType, StructField, StructType} +import org.apache.spark.util.Utils + +import java.io.File +import scala.collection.mutable.ArrayBuffer + +trait EvalPythonExec extends UnaryExecNode { + def udfs: Seq[PythonUDF] + + def resultAttrs: Seq[Attribute] + + override def output: Seq[Attribute] = child.output ++ resultAttrs + + override def producedAttributes: AttributeSet = AttributeSet(resultAttrs) + + private def collectFunctions(udf: PythonUDF): (ChainedPythonFunctions, Seq[Expression]) = { + udf.children match { + case Seq(u: PythonUDF) => + val (chained, children) = collectFunctions(u) + (ChainedPythonFunctions(chained.funcs ++ Seq(udf.func)), children) + case children => + // There should not be any other UDFs, or the children can't be evaluated directly. + assert(children.forall(!_.exists(_.isInstanceOf[PythonUDF]))) + (ChainedPythonFunctions(Seq(udf.func)), udf.children) + } + } + + protected def evaluate( + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] + + protected override def doExecute(): RDD[InternalRow] = { + val inputRDD = child.execute().map(_.copy()) + + inputRDD.mapPartitions { iter => + val context = TaskContext.get() + val contextAwareIterator = new ContextAwareIterator(context, iter) + + // The queue used to buffer input rows so we can drain it to + // combine input with output from Python. + val queue = HybridRowQueue(context.taskMemoryManager(), + new File(Utils.getLocalDir(SparkEnv.get.conf)), child.output.length) +// context.addTaskCompletionListener[Unit] { ctx => +// queue.close() +// } + + val (pyFuncs, inputs) = udfs.map(collectFunctions).unzip + + // flatten all the arguments + val allInputs = new ArrayBuffer[Expression] + val dataTypes = new ArrayBuffer[DataType] + val argOffsets = inputs.map { input => + input.map { e => + if (allInputs.exists(_.semanticEquals(e))) { + allInputs.indexWhere(_.semanticEquals(e)) + } else { + allInputs += e + dataTypes += e.dataType + allInputs.length - 1 + } + }.toArray + }.toArray + val projection = MutableProjection.create(allInputs.toSeq, child.output) + projection.initialize(context.partitionId()) + val schema = StructType(dataTypes.zipWithIndex.map { case (dt, i) => + StructField(s"_$i", dt) + }.toArray) + + // Add rows to queue to join later with the result. + val projectedRowIter = contextAwareIterator.map { inputRow => + queue.add(inputRow.asInstanceOf[UnsafeRow]) + val proj = projection(inputRow) + proj + } + + val materializedResult = projectedRowIter.toSeq + + val outputRowIterator = evaluate( + pyFuncs, argOffsets, materializedResult.toIterator, schema, context) + + val joined = new JoinedRow + val resultProj = UnsafeProjection.create(output, output) + + outputRowIterator.map { outputRow => + resultProj(joined(queue.remove(), outputRow)) + } + } + } +} diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala index 598cd830e9..e6f7f6ddf9 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowPythonRunner.scala @@ -44,7 +44,7 @@ class SedonaArrowPythonRunner( argOffsets, jobArtifactUUID) with SedonaBasicPythonArrowInput - with BasicPythonArrowOutput { + with SedonaBasicPythonArrowOutput { override val pythonExec: String = SQLConf.get.pysparkWorkerPythonExecutable.getOrElse(funcs.head.funcs.head.pythonExec) @@ -53,9 +53,9 @@ class SedonaArrowPythonRunner( override val simplifiedTraceback: Boolean = SQLConf.get.pysparkSimplifiedTraceback - override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize - require( - bufferSize >= 4, - "Pandas execution requires more than 4 bytes. Please set higher buffer. " + - s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") +// override val bufferSize: Int = SQLConf.get.pandasUDFBufferSize +// require( +// bufferSize >= 4, +// "Pandas execution requires more than 4 bytes. Please set higher buffer. " + +// s"Please change '${SQLConf.PANDAS_UDF_BUFFER_SIZE.key}'.") } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala index c025653029..05bed6a138 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaArrowStrategy.scala @@ -29,6 +29,9 @@ import org.apache.spark.sql.execution.SparkPlan import org.apache.spark.sql.types.StructType import org.apache.spark.sql.udf.SedonaArrowEvalPython import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.vectorized.ColumnarBatch +import org.apache.spark.sql.catalyst.InternalRow import scala.collection.JavaConverters.asScalaIteratorConverter @@ -67,6 +70,8 @@ case class SedonaArrowEvalPythonExec( schema: StructType, context: TaskContext): Iterator[InternalRow] = { + val outputTypes = output.drop(child.output.length).map(_.dataType) + val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) evalType match { @@ -82,9 +87,45 @@ case class SedonaArrowEvalPythonExec( pythonMetrics, jobArtifactUUID).compute(batchIter, context.partitionId(), context) - columnarBatchIter.flatMap { batch => +// val size = columnarBatchIter.size +// val iter = columnarBatchIter.foreach { batch => +// processBatch(batch) +// } +// +// println("sss") +// val data = columnarBatchIter.flatMap { batch => +// batch.rowIterator.asScala +// } +// +// val seqData = data.toSeq +// +// val seqDataSize = seqData.size +// val seqDataLength = seqData.length +// println("ssss") + +// columnarBatchIter.flatMap { batch => +// batch.rowIterator.asScala +// } + + val result = columnarBatchIter.flatMap { batch => +// val actualDataTypes = (0 until batch.numCols()).map(i => batch.column(i).dataType()) +// assert(outputTypes == actualDataTypes, "Invalid schema from pandas_udf: " + +// s"expected ${outputTypes.mkString(", ")}, got ${actualDataTypes.mkString(", ")}") batch.rowIterator.asScala } +// +// try{ +// val first = result.next().toSeq(schema) +// } catch { +// case e: Exception => { +// println("No data returned from Sedona DB UDF") +// } +// } +// +// val first = result.next().toSeq(schema) + + println("ssss") + return result case SQL_SCALAR_SEDONA_UDF => val columnarBatchIter = new ArrowPythonRunner( @@ -98,12 +139,26 @@ case class SedonaArrowEvalPythonExec( pythonMetrics, jobArtifactUUID).compute(batchIter, context.partitionId(), context) - columnarBatchIter.flatMap { batch => + val iter = columnarBatchIter.flatMap { batch => batch.rowIterator.asScala } +// +// iter.map( +// row => { +// processBatch(row) +// } +// ) +// +// val seqData = iter.toList +// println(seqData.head.getClass) + + println("SedonaArrowEvalPythonExec: Executing Sedona DB UDF") +// iter + iter } } override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = copy(child = newChild) } + diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala index b3e0878c91..c61a8addc6 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaBasePythonRunner.scala @@ -31,6 +31,8 @@ import org.apache.spark.internal.Logging import org.apache.spark.internal.config.EXECUTOR_CORES import org.apache.spark.internal.config.Python._ import org.apache.spark.resource.ResourceProfile.{EXECUTOR_CORES_LOCAL_PROPERTY, PYSPARK_MEMORY_LOCAL_PROPERTY} +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.vectorized.ColumnarBatch import org.apache.spark.util._ private object SedonaBasePythonRunner { @@ -66,6 +68,8 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( mem.map(_ / cores) } + import java.io._ + override def compute( inputIterator: Iterator[IN], partitionIndex: Int, @@ -85,7 +89,7 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( } envVars.put("SPARK_LOCAL_DIRS", localdir) // it's also used in monitor thread if (reuseWorker) { - envVars.put("SPARK_REUSE_WORKER", "1") + envVars.put("SPARK_REUSE_WORKER", "-1") } if (simplifiedTraceback) { envVars.put("SPARK_SIMPLIFIED_TRACEBACK", "1") @@ -105,15 +109,18 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( envVars.put("SPARK_JOB_ARTIFACT_UUID", jobArtifactUUID.getOrElse("default")) - val (worker: Socket, pid: Option[Int]) = + val (worker: Socket, pid: Option[Int]) = { WorkerContext.createPythonWorker(pythonExec, envVars.asScala.toMap) + } + + println("Sedona worker port: " + worker.getPort()) // Whether is the worker released into idle pool or closed. When any codes try to release or // close a worker, they should use `releasedOrClosed.compareAndSet` to flip the state to make // sure there is only one winner that is going to release or close the worker. val releasedOrClosed = new AtomicBoolean(false) // Start a thread to feed the process input from our parent's iterator - val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) + val writerThread = newWriterThread(env, worker, inputIterator, partitionIndex, context) context.addTaskCompletionListener[Unit] { _ => writerThread.shutdownOnTaskCompletion() @@ -128,21 +135,29 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( } writerThread.start() - new SedonaMonitorThread(SparkEnv.get, worker, writerThread, context).start() - if (reuseWorker) { - val key = (worker, context.taskAttemptId) - // SPARK-35009: avoid creating multiple monitor threads for the same python worker - // and task context - if (PythonRunner.runningMonitorThreads.add(key)) { - new MonitorThread(SparkEnv.get, worker, context).start() - } - } else { - new MonitorThread(SparkEnv.get, worker, context).start() - } +// 305996 +// 305997 +// new SedonaMonitorThread(SparkEnv.get, worker, writerThread, context).start() +// if (reuseWorker) { +// val key = (worker, context.taskAttemptId) +// // SPARK-35009: avoid creating multiple monitor threads for the same python worker +// // and task context +// if (PythonRunner.runningMonitorThreads.add(key)) { +// new MonitorThread(SparkEnv.get, worker, context).start() +// } +// } else { +// new MonitorThread(SparkEnv.get, worker, context).start() +// } // Return an iterator that read lines from the process's stdout - val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) +// if (writerThread.isAlive) { +// +// } + val stream = new DataInputStream(new BufferedInputStream(worker.getInputStream, bufferSize)) +// println("worker is closed : " + worker.isClosed) + // write to a file for debug +// writeDataInputStreamToFile(stream, s"/Users/pawelkocinski/Desktop/projects/sedona_java_11/sedona/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/sedona_python_output_${context.taskAttemptId}.bin") val stdoutIterator = newReaderIterator( stream, writerThread, @@ -152,7 +167,8 @@ private[spark] abstract class SedonaBasePythonRunner[IN, OUT]( pid, releasedOrClosed, context) - new InterruptibleIterator(context, stdoutIterator) +// new InterruptibleIterator(context, stdoutIterator) + stdoutIterator } private class SedonaMonitorThread( diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala index db46ff6d8c..6dd23930c8 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaDBWorkerFactory.scala @@ -89,9 +89,9 @@ class SedonaDBWorkerFactory(pythonExec: String, envVars: Map[String, String]) if (pid < 0) { throw new IllegalStateException("Python failed to launch worker with code " + pid) } - self.synchronized { - simpleWorkers.put(socket, worker) - } +// self.synchronized { +// simpleWorkers.put(socket, worker) +// } return (socket, Some(pid)) } catch { case e: Exception => diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala index fee3c22e64..91d7a024e9 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowInput.scala @@ -46,7 +46,7 @@ import org.apache.spark.sql.util.ArrowUtils import org.apache.spark.util.Utils import org.apache.spark.{SparkEnv, TaskContext} -import java.io.DataOutputStream +import java.io.{DataOutputStream, FileOutputStream} import java.net.Socket /** @@ -67,50 +67,56 @@ private[python] trait SedonaPythonArrowInput[IN] extends PythonArrowInput[IN] { handleMetadataBeforeExec(dataOut) writeUDF(dataOut, funcs, argOffsets) - val toReadCRS = inputIterator.buffered.headOption.flatMap(el => - el.asInstanceOf[Iterator[IN]].buffered.headOption) - - val row = toReadCRS match { - case Some(value) => - value match { - case row: GenericInternalRow => - Some(row) - } - case None => None - } - - val geometryFields = schema.zipWithIndex - .filter { case (field, index) => - field.dataType == GeometryUDT - } - .map { case (field, index) => - if (row.isEmpty || row.get.values(index) == null) (index, 0) - else { - val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] - val preambleByte = geom(0) & 0xff - val hasSrid = (preambleByte & 0x01) != 0 - - var srid = 0 - if (hasSrid) { - val srid2 = (geom(1) & 0xff) << 16 - val srid1 = (geom(2) & 0xff) << 8 - val srid0 = geom(3) & 0xff - srid = srid2 | srid1 | srid0 - } - (index, srid) - } - } +// val toReadCRS = inputIterator.buffered.headOption.flatMap(el => +// el.asInstanceOf[Iterator[IN]].buffered.headOption) + +// val row = toReadCRS match { +// case Some(value) => +// value match { +// case row: GenericInternalRow => +// Some(row) +// } +// case None => None +// } + +// val geometryFields = schema.zipWithIndex +// .filter { case (field, index) => +// field.dataType == GeometryUDT +// } +// .map { case (field, index) => +// if (row.isEmpty || row.get.values(index) == null) (index, 0) +// else { +// val geom = row.get.get(index, GeometryUDT).asInstanceOf[Array[Byte]] +// val preambleByte = geom(0) & 0xff +// val hasSrid = (preambleByte & 0x01) != 0 +// +// var srid = 0 +// if (hasSrid) { +// val srid2 = (geom(1) & 0xff) << 16 +// val srid1 = (geom(2) & 0xff) << 8 +// val srid0 = geom(3) & 0xff +// srid = srid2 | srid1 | srid0 +// } +// (index, srid) +// } +// } // write number of geometry fields - dataOut.writeInt(geometryFields.length) +// dataOut.writeInt(geometryFields.length) + dataOut.writeInt(0) // write geometry field indices and their SRIDs - geometryFields.foreach { case (index, srid) => - dataOut.writeInt(index) - dataOut.writeInt(srid) - } +// geometryFields.foreach { case (index, srid) => +// dataOut.writeInt(index) +// dataOut.writeInt(srid) +// } } protected override def writeIteratorToStream(dataOut: DataOutputStream): Unit = { +// val fileOut = new FileOutputStream("/Users/pawelkocinski/Desktop/projects/sedona_java_11/sedona/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/output.dat") + + // 2. Wrap it with DataOutputStream +// val dataOut = new DataOutputStream(fileOut) + val arrowSchema = ArrowUtils.toArrowSchema(schema, timeZoneId, errorOnDuplicatedFieldNames, largeVarTypes) val allocator = ArrowUtils.rootAllocator.newChildAllocator( @@ -122,6 +128,37 @@ private[python] trait SedonaPythonArrowInput[IN] extends PythonArrowInput[IN] { Utils.tryWithSafeFinally { val writer = new ArrowStreamWriter(root, null, dataOut) writer.start() +// val buffered = inputIterator.buffered +// var allValues = 0 +// while (buffered.hasNext) { +// val value = buffered.next() +// val itenralRow = value.asInstanceOf[Iterator[InternalRow]] +// +// val bufferedAll = itenralRow.buffered +// while (bufferedAll.hasNext) { +// val row = bufferedAll.next() +// allValues += 1 +// } +// } +// +// println("Total number of values: " + allValues) +// println("ssss") +// +// for (i <- 0 until inputIterator.length) { +// val value = inputIterator.next() +// val itenralRow = value.asInstanceOf[Iterator[InternalRow]] +// val firstElement = itenralRow.next() +// for (j <- 0 until value.asInstanceOf[Iterator[InternalRow]].length) { +// val row = value.asInstanceOf[Iterator[InternalRow]].next() +// val vector = root.getVector(i) +// println(s"Vector $i: ${vector.getClass.getSimpleName}, name: ${vector.getName}") +// println(s"Row $j: ${row}") +// } +// println("sss") +//// println(value) +//// val vector = root.getVector(i) +//// println(s"Vector $i: ${vector.getClass.getSimpleName}, name: ${vector.getName}") +// } writeIteratorToArrowStream(root, writer, dataOut, inputIterator) @@ -158,6 +195,7 @@ private[python] trait SedonaBasicPythonArrowInput dataOut: DataOutputStream, inputIterator: Iterator[Iterator[InternalRow]]): Unit = { val arrowWriter = ArrowWriter.create(root) + var record = 0 while (inputIterator.hasNext) { val startData = dataOut.size() @@ -172,6 +210,8 @@ private[python] trait SedonaBasicPythonArrowInput arrowWriter.reset() val deltaData = dataOut.size() - startData pythonMetrics("pythonDataSent") += deltaData + record += 1 + println("Written batch number: " + record) } } } diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala new file mode 100644 index 0000000000..50ee3cf17a --- /dev/null +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/SedonaPythonArrowOutput.scala @@ -0,0 +1,251 @@ + +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.spark.sql.execution.python + +import com.univocity.parsers.common.input.EOFException + +import java.io.{DataInputStream, File, FileInputStream, FileOutputStream} +import java.net.Socket +import java.util.concurrent.atomic.AtomicBoolean +import scala.collection.JavaConverters._ +import org.apache.arrow.vector.VectorSchemaRoot +import org.apache.arrow.vector.ipc.ArrowStreamReader +import org.apache.spark.{SparkEnv, TaskContext} +import org.apache.spark.api.python.{BasePythonRunner, SpecialLengths} +import org.apache.spark.sql.execution.metric.SQLMetric +import org.apache.spark.sql.types.StructType +import org.apache.spark.sql.util.ArrowUtils +import org.apache.spark.sql.vectorized.{ArrowColumnVector, ColumnVector, ColumnarBatch} + +/** + * A trait that can be mixed-in with [[BasePythonRunner]]. It implements the logic from + * Python (Arrow) to JVM (output type being deserialized from ColumnarBatch). + */ +private[python] trait SedonaPythonArrowOutput[OUT <: AnyRef] { self: BasePythonRunner[_, OUT] => + + protected def pythonMetrics: Map[String, SQLMetric] + + val openedFile = new File("/Users/pawelkocinski/Desktop/projects/sedonaworker/sedonaworker/output_batch_data.arrow") + val stream2 = new FileInputStream(openedFile) + protected def handleMetadataAfterExec(stream: DataInputStream): Unit = { } + + protected def deserializeColumnarBatch(batch: ColumnarBatch, schema: StructType): OUT + var numberOfReads = 0 + var numberOfReadsData = 1 + + def writeToFile(in: DataInputStream, file: File): Unit = { + val out = new FileOutputStream(file) + try { + val buffer = new Array[Byte](8192) + var bytesRead = 0 + while ({ + bytesRead = in.read(buffer) + bytesRead != -1 + }) { + out.write(buffer, 0, bytesRead) + } + } finally { + out.close() + in.close() + } + } + + protected def newReaderIterator( + stream: DataInputStream, + writerThread: WriterThread, + startTime: Long, + env: SparkEnv, + worker: Socket, + pid: Option[Int], + releasedOrClosed: AtomicBoolean, + context: TaskContext): Iterator[OUT] = { + + new ReaderIterator( + stream, writerThread, startTime, env, worker, pid, releasedOrClosed, context) { + + private val allocator = ArrowUtils.rootAllocator.newChildAllocator( + s"stdin reader for $pythonExec", 0, Long.MaxValue) + + private var reader: ArrowStreamReader = _ + private var root: VectorSchemaRoot = _ + private var schema: StructType = _ + private var vectors: Array[ColumnVector] = _ + private var totalNumberOfRows: Long = 0L + + context.addTaskCompletionListener[Unit] { _ => + if (reader != null) { + reader.close(false) + } + allocator.close() + } + + private var batchLoaded = true + + protected override def handleEndOfDataSection(): Unit = { +// handleMetadataAfterExec(stream) +// super.handleEndOfDataSection() +// worker.close() + WorkerContext.destroyPythonWorker(pythonExec = pythonExec, envVars = envVars.asScala.toMap, worker = worker) + } + + override def hasNext: Boolean = { + val value = numberOfReadsData + + numberOfReadsData -= 1 + value > 0 + } + + override def next(): OUT = { + val result = read() + if (result == null) { + throw new NoSuchElementException("End of stream") + } + result + } + + protected override def read(): OUT = { + println("sssss") + reader = new ArrowStreamReader(stream2, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + + val bytesReadStart = reader.bytesRead() + batchLoaded = reader.loadNextBatch() + val batch = new ColumnarBatch(vectors) + val rowCount = root.getRowCount +// totalNumberOfRows += rowCount + println("Total number of rows: " + totalNumberOfRows) + batch.setNumRows(root.getRowCount) + + val out = deserializeColumnarBatch(batch, schema) + +// reader.close(false) +// worker.close() +// reader.s + return out + + numberOfReads += 1 + if (numberOfReads > 5) { + handleEndOfDataSection() + return null.asInstanceOf[OUT] + } + println("worker is closed : " + worker.isInputShutdown) +// while (writerThread.isAlive) { +// Thread.sleep(10) +// println("waiting for writer to finish...") +// } +// +// while (stream.available() == 0) { +// Thread.sleep(10) +// println("waiting for data...") +// } +// +// println(stream.available()) +// +// return null.asInstanceOf[OUT] + if (writerThread.exception.isDefined) { + throw writerThread.exception.get + } + try { + if (reader != null && batchLoaded) { + val bytesReadStart = reader.bytesRead() + batchLoaded = reader.loadNextBatch() + println("ssss") + + if (batchLoaded) { + val batch = new ColumnarBatch(vectors) + val rowCount = root.getRowCount + totalNumberOfRows += rowCount + println("Total number of rows: " + totalNumberOfRows) + batch.setNumRows(root.getRowCount) + val bytesReadEnd = reader.bytesRead() + // 1_571_296 + // 24_133_432 + // 48 264 788 + // 48 264 720 + // 41076056 + + pythonMetrics("pythonNumRowsReceived") += rowCount + pythonMetrics("pythonDataReceived") += bytesReadEnd - bytesReadStart + val out = deserializeColumnarBatch(batch, schema) + out + } else { + reader.close(false) + allocator.close() + // Reach end of stream. Call `read()` again to read control data. + read() + } + } else { +// if (stream.available() == 0) { +// println("ssss") +// throw handleException +// return null.asInstanceOf[OUT] +// } + val streamType = try { + stream.readInt() + } catch { + case e: Throwable => + SpecialLengths.END_OF_DATA_SECTION + } + + streamType match { + case SpecialLengths.START_ARROW_STREAM => +// file input stream + if (numberOfReads > 2) { + return null.asInstanceOf[OUT] + } + + val stream2 = new FileInputStream(new File("/Users/pawelkocinski/Desktop/projects/sedonaworker/sedonaworker/output_batch_data.arrow")) + reader = new ArrowStreamReader(stream2, allocator) + root = reader.getVectorSchemaRoot() + schema = ArrowUtils.fromArrowSchema(root.getSchema()) + vectors = root.getFieldVectors().asScala.map { vector => + new ArrowColumnVector(vector) + }.toArray[ColumnVector] + read() +// case SpecialLengths.TIMING_DATA => +// handleTimingData() +// read() +// case SpecialLengths.PYTHON_EXCEPTION_THROWN => +// throw handlePythonException() + case SpecialLengths.END_OF_DATA_SECTION => + handleEndOfDataSection() + null.asInstanceOf[OUT] + case _ => + handleEndOfDataSection() + null.asInstanceOf[OUT] + } + } + } + } + } + } +} + +private[python] trait SedonaBasicPythonArrowOutput extends SedonaPythonArrowOutput[ColumnarBatch] { + self: BasePythonRunner[_, ColumnarBatch] => + + protected def deserializeColumnarBatch( + batch: ColumnarBatch, + schema: StructType): ColumnarBatch = batch +} + diff --git a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala index c1193cb7fa..5066516a8b 100644 --- a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala +++ b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/execution/python/WorkerContext.scala @@ -37,12 +37,15 @@ object WorkerContext { envVars: Map[String, String], worker: Socket): Unit = { synchronized { - val key = (pythonExec, envVars) - pythonWorkers.get(key).foreach(_.stopWorker(worker)) + worker.close() +// val key = (pythonExec, envVars) +// pythonWorkers.get(key).foreach(workerFactory => { +// workerFactory.stopWorker(worker) +// }) } } - private val pythonWorkers = + private var pythonWorkers = mutable.HashMap[(String, Map[String, String]), SedonaDBWorkerFactory]() } diff --git a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala index 28943ff11d..4eb48e4ca7 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/sedona/sql/TestBaseScala.scala @@ -40,7 +40,7 @@ trait TestBaseScala extends FunSpec with BeforeAndAfterAll { val warehouseLocation = System.getProperty("user.dir") + "/target/" val sparkSession = SedonaContext .builder() - .master("local[*]") + .master("local[1]") .appName("sedonasqlScalaTest") .config("spark.sql.warehouse.dir", warehouseLocation) // We need to be explicit about broadcasting in tests. diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala index 9dc6677035..e638a05f25 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala @@ -18,16 +18,17 @@ */ package org.apache.spark.sql.udf -import org.apache.sedona.spark.SedonaContext import org.apache.sedona.sql.TestBaseScala +import org.apache.spark.SparkEnv +import org.apache.spark.security.SocketAuthHelper import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.execution.python.WorkerContext -import org.apache.spark.sql.functions.{col, expr, lit} -import org.apache.spark.sql.udf.ScalarUDF.{geometryToGeometryFunction, nonGeometryVectorizedUDF, nonGeometryVectorizedUDF2} +import org.apache.spark.sql.functions.{col, expr} +import org.apache.spark.sql.udf.ScalarUDF.{geoPandasScalaFunction, nonGeometryVectorizedUDF, sedonaDBGeometryToGeometryFunction} import org.locationtech.jts.io.WKTReader -import org.scalatest.funsuite.AnyFunSuite import org.scalatest.matchers.should.Matchers +import java.net.{InetAddress, ServerSocket, Socket} + class StrategySuite extends TestBaseScala with Matchers { val wktReader = new WKTReader() @@ -39,183 +40,171 @@ class StrategySuite extends TestBaseScala with Matchers { import spark.implicits._ it("sedona geospatial UDF - geopandas") { - val df = Seq( - (1, "value", wktReader.read("POINT(21 52)")), - (2, "value1", wktReader.read("POINT(20 50)")), - (3, "value2", wktReader.read("POINT(20 49)")), - (4, "value3", wktReader.read("POINT(20 48)")), - (5, "value4", wktReader.read("POINT(20 47)"))) - .toDF("id", "value", "geom") - .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) - - df.count shouldEqual 5 - - df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") - .as[String] - .collect() should contain theSameElementsAs Seq( - "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", - "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", - "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", - "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", - "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") - } - - it("sedona geospatial UDF") { -// spark.sql("select 1").show() val df = spark.read .format("geoparquet") .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") - .withColumn("geometry", expr("ST_SetSRID(geometry, '4326')")) + .withColumn("geom_buffer", geoPandasScalaFunction(col("geometry")) ) - df.show() - - df - .select( - col("id"), - col("version"), - col("bbox"), - // nonGeometryVectorizedUDF(col("bbox.xmin")).alias("xmin"), - geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom"), - nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), - // nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), - // geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") - // geometryToNonGeometryFunction(col("geometry")), - // geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") - // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), - ) - .show(10) + df.printSchema() - df - .select( - col("id"), - col("version"), - col("bbox"), - // nonGeometryVectorizedUDF(col("bbox.xmin")).alias("xmin"), - geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom"), - nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), - // nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), - // geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") - // geometryToNonGeometryFunction(col("geometry")), - // geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") - // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), - ) - .explain(true) + df.show() // -// df -// .select( -// col("id"), -// col("version"), -// col("bbox"), -//// nonGeometryVectorizedUDF(col("bbox.xmin")).alias("xmin"), -// geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom"), -//// nonGeometryVectorizedUDF(col("bbox.xmin")).alias("xmin"), -//// nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), -//// nonGeometryVectorizedUDF2(col("bbox.xmin")).alias("xmin"), -//// geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") -//// geometryToNonGeometryFunction(col("geometry")), -//// geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") -//// nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), -// ) -// .show(10) +// val df = Seq( +// (1, "value", wktReader.read("POINT(21 52)")), +// (2, "value1", wktReader.read("POINT(20 50)")), +// (3, "value2", wktReader.read("POINT(20 49)")), +// (4, "value3", wktReader.read("POINT(20 48)")), +// (5, "value4", wktReader.read("POINT(20 47)"))) +// .toDF("id", "value", "geom") +// .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) + +// df.count shouldEqual 5 +// +// df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") +// .as[String] +// .collect() should contain theSameElementsAs Seq( +// "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", +// "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", +// "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", +// "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", +// "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") } + it("sedona geospatial UDF - sedona db") { +// val df = Seq( +// (1, "value", wktReader.read("POINT(21 52)")), +// (2, "value1", wktReader.read("POINT(20 50)")), +// (3, "value2", wktReader.read("POINT(20 49)")), +// (4, "value3", wktReader.read("POINT(20 48)")), +// (5, "value4", wktReader.read("POINT(20 47)"))) +// .toDF("id", "value", "geometry") + +// df.cache() +// df.count() +// .select( +// sedonaDBGeometryToGeometryFunction(col("geometry")).alias("geom"), +// nonGeometryVectorizedUDF(col("id")).alias("id_increased"), +// ) - it("sedona db 1 geospatial UDF") { - // spark.sql("select 1").show() +// spark.read +// .format("geoparquet") +// .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") +// .limit(10000) +// .write.format("geoparquet") +// .save("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings_2") val df = spark.read .format("geoparquet") - .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") - .withColumn("geometry", expr("ST_SetSRID(geometry, '4326')")) + .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings_2") + .select("geometry") - df.show() + df.cache() + df.count() +// .limit(100) - df.printSchema() - df +// println(df.count()) + +// df.cache() +// +// df.count() + + val dfVectorized = df + .withColumn("geometry", expr("ST_SetSRID(geometry, '4326')")) .select( - col("id"), - col("version"), - col("bbox"), - // geometryToNonGeometryFunction(col("geometry")), - geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") - // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), +// col("id"), +// col("version"), +// col("bbox"), + sedonaDBGeometryToGeometryFunction(col("geometry")).alias("geom"), +// nonGeometryVectorizedUDF(col("id")).alias("id_increased"), ) - .show(10) - - println( - df - .select( - col("id"), - col("version"), - col("bbox"), - // geometryToNonGeometryFunction(col("geometry")), - geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") - // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), - ) - .count()) - - WorkerContext - - // df.show() - 1 shouldBe 1 - - // val df = Seq( - // (1, "value", wktReader.read("POINT(21 52)")), - // (2, "value1", wktReader.read("POINT(20 50)")), - // (3, "value2", wktReader.read("POINT(20 49)")), - // (4, "value3", wktReader.read("POINT(20 48)")), - // (5, "value4", wktReader.read("POINT(20 47)"))) - // .toDF("id", "value", "geom") - // .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) - - // df.count shouldEqual 5 - - // df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") - // .as[String] - // .collect() should contain theSameElementsAs Seq( - // "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", - // "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", - // "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", - // "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", - // "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") + + dfVectorized.show() +// dfVectorized.selectExpr("ST_X(ST_Centroid(geom)) AS x").selectExpr("sum(x)").show() +// val processingContext = df.queryExecution.explainString(mode = ExplainMode.fromString("extended")) + +// println(processingContext) } - it("sedona db geospatial UDF") { -// spark.sql("select 1").show() - val df = spark.read - .format("geoparquet") - .load("/Users/pawelkocinski/Desktop/projects/sedona-production/apache-sedona-book/data/warehouse/buildings") - .withColumn("geometry", expr("ST_SetSRID(geometry, '4326')")) + it("should properly start socket server") { + val authHelper = new SocketAuthHelper(SparkEnv.get.conf) + val serverSocket = new ServerSocket(5356, 1, InetAddress.getLoopbackAddress()) - df.show() +// serverSocket.setSoTimeout(15000) + println(serverSocket.getLocalPort) + val socket = serverSocket.accept() - df - .select( - col("id"), - col("version"), - col("bbox"), - // geometryToNonGeometryFunction(col("geometry")), - geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") - // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), - ) - .show(10) - - println( - df - .select( - col("id"), - col("version"), - col("bbox"), - // geometryToNonGeometryFunction(col("geometry")), - geometryToGeometryFunction(col("geometry"), lit(1)).alias("geom") - // nonGeometryToGeometryFunction(expr("ST_AsText(geometry)")), - ) - .count()) - - WorkerContext - - // df.show() - 1 shouldBe 1 - } + println("socket accepted") +// +// val acceptThread = new Thread(() => { +// println("Waiting for client...") +// val socket = serverSocket.accept() // BLOCKS HERE +// println("Client connected!") +// }, "accept-thread") +// +// acceptThread.start() +// +// println("Main thread continues immediately") + +// val t = new Thread() { +// override def run(): Unit = { +// println("starting client") +// socket = serverSocket.accept() +// println("client connected") +// +// } +// } + +// t.start() +// val socket = serverSocket.accept() +// authHelper.authClient(socket) + + println(authHelper.secret) +// Thread.sleep(10000) +// authHelper.authClient(socket) + +// var socket: Socket = null + +// new Thread() { +// socket = serverSocket.accept() +// } +// +// val t2 = new Thread() { +// override def run(): Unit = { +// println("accepted connection") +// val socket = serverSocket.accept() +// val in = socket.getInputStream +// val buffer = new Array[Byte](1024) +// println("accepted connection") +// var bytesRead = in.read(buffer) +// while (bytesRead != -1) { +// val received = new String(buffer, 0, bytesRead) +// println(s"Received: $received") +// bytesRead = in.read(buffer) +// } +// in.close() +// socket.close() +// } +// } +// +// t2.start() +// +// val thread2 = new Thread() { +// override def run(): Unit = { +// println("starting client") +// serverSocket.close() +// } +// } +// +// val t = new Thread(() => { +// println("hello from thread") +// }) +// +// t.start() +// t.join() // <-- this IS valid +// t2.join() + +// Thread.sleep(30000) + + } } diff --git a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala index cc83f5f852..89afa10986 100644 --- a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala +++ b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala @@ -22,7 +22,7 @@ import org.apache.sedona.sql.UDF import org.apache.spark.{SparkEnv, TestUtils} import org.apache.spark.api.python._ import org.apache.spark.broadcast.Broadcast -import org.apache.spark.internal.config.Python.{PYTHON_DAEMON_MODULE, PYTHON_USE_DAEMON, PYTHON_WORKER_MODULE} +import org.apache.spark.internal.config.Python.{PYTHON_USE_DAEMON, PYTHON_WORKER_MODULE} import org.apache.spark.sql.execution.python.UserDefinedPythonFunction import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.types.FloatType @@ -45,6 +45,9 @@ object ScalarUDF { } } + SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false) +// SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "org.apache.sedona.python.SedonaPythonWorker") + private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "") protected lazy val sparkHome: String = { sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) @@ -74,30 +77,6 @@ object ScalarUDF { val additionalModule = "spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf" - val vectorizedFunction2: Array[Byte] = { - var binaryPandasFunc: Array[Byte] = null - withTempPath { path => - Process( - Seq( - pythonExec, - "-c", - f""" - |from pyspark.sql.types import FloatType - |from pyspark.serializers import CloudPickleSerializer - |f = open('$path', 'wb'); - | - |def apply_function_on_number(x): - | return x + 3.0 - |f.write(CloudPickleSerializer().dumps((apply_function_on_number, FloatType()))) - |""".stripMargin), - None, - "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! - binaryPandasFunc = Files.readAllBytes(path.toPath) - } - assert(binaryPandasFunc != null) - binaryPandasFunc - } - val vectorizedFunction: Array[Byte] = { var binaryPandasFunc: Array[Byte] = null withTempPath { path => @@ -122,7 +101,7 @@ object ScalarUDF { binaryPandasFunc } - val geopandasGeometryToGeometryFunction: Array[Byte] = { + val sedonaDBGeometryToGeometryFunctionBytes: Array[Byte] = { var binaryPandasFunc: Array[Byte] = null withTempPath { path => Process( @@ -156,70 +135,84 @@ object ScalarUDF { assert(binaryPandasFunc != null) binaryPandasFunc } - // - // val geopandasNonGeometryToGeometryFunction: Array[Byte] = { - // var binaryPandasFunc: Array[Byte] = null - // withTempPath { path => - // Process( - // Seq( - // pythonExec, - // "-c", - // f""" - // |from sedona.sql.types import GeometryType - // |from shapely.wkt import loads - // |from pyspark.serializers import CloudPickleSerializer - // |f = open('$path', 'wb'); - // |def apply_geopandas(x): - // | return x.apply(lambda wkt: loads(wkt).buffer(1)) - // |f.write(CloudPickleSerializer().dumps((apply_geopandas, GeometryType()))) - // |""".stripMargin), - // None, - // "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! - // binaryPandasFunc = Files.readAllBytes(path.toPath) - // } - // assert(binaryPandasFunc != null) - // binaryPandasFunc - // } + + val geopandasNonGeometryToGeometryFunction: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + Process( + Seq( + pythonExec, + "-c", + f""" + |from sedona.sql.types import GeometryType + |from shapely.wkt import loads + |from pyspark.serializers import CloudPickleSerializer + |f = open('$path', 'wb'); + |def apply_geopandas(x): + | return x.apply(lambda wkt: loads(wkt).buffer(1)) + |f.write(CloudPickleSerializer().dumps((apply_geopandas, GeometryType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } private val workerEnv = new java.util.HashMap[String, String]() - workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") - SparkEnv.get.conf.set(PYTHON_WORKER_MODULE, "sedonaworker.initialworker") -// SparkEnv.get.conf.set(PYTHON_USE_DAEMON, false) - // - // val geometryToNonGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( - // name = "geospatial_udf", - // func = SimplePythonFunction( - // command = geopandasGeometryToNonGeometry, - // envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], - // pythonIncludes = List.empty[String].asJava, - // pythonExec = pythonExec, - // pythonVer = pythonVer, - // broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, - // accumulator = null), - // dataType = FloatType, - // pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, - // udfDeterministic = true) - val nonGeometryVectorizedUDF: UserDefinedPythonFunction = UserDefinedPythonFunction( - name = "vectorized_udf", + val pandasFunc: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + println(path) + Process( + Seq( + pythonExec, + "-c", + f""" + |from pyspark.sql.types import IntegerType + |from shapely.geometry import Point + |from sedona.sql.types import GeometryType + |from pyspark.serializers import CloudPickleSerializer + |from sedona.utils import geometry_serde + |from shapely import box + |f = open('$path', 'wb'); + |def w(x): + | def apply_function(w): + | geom, offset = geometry_serde.deserialize(w) + | bounds = geom.buffer(1).bounds + | x = box(*bounds) + | return geometry_serde.serialize(x) + | return x.apply(apply_function) + |f.write(CloudPickleSerializer().dumps((w, GeometryType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } + + val geoPandasScalaFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( + name = "geospatial_udf", func = SimplePythonFunction( - command = vectorizedFunction, + command = pandasFunc, envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], pythonIncludes = List.empty[String].asJava, pythonExec = pythonExec, pythonVer = pythonVer, broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, - accumulator = null - ), - dataType = FloatType, - pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, - udfDeterministic = true - ) + accumulator = null), + dataType = GeometryUDT, + pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, + udfDeterministic = true) - val nonGeometryVectorizedUDF2: UserDefinedPythonFunction = UserDefinedPythonFunction( + val nonGeometryVectorizedUDF: UserDefinedPythonFunction = UserDefinedPythonFunction( name = "vectorized_udf", func = SimplePythonFunction( - command = vectorizedFunction2, + command = vectorizedFunction, envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], pythonIncludes = List.empty[String].asJava, pythonExec = pythonExec, @@ -229,13 +222,13 @@ object ScalarUDF { ), dataType = FloatType, pythonEvalType = PythonEvalType.SQL_SCALAR_PANDAS_UDF, - udfDeterministic = true + udfDeterministic = false ) - val geometryToGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( + val sedonaDBGeometryToGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( name = "geospatial_udf", func = SimplePythonFunction( - command = geopandasGeometryToGeometryFunction, + command = sedonaDBGeometryToGeometryFunctionBytes, envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], pythonIncludes = List.empty[String].asJava, pythonExec = pythonExec, @@ -245,18 +238,5 @@ object ScalarUDF { dataType = GeometryUDT, pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_DB_UDF, udfDeterministic = false) - // - // val nonGeometryToGeometryFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( - // name = "geospatial_udf", - // func = SimplePythonFunction( - // command = geopandasNonGeometryToGeometryFunction, - // envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], - // pythonIncludes = List.empty[String].asJava, - // pythonExec = pythonExec, - // pythonVer = pythonVer, - // broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, - // accumulator = null), - // dataType = GeometryUDT, - // pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF, - // udfDeterministic = true) + }
