This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 1798df23fa [SEDONA-721] Add Sedona vectorized udf for Python (#1859)
1798df23fa is described below
commit 1798df23fa0cbe8460979a41df09e6129a87d8a9
Author: Paweł Tokaj <[email protected]>
AuthorDate: Wed Apr 2 03:35:19 2025 +0200
[SEDONA-721] Add Sedona vectorized udf for Python (#1859)
* SEDONA-721 Add Sedona vectorized udf.
* SEDONA-721 Add documentation
* SEDONA-721 Add documentation
* SEDONA-721 Add documentation
* Update .github/workflows/java.yml
Co-authored-by: Kristin Cowalcijk <[email protected]>
* SEDONA-721 Apply requested changes.
* SEDONA-721 Apply requested changes.
* SEDONA-721 Apply requested changes.
* SEDONA-721 Apply requested changes.
* SEDONA-721 Apply requested changes.
* SEDONA-721 Apply requested changes.
* SEDONA-721 Apply requested changes.
* SEDONA-721 Apply requested changes.
---------
Co-authored-by: Kristin Cowalcijk <[email protected]>
---
.github/workflows/java.yml | 8 +-
docs/tutorial/sql.md | 67 ++++++
python/sedona/sql/functions.py | 144 +++++++++++++
python/tests/utils/test_pandas_arrow_udf.py | 231 +++++++++++++++++++++
.../org/apache/sedona/spark/SedonaContext.scala | 28 ++-
.../org/apache/sedona/sql/UDF/PythonEvalType.scala | 29 +++
.../strategies/SedonaArrowEvalPython.scala | 32 +++
.../spark/sql/udf/ExtractSedonaUDFRule.scala | 168 +++++++++++++++
.../spark/sql/udf/SedonaArrowEvalPython.scala | 32 +++
.../apache/spark/sql/udf/SedonaArrowStrategy.scala | 89 ++++++++
.../org/apache/spark/sql/udf/StrategySuite.scala | 67 ++++++
.../apache/spark/sql/udf/TestScalarPandasUDF.scala | 122 +++++++++++
12 files changed, 1015 insertions(+), 2 deletions(-)
diff --git a/.github/workflows/java.yml b/.github/workflows/java.yml
index f8e4f6b204..921102cb36 100644
--- a/.github/workflows/java.yml
+++ b/.github/workflows/java.yml
@@ -97,7 +97,7 @@ jobs:
java-version: ${{ matrix.jdk }}
- uses: actions/setup-python@v5
with:
- python-version: '3.7'
+ python-version: '3.10'
- name: Cache Maven packages
uses: actions/cache@v3
with:
@@ -110,6 +110,12 @@ jobs:
SKIP_TESTS: ${{ matrix.skipTests }}
run: |
SPARK_COMPAT_VERSION=${SPARK_VERSION:0:3}
+
+ if [ "${SPARK_VERSION}" == "3.5.0" ]; then
+ pip install pyspark==3.5.0 pandas shapely apache-sedona pyarrow
+ export SPARK_HOME=$(python -c "import pyspark;
print(pyspark.__path__[0])")
+ fi
+
mvn -q clean install -Dspark=${SPARK_COMPAT_VERSION}
-Dscala=${SCALA_VERSION:0:4} -Dspark.version=${SPARK_VERSION} ${SKIP_TESTS}
- run: mkdir staging
- run: cp spark-shaded/target/sedona-*.jar staging
diff --git a/docs/tutorial/sql.md b/docs/tutorial/sql.md
index 821ea37d9b..b835084757 100644
--- a/docs/tutorial/sql.md
+++ b/docs/tutorial/sql.md
@@ -1195,6 +1195,73 @@ Output:
+------------------------------+--------+--------------------------------------------------+-----------------+
```
+## Spatial vectorized udfs (Python only)
+
+By default when you create the user defined functions in Python, the UDFs are
not vectorized.
+This means that the UDFs are called row by row which can be slow.
+To speed up the UDFs, you can use the `vectorized` UDF which will be called in
a batch mode
+using Apache Arrow.
+
+To create a vectorized UDF please use the decorator sedona_vectorized_udf.
+Currently supports only the scalar UDFs. Vectorized UDFs are way faster than
+the normal UDFs. It might be even 2x faster than the normal UDFs.
+
+!!!note
+ When you use geometry as an input type, please include the BaseGeometry
type,
+ like Point from shapely or geopandas GeoSeries, when you use GEO_SERIES
vectorized udf.
+ That's how Sedona infers the type and knows if the data should be cast.
+
+Decorator signature looks as follows:
+
+```python
+def sedona_vectorized_udf(udf_type: SedonaUDFType =
SedonaUDFType.SHAPELY_SCALAR, return_type: DataType)
+```
+
+where udf_type is the type of the UDF function, currently supported are:
+
+- SHAPELY_SCALAR
+- GEO_SERIES
+
+The main difference is what input data you get in the function
+Let's analyze the two examples below, that creates buffers from
+a given geometry.
+
+### Shapely scalar UDF
+
+```python
+import shapely.geometry.base as b
+from sedona.sql.functions import sedona_vectorized_udf
+
+@sedona_vectorized_udf(return_type=GeometryType())
+def vectorized_buffer(geom: b.BaseGeometry) -> b.BaseGeometry:
+ return geom.buffer(0.1)
+```
+
+### GeoSeries UDF
+
+```python
+import geopandas as gpd
+from sedona.sql.functions import sedona_vectorized_udf, SedonaUDFType
+from sedona.sql.types import GeometryType
+
+
+@sedona_vectorized_udf(udf_type=SedonaUDFType.GEO_SERIES,
return_type=GeometryType())
+def vectorized_geo_series_buffer(series: gpd.GeoSeries) -> gpd.GeoSeries:
+ buffered = series.buffer(0.1)
+
+ return buffered
+```
+
+To call the UDFs you can use the following code:
+
+```python
+# Shapely scalar UDF
+df.withColumn("buffered", vectorized_buffer(df.geom)).show()
+
+# GeoSeries UDF
+df.withColumn("buffered", vectorized_geo_series_buffer(df.geom)).show()
+```
+
## Save to permanent storage
To save a Spatial DataFrame to some permanent storage such as Hive tables and
HDFS, you can simply convert each geometry in the Geometry type column back to
a plain String and save the plain DataFrame to wherever you want.
diff --git a/python/sedona/sql/functions.py b/python/sedona/sql/functions.py
new file mode 100644
index 0000000000..83648ff53f
--- /dev/null
+++ b/python/sedona/sql/functions.py
@@ -0,0 +1,144 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+import inspect
+from enum import Enum
+
+import pandas as pd
+
+from sedona.sql.types import GeometryType
+from sedona.utils import geometry_serde
+from pyspark.sql.udf import UserDefinedFunction
+from pyspark.sql.types import DataType
+from shapely.geometry.base import BaseGeometry
+
+
+SEDONA_SCALAR_EVAL_TYPE = 5200
+SEDONA_PANDAS_ARROW_NAME = "SedonaPandasArrowUDF"
+
+
+class SedonaUDFType(Enum):
+ SHAPELY_SCALAR = "ShapelyScalar"
+ GEO_SERIES = "GeoSeries"
+
+
+class InvalidSedonaUDFType(Exception):
+ pass
+
+
+sedona_udf_to_eval_type = {
+ SedonaUDFType.SHAPELY_SCALAR: SEDONA_SCALAR_EVAL_TYPE,
+ SedonaUDFType.GEO_SERIES: SEDONA_SCALAR_EVAL_TYPE,
+}
+
+
+def sedona_vectorized_udf(
+ return_type: DataType, udf_type: SedonaUDFType =
SedonaUDFType.SHAPELY_SCALAR
+):
+ import geopandas as gpd
+
+ def apply_fn(fn):
+ function_signature = inspect.signature(fn)
+ serialize_geom = False
+ deserialize_geom = False
+
+ if isinstance(return_type, GeometryType):
+ serialize_geom = True
+
+ if issubclass(function_signature.return_annotation, BaseGeometry):
+ serialize_geom = True
+
+ if issubclass(function_signature.return_annotation, gpd.GeoSeries):
+ serialize_geom = True
+
+ for param in function_signature.parameters.values():
+ if issubclass(param.annotation, BaseGeometry):
+ deserialize_geom = True
+
+ if issubclass(param.annotation, gpd.GeoSeries):
+ deserialize_geom = True
+
+ if udf_type == SedonaUDFType.SHAPELY_SCALAR:
+ return _apply_shapely_series_udf(
+ fn, return_type, serialize_geom, deserialize_geom
+ )
+
+ if udf_type == SedonaUDFType.GEO_SERIES:
+ return _apply_geo_series_udf(
+ fn, return_type, serialize_geom, deserialize_geom
+ )
+
+ raise InvalidSedonaUDFType(f"Invalid UDF type: {udf_type}")
+
+ return apply_fn
+
+
+def _apply_shapely_series_udf(
+ fn, return_type: DataType, serialize_geom: bool, deserialize_geom: bool
+):
+ def apply(series: pd.Series) -> pd.Series:
+ applied = series.apply(
+ lambda x: (
+ fn(geometry_serde.deserialize(x)[0]) if deserialize_geom else
fn(x)
+ )
+ )
+
+ return applied.apply(
+ lambda x: geometry_serde.serialize(x) if serialize_geom else x
+ )
+
+ udf = UserDefinedFunction(
+ apply, return_type, "SedonaPandasArrowUDF",
evalType=SEDONA_SCALAR_EVAL_TYPE
+ )
+
+ return udf
+
+
+def _apply_geo_series_udf(
+ fn, return_type: DataType, serialize_geom: bool, deserialize_geom: bool
+):
+ import geopandas as gpd
+
+ def apply(series: pd.Series) -> pd.Series:
+ series_data = series
+ if deserialize_geom:
+ series_data = gpd.GeoSeries(
+ series.apply(lambda x: geometry_serde.deserialize(x)[0])
+ )
+
+ return fn(series_data).apply(
+ lambda x: geometry_serde.serialize(x) if serialize_geom else x
+ )
+
+ return UserDefinedFunction(
+ apply, return_type, "SedonaPandasArrowUDF",
evalType=SEDONA_SCALAR_EVAL_TYPE
+ )
+
+
+def deserialize_geometry_if_geom(data):
+ if isinstance(data, BaseGeometry):
+ return geometry_serde.deserialize(data)[0]
+
+ return data
+
+
+def serialize_to_geometry_if_geom(data, return_type: DataType):
+ if isinstance(return_type, GeometryType):
+ return geometry_serde.serialize(data)
+
+ return data
diff --git a/python/tests/utils/test_pandas_arrow_udf.py
b/python/tests/utils/test_pandas_arrow_udf.py
new file mode 100644
index 0000000000..c0d723d214
--- /dev/null
+++ b/python/tests/utils/test_pandas_arrow_udf.py
@@ -0,0 +1,231 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+
+from sedona.sql.types import GeometryType
+from sedona.sql.functions import sedona_vectorized_udf, SedonaUDFType
+from tests import chicago_crimes_input_location
+from tests.test_base import TestBase
+import pyspark.sql.functions as f
+import shapely.geometry.base as b
+import geopandas as gpd
+import pytest
+import pyspark
+import pandas as pd
+from pyspark.sql.functions import pandas_udf
+from pyspark.sql.types import IntegerType, FloatType
+from shapely.geometry import Point
+from shapely.wkt import loads
+
+
+def non_vectorized_buffer_udf(geom: b.BaseGeometry) -> b.BaseGeometry:
+ return geom.buffer(0.1)
+
+
+@sedona_vectorized_udf(return_type=GeometryType())
+def vectorized_buffer_udf(geom: b.BaseGeometry) -> b.BaseGeometry:
+ return geom.buffer(0.1)
+
+
+@sedona_vectorized_udf(return_type=FloatType())
+def vectorized_geom_to_numeric_udf(geom: b.BaseGeometry) -> float:
+ return geom.area
+
+
+@sedona_vectorized_udf(return_type=FloatType())
+def vectorized_geom_to_numeric_udf_child_geom(geom: Point) -> float:
+ return geom.x
+
+
+@sedona_vectorized_udf(return_type=GeometryType())
+def vectorized_numeric_to_geom(x: float) -> b.BaseGeometry:
+ return Point(x, x)
+
+
+@sedona_vectorized_udf(udf_type=SedonaUDFType.GEO_SERIES,
return_type=FloatType())
+def vectorized_series_to_numeric_udf(series: gpd.GeoSeries) -> pd.Series:
+ buffered = series.x
+
+ return buffered
+
+
+@sedona_vectorized_udf(udf_type=SedonaUDFType.GEO_SERIES,
return_type=GeometryType())
+def vectorized_series_string_to_geom(x: pd.Series) -> b.BaseGeometry:
+ return x.apply(lambda x: loads(str(x)))
+
+
+@sedona_vectorized_udf(udf_type=SedonaUDFType.GEO_SERIES,
return_type=GeometryType())
+def vectorized_series_string_to_geom_2(x: pd.Series):
+ return x.apply(lambda x: loads(str(x)))
+
+
+@sedona_vectorized_udf(udf_type=SedonaUDFType.GEO_SERIES,
return_type=GeometryType())
+def vectorized_series_buffer_udf(series: gpd.GeoSeries) -> gpd.GeoSeries:
+ buffered = series.buffer(0.1)
+
+ return buffered
+
+
+@pandas_udf(IntegerType())
+def squared_udf(s: pd.Series) -> pd.Series:
+ return s**2 # Perform vectorized operation
+
+
+buffer_distanced_udf = f.udf(non_vectorized_buffer_udf, GeometryType())
+
+
+class TestSedonaArrowUDF(TestBase):
+
+ def get_area(self, df, udf_fn):
+ return (
+ df.select(udf_fn(f.col("geom")).alias("buffer"))
+ .selectExpr("SUM(ST_Area(buffer))")
+ .collect()[0][0]
+ )
+
+ @pytest.mark.skipif(
+ pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+ )
+ def test_pandas_arrow_udf(self):
+ df = (
+ self.spark.read.option("header", "true")
+ .format("csv")
+ .load(chicago_crimes_input_location)
+ .selectExpr("ST_Point(x, y) AS geom")
+ )
+
+ area1 = self.get_area(df, vectorized_buffer_udf)
+ assert area1 > 478
+
+ @pytest.mark.skipif(
+ pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+ )
+ def test_pandas_udf_shapely_geometry_and_numeric(self):
+ df = (
+ self.spark.read.option("header", "true")
+ .format("csv")
+ .load(chicago_crimes_input_location)
+ .selectExpr("ST_Point(x, y) AS geom", "x")
+ .select(
+ vectorized_geom_to_numeric_udf(f.col("geom")).alias("area"),
+ vectorized_geom_to_numeric_udf_child_geom(f.col("geom")).alias(
+ "x_coordinate"
+ ),
+ vectorized_numeric_to_geom(f.col("x").cast("float")).alias(
+ "geom_second"
+ ),
+ )
+ )
+
+ assert df.select(f.sum("area")).collect()[0][0] == 0.0
+ assert -1339276 > df.select(f.sum("x_coordinate")).collect()[0][0] >
-1339277
+ assert (
+ -1339276
+ > df.selectExpr("ST_X(geom_second) AS x_coordinate")
+ .select(f.sum("x_coordinate"))
+ .collect()[0][0]
+ > -1339277
+ )
+
+ @pytest.mark.skipif(
+ pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+ )
+ def test_pandas_udf_geoseries_geometry_and_numeric(self):
+ df = (
+ self.spark.read.option("header", "true")
+ .format("csv")
+ .load(chicago_crimes_input_location)
+ .selectExpr(
+ "ST_Point(x, y) AS geom",
+ "CONCAT('POINT(', x, ' ', y, ')') AS wkt",
+ )
+ .select(
+
vectorized_series_to_numeric_udf(f.col("geom")).alias("x_coordinate"),
+ vectorized_series_string_to_geom(f.col("wkt")).alias("geom"),
+
vectorized_series_string_to_geom_2(f.col("wkt")).alias("geom_2"),
+ )
+ )
+
+ assert -1339276 > df.select(f.sum("x_coordinate")).collect()[0][0] >
-1339277
+ assert (
+ -1339276
+ > df.selectExpr("ST_X(geom) AS x_coordinate")
+ .select(f.sum("x_coordinate"))
+ .collect()[0][0]
+ > -1339277
+ )
+ assert (
+ -1339276
+ > df.selectExpr("ST_X(geom_2) AS x_coordinate")
+ .select(f.sum("x_coordinate"))
+ .collect()[0][0]
+ > -1339277
+ )
+
+ @pytest.mark.skipif(
+ pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+ )
+ def test_pandas_udf_numeric_to_geometry(self):
+ df = (
+ self.spark.read.option("header", "true")
+ .format("csv")
+ .load(chicago_crimes_input_location)
+ .selectExpr("ST_Point(y, x) AS geom")
+ )
+
+ area1 = self.get_area(df, vectorized_buffer_udf)
+ assert area1 > 478
+
+ @pytest.mark.skipif(
+ pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+ )
+ def test_pandas_udf_numeric_and_numeric_to_geometry(self):
+ df = (
+ self.spark.read.option("header", "true")
+ .format("csv")
+ .load(chicago_crimes_input_location)
+ .selectExpr("ST_Point(y, x) AS geom")
+ )
+
+ area1 = self.get_area(df, vectorized_buffer_udf)
+ assert area1 > 478
+
+ @pytest.mark.skipif(
+ pyspark.__version__ < "3.5", reason="requires Spark 3.5 or higher"
+ )
+ def test_geo_series_udf(self):
+ df = (
+ self.spark.read.option("header", "true")
+ .format("csv")
+ .load(chicago_crimes_input_location)
+ .selectExpr("ST_Point(y, x) AS geom")
+ )
+
+ area = self.get_area(df, vectorized_series_buffer_udf)
+
+ assert area > 478
+
+ def test_pandas_arrow_udf_compatibility(self):
+ df = (
+ self.spark.read.option("header", "true")
+ .format("csv")
+ .load(chicago_crimes_input_location)
+ .selectExpr("CAST(x AS INT) AS x")
+ )
+
+ sum_value = df.select(f.sum(squared_udf(f.col("x")))).collect()[0][0]
+ assert sum_value == 115578630
diff --git
a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
index 7cfb8670be..d38ad5e1b6 100644
--- a/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
+++ b/spark/common/src/main/scala/org/apache/sedona/spark/SedonaContext.scala
@@ -24,10 +24,12 @@ import org.apache.sedona.sql.RasterRegistrator
import org.apache.sedona.sql.UDF.Catalog
import org.apache.sedona.sql.UDT.UdtRegistrator
import org.apache.spark.serializer.KryoSerializer
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.catalyst.rules.Rule
import org.apache.spark.sql.sedona_sql.optimization._
import org.apache.spark.sql.sedona_sql.strategy.join.JoinQueryDetector
import
org.apache.spark.sql.sedona_sql.strategy.physical.function.EvalPhysicalFunctionStrategy
-import org.apache.spark.sql.{SQLContext, SparkSession}
+import org.apache.spark.sql.{SQLContext, SparkSession, Strategy}
import scala.annotation.StaticAnnotation
import scala.util.Try
@@ -50,6 +52,7 @@ object SedonaContext {
/**
* This is the entry point of the entire Sedona system
+ *
* @param sparkSession
* @return
*/
@@ -64,6 +67,28 @@ object SedonaContext {
sparkSession.experimental.extraStrategies ++= Seq(new
JoinQueryDetector(sparkSession))
}
+ val sedonaArrowStrategy = Try(
+ Class
+ .forName("org.apache.spark.sql.udf.SedonaArrowStrategy")
+ .getDeclaredConstructor()
+ .newInstance()
+ .asInstanceOf[Strategy])
+
+ val extractSedonaUDFRule =
+ Try(
+ Class
+ .forName("org.apache.spark.sql.udf.ExtractSedonaUDFRule")
+ .getDeclaredConstructor()
+ .newInstance()
+ .asInstanceOf[Rule[LogicalPlan]])
+
+ if (sedonaArrowStrategy.isSuccess && extractSedonaUDFRule.isSuccess) {
+ sparkSession.experimental.extraStrategies =
+ sparkSession.experimental.extraStrategies :+ sedonaArrowStrategy.get
+ sparkSession.experimental.extraOptimizations =
+ sparkSession.experimental.extraOptimizations :+
extractSedonaUDFRule.get
+ }
+
customOptimizationsWithSession(sparkSession).foreach { opt =>
if (!sparkSession.experimental.extraOptimizations.exists {
case _: opt.type => true
@@ -95,6 +120,7 @@ object SedonaContext {
* This method adds the basic Sedona configurations to the SparkSession
Usually the user does
* not need to call this method directly This is only needed when the user
needs to manually
* configure Sedona
+ *
* @return
*/
def builder(): SparkSession.Builder = {
diff --git
a/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala
b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala
new file mode 100644
index 0000000000..aece26267d
--- /dev/null
+++ b/spark/common/src/main/scala/org/apache/sedona/sql/UDF/PythonEvalType.scala
@@ -0,0 +1,29 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.sedona.sql.UDF
+
+// We use constant 5000 for Sedona UDFs, 200 is Apache Spark scalar UDF
+object PythonEvalType {
+ val SQL_SCALAR_SEDONA_UDF = 5200
+ val SEDONA_UDF_TYPE_CONSTANT = 5000
+
+ def toString(pythonEvalType: Int): String = pythonEvalType match {
+ case SQL_SCALAR_SEDONA_UDF => "SQL_SCALAR_GEO_UDF"
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala
new file mode 100644
index 0000000000..78e00871ad
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/strategies/SedonaArrowEvalPython.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.strategies
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.{BaseEvalPython,
LogicalPlan}
+
+case class SedonaArrowEvalPython(
+ udfs: Seq[PythonUDF],
+ resultAttrs: Seq[Attribute],
+ child: LogicalPlan,
+ evalType: Int)
+ extends BaseEvalPython {
+ override protected def withNewChildInternal(newChild: LogicalPlan):
SedonaArrowEvalPython =
+ copy(child = newChild)
+}
diff --git
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala
new file mode 100644
index 0000000000..03e10a1602
--- /dev/null
+++
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/ExtractSedonaUDFRule.scala
@@ -0,0 +1,168 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.udf
+
+import org.apache.sedona.sql.UDF.PythonEvalType
+import org.apache.spark.sql.catalyst.expressions.{AttributeReference,
Expression, ExpressionSet, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.{LogicalPlan, Project,
Subquery}
+import org.apache.spark.sql.catalyst.rules.Rule
+import org.apache.spark.sql.catalyst.trees.TreePattern.PYTHON_UDF
+
+import scala.collection.mutable
+
+// That rule extracts scalar Python UDFs, currently Apache Spark has
+// assert on types which blocks using the vectorized udfs with geometry type
+class ExtractSedonaUDFRule extends Rule[LogicalPlan] {
+
+ private def hasScalarPythonUDF(e: Expression): Boolean = {
+ e.exists(PythonUDF.isScalarPythonUDF)
+ }
+
+ @scala.annotation.tailrec
+ private def canEvaluateInPython(e: PythonUDF): Boolean = {
+ e.children match {
+ case Seq(u: PythonUDF) => e.evalType == u.evalType &&
canEvaluateInPython(u)
+ case children => !children.exists(hasScalarPythonUDF)
+ }
+ }
+
+ def isScalarPythonUDF(e: Expression): Boolean = {
+ e.isInstanceOf[PythonUDF] && e
+ .asInstanceOf[PythonUDF]
+ .evalType == PythonEvalType.SQL_SCALAR_SEDONA_UDF
+ }
+
+ private def collectEvaluableUDFsFromExpressions(
+ expressions: Seq[Expression]): Seq[PythonUDF] = {
+
+ var firstVisitedScalarUDFEvalType: Option[Int] = None
+
+ def canChainUDF(evalType: Int): Boolean = {
+ evalType == firstVisitedScalarUDFEvalType.get
+ }
+
+ def collectEvaluableUDFs(expr: Expression): Seq[PythonUDF] = expr match {
+ case udf: PythonUDF
+ if isScalarPythonUDF(udf) && canEvaluateInPython(udf)
+ && firstVisitedScalarUDFEvalType.isEmpty =>
+ firstVisitedScalarUDFEvalType = Some(udf.evalType)
+ Seq(udf)
+ case udf: PythonUDF
+ if isScalarPythonUDF(udf) && canEvaluateInPython(udf)
+ && canChainUDF(udf.evalType) =>
+ Seq(udf)
+ case e => e.children.flatMap(collectEvaluableUDFs)
+ }
+
+ expressions.flatMap(collectEvaluableUDFs)
+ }
+
+ def apply(plan: LogicalPlan): LogicalPlan = plan match {
+ case s: Subquery if s.correlated => plan
+
+ case _ =>
+ plan.transformUpWithPruning(_.containsPattern(PYTHON_UDF)) {
+ case p: SedonaArrowEvalPython => p
+
+ case plan: LogicalPlan => extract(plan)
+ }
+ }
+
+ private def canonicalizeDeterministic(u: PythonUDF) = {
+ if (u.deterministic) {
+ u.canonicalized.asInstanceOf[PythonUDF]
+ } else {
+ u
+ }
+ }
+
+ private def extract(plan: LogicalPlan): LogicalPlan = {
+ val udfs =
ExpressionSet(collectEvaluableUDFsFromExpressions(plan.expressions))
+ .filter(udf => udf.references.subsetOf(plan.inputSet))
+ .toSeq
+ .asInstanceOf[Seq[PythonUDF]]
+
+ udfs match {
+ case Seq() => plan
+ case _ => resolveUDFs(plan, udfs)
+ }
+ }
+
+ def resolveUDFs(plan: LogicalPlan, udfs: Seq[PythonUDF]): LogicalPlan = {
+ val attributeMap = mutable.HashMap[PythonUDF, Expression]()
+
+ val newChildren = adjustAttributeMap(plan, udfs, attributeMap)
+
+
udfs.map(canonicalizeDeterministic).filterNot(attributeMap.contains).foreach {
udf =>
+ throw new IllegalStateException(
+ s"Invalid PythonUDF $udf, requires attributes from more than one
child.")
+ }
+
+ val rewritten = plan.withNewChildren(newChildren).transformExpressions {
case p: PythonUDF =>
+ attributeMap.getOrElse(canonicalizeDeterministic(p), p)
+ }
+
+ val newPlan = extract(rewritten)
+ if (newPlan.output != plan.output) {
+ Project(plan.output, newPlan)
+ } else {
+ newPlan
+ }
+ }
+
+ def adjustAttributeMap(
+ plan: LogicalPlan,
+ udfs: Seq[PythonUDF],
+ attributeMap: mutable.HashMap[PythonUDF, Expression]): Seq[LogicalPlan]
= {
+ plan.children.map { child =>
+ val validUdfs = udfs.filter { udf =>
+ udf.references.subsetOf(child.outputSet)
+ }
+
+ if (validUdfs.nonEmpty) {
+ require(
+ validUdfs.forall(isScalarPythonUDF),
+ "Can only extract scalar vectorized udf or sql batch udf")
+
+ val resultAttrs = validUdfs.zipWithIndex.map { case (u, i) =>
+ AttributeReference(s"pythonUDF$i", u.dataType)()
+ }
+
+ val evalTypes = validUdfs.map(_.evalType).toSet
+ if (evalTypes.size != 1) {
+ throw new IllegalStateException(
+ "Expected udfs have the same evalType but got different evalTypes:
" +
+ evalTypes.mkString(","))
+ }
+ val evalType = evalTypes.head
+ val evaluation = evalType match {
+ case PythonEvalType.SQL_SCALAR_SEDONA_UDF =>
+ SedonaArrowEvalPython(validUdfs, resultAttrs, child, evalType)
+ case _ =>
+ throw new IllegalStateException("Unexpected UDF evalType")
+ }
+
+ attributeMap ++=
validUdfs.map(canonicalizeDeterministic).zip(resultAttrs)
+ evaluation
+ } else {
+ child
+ }
+ }
+ }
+}
diff --git
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala
new file mode 100644
index 0000000000..7600ece507
--- /dev/null
+++
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowEvalPython.scala
@@ -0,0 +1,32 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.udf
+
+import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.{BaseEvalPython,
LogicalPlan}
+
+case class SedonaArrowEvalPython(
+ udfs: Seq[PythonUDF],
+ resultAttrs: Seq[Attribute],
+ child: LogicalPlan,
+ evalType: Int)
+ extends BaseEvalPython {
+ override protected def withNewChildInternal(newChild: LogicalPlan):
SedonaArrowEvalPython =
+ copy(child = newChild)
+}
diff --git
a/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
new file mode 100644
index 0000000000..a403fa6b9e
--- /dev/null
+++
b/spark/spark-3.5/src/main/scala/org/apache/spark/sql/udf/SedonaArrowStrategy.scala
@@ -0,0 +1,89 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.udf
+
+import org.apache.sedona.sql.UDF.PythonEvalType
+import org.apache.spark.api.python.ChainedPythonFunctions
+import org.apache.spark.{JobArtifactSet, TaskContext}
+import org.apache.spark.sql.Strategy
+import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.sql.catalyst.expressions.{Attribute, PythonUDF}
+import org.apache.spark.sql.catalyst.plans.logical.LogicalPlan
+import org.apache.spark.sql.execution.SparkPlan
+import org.apache.spark.sql.execution.python.{ArrowPythonRunner,
BatchIterator, EvalPythonExec, PythonSQLMetrics}
+import org.apache.spark.sql.types.StructType
+
+import scala.collection.JavaConverters.asScalaIteratorConverter
+
+// We use custom Strategy to avoid Apache Spark assert on types, we
+// can consider extending this to support other engines working with
+// arrow data
+class SedonaArrowStrategy extends Strategy {
+ override def apply(plan: LogicalPlan): Seq[SparkPlan] = plan match {
+ case SedonaArrowEvalPython(udfs, output, child, evalType) =>
+ SedonaArrowEvalPythonExec(udfs, output, planLater(child), evalType) ::
Nil
+ case _ => Nil
+ }
+}
+
+// It's modification og Apache Spark's ArrowEvalPythonExec, we remove the
check on the types to allow geometry types
+// here, it's initial version to allow the vectorized udf for Sedona geometry
types. We can consider extending this
+// to support other engines working with arrow data
+case class SedonaArrowEvalPythonExec(
+ udfs: Seq[PythonUDF],
+ resultAttrs: Seq[Attribute],
+ child: SparkPlan,
+ evalType: Int)
+ extends EvalPythonExec
+ with PythonSQLMetrics {
+
+ private val batchSize = conf.arrowMaxRecordsPerBatch
+ private val sessionLocalTimeZone = conf.sessionLocalTimeZone
+ private val largeVarTypes = conf.arrowUseLargeVarTypes
+ private val pythonRunnerConf = ArrowPythonRunner.getPythonRunnerConfMap(conf)
+ private[this] val jobArtifactUUID =
JobArtifactSet.getCurrentJobArtifactState.map(_.uuid)
+
+ protected override def evaluate(
+ funcs: Seq[ChainedPythonFunctions],
+ argOffsets: Array[Array[Int]],
+ iter: Iterator[InternalRow],
+ schema: StructType,
+ context: TaskContext): Iterator[InternalRow] = {
+
+ val batchIter = if (batchSize > 0) new BatchIterator(iter, batchSize) else
Iterator(iter)
+
+ val columnarBatchIter = new ArrowPythonRunner(
+ funcs,
+ evalType - PythonEvalType.SEDONA_UDF_TYPE_CONSTANT,
+ argOffsets,
+ schema,
+ sessionLocalTimeZone,
+ largeVarTypes,
+ pythonRunnerConf,
+ pythonMetrics,
+ jobArtifactUUID).compute(batchIter, context.partitionId(), context)
+
+ columnarBatchIter.flatMap { batch =>
+ batch.rowIterator.asScala
+ }
+ }
+
+ override protected def withNewChildInternal(newChild: SparkPlan): SparkPlan =
+ copy(child = newChild)
+}
diff --git
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
new file mode 100644
index 0000000000..adbb97819f
--- /dev/null
+++
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/StrategySuite.scala
@@ -0,0 +1,67 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.udf
+
+import org.apache.sedona.spark.SedonaContext
+import org.apache.spark.sql.SparkSession
+import org.apache.spark.sql.functions.col
+import org.apache.spark.sql.udf.ScalarUDF.geoPandasScalaFunction
+import org.locationtech.jts.io.WKTReader
+import org.scalatest.funsuite.AnyFunSuite
+import org.scalatest.matchers.should.Matchers
+
+class StrategySuite extends AnyFunSuite with Matchers {
+ val wktReader = new WKTReader()
+
+ val spark: SparkSession = {
+ val builder = SedonaContext
+ .builder()
+ .master("local[*]")
+ .appName("sedonasqlScalaTest")
+
+ val spark = SedonaContext.create(builder.getOrCreate())
+
+ spark.sparkContext.setLogLevel("ALL")
+ spark
+ }
+
+ import spark.implicits._
+
+ test("sedona geospatial UDF") {
+ val df = Seq(
+ (1, "value", wktReader.read("POINT(21 52)")),
+ (2, "value1", wktReader.read("POINT(20 50)")),
+ (3, "value2", wktReader.read("POINT(20 49)")),
+ (4, "value3", wktReader.read("POINT(20 48)")),
+ (5, "value4", wktReader.read("POINT(20 47)")))
+ .toDF("id", "value", "geom")
+ .withColumn("geom_buffer", geoPandasScalaFunction(col("geom")))
+
+ df.count shouldEqual 5
+
+ df.selectExpr("ST_AsText(ST_ReducePrecision(geom_buffer, 2))")
+ .as[String]
+ .collect() should contain theSameElementsAs Seq(
+ "POLYGON ((20 51, 20 53, 22 53, 22 51, 20 51))",
+ "POLYGON ((19 49, 19 51, 21 51, 21 49, 19 49))",
+ "POLYGON ((19 48, 19 50, 21 50, 21 48, 19 48))",
+ "POLYGON ((19 47, 19 49, 21 49, 21 47, 19 47))",
+ "POLYGON ((19 46, 19 48, 21 48, 21 46, 19 46))")
+ }
+}
diff --git
a/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
new file mode 100644
index 0000000000..c0a2d8f260
--- /dev/null
+++
b/spark/spark-3.5/src/test/scala/org/apache/spark/sql/udf/TestScalarPandasUDF.scala
@@ -0,0 +1,122 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.udf
+
+import org.apache.sedona.sql.UDF
+import org.apache.spark.TestUtils
+import org.apache.spark.api.python._
+import org.apache.spark.broadcast.Broadcast
+import org.apache.spark.sql.execution.python.UserDefinedPythonFunction
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
+import org.apache.spark.util.Utils
+
+import java.io.File
+import java.nio.file.{Files, Paths}
+import scala.sys.process.Process
+import scala.jdk.CollectionConverters._
+
+object ScalarUDF {
+
+ val pythonExec: String = {
+ val pythonExec =
+ sys.env.getOrElse("PYSPARK_DRIVER_PYTHON",
sys.env.getOrElse("PYSPARK_PYTHON", "python3"))
+ if (TestUtils.testCommandAvailable(pythonExec)) {
+ pythonExec
+ } else {
+ "python"
+ }
+ }
+
+ private[spark] lazy val pythonPath = sys.env.getOrElse("PYTHONPATH", "")
+ protected lazy val sparkHome: String = {
+ sys.props.getOrElse("spark.test.home", sys.env("SPARK_HOME"))
+ }
+
+ private lazy val py4jPath =
+ Paths.get(sparkHome, "python", "lib",
PythonUtils.PY4J_ZIP_NAME).toAbsolutePath
+ private[spark] lazy val pysparkPythonPath = s"$py4jPath"
+
+ private lazy val isPythonAvailable: Boolean =
TestUtils.testCommandAvailable(pythonExec)
+
+ lazy val pythonVer: String = if (isPythonAvailable) {
+ Process(
+ Seq(pythonExec, "-c", "import sys; print('%d.%d' %
sys.version_info[:2])"),
+ None,
+ "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!.trim()
+ } else {
+ throw new RuntimeException(s"Python executable [$pythonExec] is
unavailable.")
+ }
+
+ protected def withTempPath(f: File => Unit): Unit = {
+ val path = Utils.createTempDir()
+ path.delete()
+ try f(path)
+ finally Utils.deleteRecursively(path)
+ }
+
+ val pandasFunc: Array[Byte] = {
+ var binaryPandasFunc: Array[Byte] = null
+ withTempPath { path =>
+ println(path)
+ Process(
+ Seq(
+ pythonExec,
+ "-c",
+ f"""
+ |from pyspark.sql.types import IntegerType
+ |from shapely.geometry import Point
+ |from sedona.sql.types import GeometryType
+ |from pyspark.serializers import CloudPickleSerializer
+ |from sedona.utils import geometry_serde
+ |from shapely import box
+ |f = open('$path', 'wb');
+ |def w(x):
+ | def apply_function(w):
+ | geom, offset = geometry_serde.deserialize(w)
+ | bounds = geom.buffer(1).bounds
+ | x = box(*bounds)
+ | return geometry_serde.serialize(x)
+ | return x.apply(apply_function)
+ |f.write(CloudPickleSerializer().dumps((w, GeometryType())))
+ |""".stripMargin),
+ None,
+ "PYTHONPATH" -> s"$pysparkPythonPath:$pythonPath").!!
+ binaryPandasFunc = Files.readAllBytes(path.toPath)
+ }
+ assert(binaryPandasFunc != null)
+ binaryPandasFunc
+ }
+
+ private val workerEnv = new java.util.HashMap[String, String]()
+ workerEnv.put("PYTHONPATH", s"$pysparkPythonPath:$pythonPath")
+
+ val geoPandasScalaFunction: UserDefinedPythonFunction =
UserDefinedPythonFunction(
+ name = "geospatial_udf",
+ func = SimplePythonFunction(
+ command = pandasFunc,
+ envVars = workerEnv.clone().asInstanceOf[java.util.Map[String, String]],
+ pythonIncludes = List.empty[String].asJava,
+ pythonExec = pythonExec,
+ pythonVer = pythonVer,
+ broadcastVars = List.empty[Broadcast[PythonBroadcast]].asJava,
+ accumulator = null),
+ dataType = GeometryUDT,
+ pythonEvalType = UDF.PythonEvalType.SQL_SCALAR_SEDONA_UDF,
+ udfDeterministic = true)
+}