This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch sedona-arrow-udf-example in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 98728457796c01f61cc5f9fd38ffa9d90d54ddb1 Author: pawelkocinski <[email protected]> AuthorDate: Sun Feb 23 21:53:47 2025 +0100 SEDONA-721 Add Sedona vectorized udf. --- pom.xml | 12 +- python/sedona/sql/udf.py | 19 +++ python/sedona/utils/geoarrow.py | 3 +- python/tests/utils/test_pandas_arrow_udf.py | 46 ++++++ .../org/apache/sedona/sql/RasterRegistrator.scala | 7 + .../sedona_sql/strategies/ExtractSedonaUDF.scala | 165 +++++++++++++++++++++ .../sql/sedona_sql/strategies/PythonEvalType.scala | 28 ++++ .../strategies/SedonaArrowEvalPython.scala | 32 ++++ .../strategies/SedonaArrowStrategy.scala | 82 ++++++++++ .../sql/sedona_sql/strategies/StrategySuite.scala | 67 +++++++++ .../strategies/TestScalarPandasUDF.scala | 121 +++++++++++++++ 11 files changed, 574 insertions(+), 8 deletions(-) diff --git a/pom.xml b/pom.xml index 08d4ff646a..da87637d32 100644 --- a/pom.xml +++ b/pom.xml @@ -18,12 +18,12 @@ --> <project xmlns="http://maven.apache.org/POM/4.0.0" xmlns:xsi="http://www.w3.org/2001/XMLSchema-instance" xsi:schemaLocation="http://maven.apache.org/POM/4.0.0 http://maven.apache.org/maven-v4_0_0.xsd"> <modelVersion>4.0.0</modelVersion> - <parent> - <groupId>org.apache</groupId> - <artifactId>apache</artifactId> - <version>23</version> - <relativePath /> - </parent> +<!-- <parent>--> +<!-- <groupId>org.apache</groupId>--> +<!-- <artifactId>apache</artifactId>--> +<!-- <version>23</version>--> +<!-- <relativePath />--> +<!-- </parent>--> <groupId>org.apache.sedona</groupId> <artifactId>sedona-parent</artifactId> <version>1.8.0-SNAPSHOT</version> diff --git a/python/sedona/sql/udf.py b/python/sedona/sql/udf.py new file mode 100644 index 0000000000..6e723d0e10 --- /dev/null +++ b/python/sedona/sql/udf.py @@ -0,0 +1,19 @@ +import pandas as pd + +from sedona.sql.types import GeometryType +from sedona.utils import geometry_serde +from pyspark.sql.udf import UserDefinedFunction + + +SEDONA_SCALAR_EVAL_TYPE = 5200 + + +def sedona_vectorized_udf(fn): + def apply(series: pd.Series) -> pd.Series: + geo_series = series.apply(lambda x: fn(geometry_serde.deserialize(x)[0])) + + return geo_series.apply(lambda x: geometry_serde.serialize(x)) + + return UserDefinedFunction( + apply, GeometryType(), "SedonaPandasArrowUDF", evalType=SEDONA_SCALAR_EVAL_TYPE + ) diff --git a/python/sedona/utils/geoarrow.py b/python/sedona/utils/geoarrow.py index b4a539dfa4..b0f708af9c 100644 --- a/python/sedona/utils/geoarrow.py +++ b/python/sedona/utils/geoarrow.py @@ -323,12 +323,11 @@ def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) -> Data step = spark._jconf.arrowMaxRecordsPerBatch() step = step if step > 0 else len(gdf) pdf_slices = (gdf.iloc[start : start + step] for start in range(0, len(gdf), step)) - spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields] arrow_data = [ [ (c, to_arrow_type(t) if t is not None else None, t) - for (_, c), t in zip(pdf_slice.items(), spark_types) + for (_, c), t in zip(pdf_slice.items(), schema.fields) ] for pdf_slice in pdf_slices ] diff --git a/python/tests/utils/test_pandas_arrow_udf.py b/python/tests/utils/test_pandas_arrow_udf.py new file mode 100644 index 0000000000..7d59a93283 --- /dev/null +++ b/python/tests/utils/test_pandas_arrow_udf.py @@ -0,0 +1,46 @@ +from sedona.sql.types import GeometryType +from sedona.sql.udf import sedona_vectorized_udf +from tests import chicago_crimes_input_location +from tests.test_base import TestBase +import pyspark.sql.functions as f +import shapely.geometry.base as b +from time import time + + +def non_vectorized_buffer_udf(geom: b.BaseGeometry) -> b.BaseGeometry: + return geom.buffer(0.001) + + +@sedona_vectorized_udf +def vectorized_buffer(geom: b.BaseGeometry) -> b.BaseGeometry: + return geom.buffer(0.001) + + +buffer_distanced_udf = f.udf(non_vectorized_buffer_udf, GeometryType()) + + +class TestSedonaArrowUDF(TestBase): + + def test_pandas_arrow_udf(self): + df = ( + self.spark.read.option("header", "true") + .format("csv") + .load(chicago_crimes_input_location) + .selectExpr("ST_Point(y, x) AS geom") + ) + + vectorized_times = [] + non_vectorized_times = [] + + for i in range(10): + start = time() + df = df.withColumn("buffer", vectorized_buffer(f.col("geom"))) + df.count() + vectorized_times.append(time() - start) + + df = df.withColumn("buffer", buffer_distanced_udf(f.col("geom"))) + df.count() + non_vectorized_times.append(time() - start) + + for v, nv in zip(vectorized_times, non_vectorized_times): + assert v < nv, "Vectorized UDF is slower than non-vectorized UDF" diff --git a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala index ee7aa8b0be..bcacb2ab29 100644 --- a/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala +++ b/spark/common/src/main/scala/org/apache/sedona/sql/RasterRegistrator.scala @@ -22,6 +22,7 @@ import org.apache.sedona.sql.UDF.RasterUdafCatalog import org.apache.sedona.sql.utils.GeoToolsCoverageAvailability.{gridClassName, isGeoToolsAvailable} import org.apache.spark.sql.catalyst.FunctionIdentifier import org.apache.spark.sql.sedona_sql.UDT.RasterUdtRegistratorWrapper +import org.apache.spark.sql.sedona_sql.strategies.{ExtractSedonaUDF, SedonaArrowStrategy} import org.apache.spark.sql.{SparkSession, functions} import org.slf4j.{Logger, LoggerFactory} @@ -29,6 +30,12 @@ object RasterRegistrator { val logger: Logger = LoggerFactory.getLogger(getClass) def registerAll(sparkSession: SparkSession): Unit = { + + sparkSession.experimental.extraStrategies = + sparkSession.experimental.extraStrategies :+ new SedonaArrowStrategy() + sparkSession.experimental.extraOptimizations = + sparkSession.experimental.extraOptimizations :+ ExtractSedonaUDF + if (isGeoToolsAvailable) { RasterUdtRegistratorWrapper.registerAll(gridClassName) sparkSession.udf.register( diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/ExtractSedonaUDF.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/ExtractSedonaUDF.scala new file mode 100644 index 0000000000..be34fa5fcc --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/ExtractSedonaUDF.scala @@ -0,0 +1,165 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.strategies + +import org.apache.spark.sql.catalyst.expressions.{AttributeReference, Expression, ExpressionSet, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project, Subquery} +import org.apache.spark.sql.catalyst.rules.Rule +import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_UDF + +import scala.collection.mutable + +object ExtractSedonaUDF extends Rule[LogicalPlan] { + + private def hasScalarPythonUDF(e: Expression): Boolean = { + e.exists(PythonUDF.isScalarPythonUDF) + } + + @scala.annotation.tailrec + private def canEvaluateInPython(e: PythonUDF): Boolean = { + e.children match { + case Seq(u: PythonUDF) => e.evalType == u.evalType && canEvaluateInPython(u) + case children => !children.exists(hasScalarPythonUDF) + } + } + + def isScalarPythonUDF(e: Expression): Boolean = { + e.isInstanceOf[PythonUDF] && e + .asInstanceOf[PythonUDF] + .evalType == PythonEvalType.SQL_SCALAR_SEDONA_UDF + } + + private def collectEvaluableUDFsFromExpressions( + expressions: Seq[Expression]): Seq[PythonUDF] = { + + var firstVisitedScalarUDFEvalType: Option[Int] = None + + def canChainUDF(evalType: Int): Boolean = { + evalType == firstVisitedScalarUDFEvalType.get + } + + def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match { + case udf: PythonUDF + if isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && firstVisitedScalarUDFEvalType.isEmpty => + firstVisitedScalarUDFEvalType = Some(udf.evalType) + Seq(udf) + case udf: PythonUDF + if isScalarPythonUDF(udf) && canEvaluateInPython(udf) + && canChainUDF(udf.evalType) => + Seq(udf) + case e => e.children.flatMap(collectEvaluableUDFs) + } + + expressions.flatMap(collectEvaluableUDFs) + } + + def apply(plan: LogicalPlan): LogicalPlan = plan match { + case s: Subquery if s.correlated => plan + + case _ => + plan.transformUpWithPruning(_.containsPattern(PYTHON_UDF)) { + case p: SedonaArrowEvalPython => p + + case plan: LogicalPlan => extract(plan) + } + } + + private def canonicalizeDeterministic(u: PythonUDF) = { + if (u.deterministic) { + u.canonicalized.asInstanceOf[PythonUDF] + } else { + u + } + } + + private def extract(plan: LogicalPlan): LogicalPlan = { + val udfs = ExpressionSet(collectEvaluableUDFsFromExpressions(plan.expressions)) + .filter(udf => udf.references.subsetOf(plan.inputSet)) + .toSeq + .asInstanceOf[Seq[PythonUDF]] + + udfs match { + case Seq() => plan + case _ => resolveUDFs(plan, udfs) + } + } + + def resolveUDFs(plan: LogicalPlan, udfs: Seq[PythonUDF]): LogicalPlan = { + val attributeMap = mutable.HashMap[PythonUDF, Expression]() + + val newChildren = adjustAttributeMap(plan, udfs, attributeMap) + + udfs.map(canonicalizeDeterministic).filterNot(attributeMap.contains).foreach { udf => + throw new IllegalStateException( + s"Invalid PythonUDF $udf, requires attributes from more than one child.") + } + + val rewritten = plan.withNewChildren(newChildren).transformExpressions { case p: PythonUDF => + attributeMap.getOrElse(canonicalizeDeterministic(p), p) + } + + val newPlan = extract(rewritten) + if (newPlan.output != plan.output) { + Project(plan.output, newPlan) + } else { + newPlan + } + } + + def adjustAttributeMap( + plan: LogicalPlan, + udfs: Seq[PythonUDF], + attributeMap: mutable.HashMap[PythonUDF, Expression]): Seq[LogicalPlan] = { + plan.children.map { child => + val validUdfs = udfs.filter { udf => + udf.references.subsetOf(child.outputSet) + } + + if (validUdfs.nonEmpty) { + require( + validUdfs.forall(isScalarPythonUDF), + "Can only extract scalar vectorized udf or sql batch udf") + + val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) => + AttributeReference(s"pythonUDF$i", u.dataType)() + } + + val evalTypes = validUdfs.map(_.evalType).toSet + if (evalTypes.size != 1) { + throw new IllegalStateException( + "Expected udfs have the same evalType but got different evalTypes: " + + evalTypes.mkString(",")) + } + val evalType = evalTypes.head + val evaluation = evalType match { + case PythonEvalType.SQL_SCALAR_SEDONA_UDF => + SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType) + case _ => + throw new IllegalStateException("Unexpected UDF evalType") + } + + attributeMap ++= validUdfs.map(canonicalizeDeterministic).zip(resultAttrs) + evaluation + } else { + child + } + } + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/PythonEvalType.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/PythonEvalType.scala new file mode 100644 index 0000000000..0a8904edb4 --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/PythonEvalType.scala @@ -0,0 +1,28 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.strategies + +object PythonEvalType { + val SQL_SCALAR_SEDONA_UDF = 5200 + val SEDONA_UDF_TYPE_CONSTANT = 5000 + + def toString(pythonEvalType: Int): String = pythonEvalType match { + case SQL_SCALAR_SEDONA_UDF => "SQL_SCALAR_GEO_UDF" + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala new file mode 100644 index 0000000000..78e00871ad --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala @@ -0,0 +1,32 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.strategies + +import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.{BaseEvalPython, LogicalPlan} + +case class SedonaArrowEvalPython( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: LogicalPlan, + evalType: Int) + extends BaseEvalPython { + override protected def withNewChildInternal(newChild: LogicalPlan): SedonaArrowEvalPython = + copy(child = newChild) +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowStrategy.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowStrategy.scala new file mode 100644 index 0000000000..f5a0d1c95f --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowStrategy.scala @@ -0,0 +1,82 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.strategies + +import org.apache.spark.api.python.ChainedPythonFunctions +import org.apache.spark.{JobArtifactSet, TaskContext} +import org.apache.spark.sql.Strategy +import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF} +import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan +import org.apache.spark.sql.execution.SparkPlan +import org.apache.spark.sql.execution.python.{ArrowPythonRunner, BatchIterator, EvalPythonExec, PythonSQLMetrics} +import org.apache.spark.sql.types.StructType + +import scala.jdk.CollectionConverters.asScalaIteratorConverter + +class SedonaArrowStrategy extends Strategy { + override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match { + case SedonaArrowEvalPython(udfs, output, child, evalType) => + SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) :: Nil + case _ => Nil + } +} + +case class SedonaArrowEvalPythonExec( + udfs: Seq[PythonUDF], + resultAttrs: Seq[Attribute], + child: SparkPlan, + evalType: Int) + extends EvalPythonExec + with PythonSQLMetrics { + + private val batchSize = conf.arrowMaxRecordsPerBatch + private val sessionLocalTimeZone = conf.sessionLocalTimeZone + private val largeVarTypes = conf.arrowUseLargeVarTypes + private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf) + private[this] val jobArtifactUUID = JobArtifactSet.getCurrentJobArtifactState.map(_.uuid) + + protected override def evaluate( + funcs: Seq[ChainedPythonFunctions], + argOffsets: Array[Array[Int]], + iter: Iterator[InternalRow], + schema: StructType, + context: TaskContext): Iterator[InternalRow] = { + + val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else Iterator(iter) + + val columnarBatchIter = new ArrowPythonRunner( + funcs, + evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT, + argOffsets, + schema, + sessionLocalTimeZone, + largeVarTypes, + pythonRunnerConf, + pythonMetrics, + jobArtifactUUID).compute(batchIter, context.partitionId(), context) + + columnarBatchIter.flatMap { batch => + batch.rowIterator.asScala + } + } + + override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan = + copy(child = newChild) +} diff --git a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/StrategySuite.scala b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/StrategySuite.scala new file mode 100644 index 0000000000..52d8ea8bac --- /dev/null +++ b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/StrategySuite.scala @@ -0,0 +1,67 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.strategies + +import org.apache.sedona.spark.SedonaContext +import org.apache.spark.sql.SparkSession +import org.apache.spark.sql.functions.col +import org.apache.spark.sql.sedona_sql.strategies.ScalarUDF.geoPandasScalaFunction +import org.locationtech.jts.io.WKTReader +import org.scalatest.funsuite.AnyFunSuite +import org.scalatest.matchers.should.Matchers + +class StrategySuite extends AnyFunSuite with Matchers { + val wktReader = new WKTReader() + + val spark: SparkSession = { + val builder = SedonaContext + .builder() + .master("local[*]") + .appName("sedonasqlScalaTest") + + val spark = SedonaContext.create(builder.getOrCreate()) + + spark.sparkContext.setLogLevel("ALL") + spark + } + + import spark.implicits._ + + test("Chained Scalar Pandas UDFs should be combined to a single physical node") { + val df = Seq( + (1, "value", wktReader.read("POINT(21 52)")), + (2, "value1", wktReader.read("POINT(20 50)")), + (3, "value2", wktReader.read("POINT(20 49)")), + (4, "value3", wktReader.read("POINT(20 48)")), + (5, "value4", wktReader.read("POINT(20 47)"))) + .toDF("id", "value", "geom") + .withColumn("geom_buffer", geoPandasScalaFunction(col("geom"))) + + df.count shouldEqual 5 + + df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))") + .as[String] + .collect() should contain theSameElementsAs Seq( + "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))", + "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))", + "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))", + "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))", + "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))") + } +} diff --git a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/TestScalarPandasUDF.scala b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/TestScalarPandasUDF.scala new file mode 100644 index 0000000000..925115b5e8 --- /dev/null +++ b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/strategies/TestScalarPandasUDF.scala @@ -0,0 +1,121 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.strategies + +import org.apache.spark.TestUtils +import org.apache.spark.api.python._ +import org.apache.spark.broadcast.Broadcast +import org.apache.spark.sql.execution.python.UserDefinedPythonFunction +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT +import org.apache.spark.util.Utils + +import java.io.File +import java.nio.file.{Files, Paths} +import scala.jdk.CollectionConverters.seqAsJavaListConverter +import scala.sys.process.Process + +object ScalarUDF { + + val pythonExec: String = { + val pythonExec = + sys.env.getOrElse("PYSPARK_DRIVER_PYTHON", sys.env.getOrElse("PYSPARK_PYTHON", "python3")) + if (TestUtils.testCommandAvailable(pythonExec)) { + pythonExec + } else { + "python" + } + } + + private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "") + protected lazy val sparkHome: String = { + sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME")) + } + + private lazy val py4jPath = + Paths.get(sparkHome, "python", "lib", PythonUtils.PY4J_ZIP_NAME).toAbsolutePath + private[spark] lazy val pysparkPythonPath = s"$py4jPath" + + private lazy val isPythonAvailable: Boolean = TestUtils.testCommandAvailable(pythonExec) + + lazy val pythonVer: String = if (isPythonAvailable) { + Process( + Seq(pythonExec, "-c", "import sys; print('%d.%d' % sys.version_info[:2])"), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!.trim() + } else { + throw new RuntimeException(s"Python executable [$pythonExec] is unavailable.") + } + + protected def withTempPath(f: File => Unit): Unit = { + val path = Utils.createTempDir() + path.delete() + try f(path) + finally Utils.deleteRecursively(path) + } + + val pandasFunc: Array[Byte] = { + var binaryPandasFunc: Array[Byte] = null + withTempPath { path => + println(path) + Process( + Seq( + pythonExec, + "-c", + f""" + |from pyspark.sql.types import IntegerType + |from shapely.geometry import Point + |from sedona.sql.types import GeometryType + |from pyspark.serializers import CloudPickleSerializer + |from sedona.utils import geometry_serde + |from shapely import box + |f = open('$path', 'wb'); + |def w(x): + | def apply_function(w): + | geom, offset = geometry_serde.deserialize(w) + | bounds = geom.buffer(1).bounds + | x = box(*bounds) + | return geometry_serde.serialize(x) + | return x.apply(apply_function) + |f.write(CloudPickleSerializer().dumps((w, GeometryType()))) + |""".stripMargin), + None, + "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!! + binaryPandasFunc = Files.readAllBytes(path.toPath) + } + assert(binaryPandasFunc != null) + binaryPandasFunc + } + + private val workerEnv = new java.util.HashMap[String, String]() + workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath") + + val geoPandasScalaFunction: UserDefinedPythonFunction = UserDefinedPythonFunction( + name = "geospatial_udf", + func = SimplePythonFunction( + command = pandasFunc, + envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]], + pythonIncludes = List.empty[String].asJava, + pythonExec = pythonExec, + pythonVer = pythonVer, + broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava, + accumulator = null), + dataType = GeometryUDT, + pythonEvalType = PythonEvalType.SQL_SCALAR_SEDONA_UDF, + udfDeterministic = true) +}
