This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new ebd6f67c50 [SEDONA-714] Add geopandas to spark arrow conversion.
(#1825)
ebd6f67c50 is described below
commit ebd6f67c5029554055db6b81ffba5d3ecaad3b62
Author: Paweł Tokaj <[email protected]>
AuthorDate: Wed Feb 26 01:31:59 2025 +0100
[SEDONA-714] Add geopandas to spark arrow conversion. (#1825)
* SEDONA-714 Add geopandas to spark arrow conversion.
* SEDONA-714 Add geopandas to spark arrow conversion.
* SEDONA-714 Add geopandas to spark arrow conversion.
* SEDONA-714 Add geopandas to spark arrow conversion.
* SEDONA-714 Add geopandas to spark arrow conversion.
* Update python/sedona/utils/geoarrow.py
Co-authored-by: Dewey Dunnington <[email protected]>
* SEDONA-714 Add geopandas to spark arrow conversion.
* SEDONA-714 Add docs.
* SEDONA-714 Add docs.
---------
Co-authored-by: Dewey Dunnington <[email protected]>
---
Makefile | 4 +
docker/docs/Dockerfile | 8 ++
docs/tutorial/geopandas-shapely.md | 19 +++
python/sedona/utils/geoarrow.py | 139 ++++++++++++++++++++-
.../test_arrow_conversion_geopandas_to_sedona.py | 94 ++++++++++++++
5 files changed, 263 insertions(+), 1 deletion(-)
diff --git a/Makefile b/Makefile
index 274707936a..4034de4d7f 100644
--- a/Makefile
+++ b/Makefile
@@ -65,3 +65,7 @@ clean:
rm -rf __pycache__
rm -rf .mypy_cache
rm -rf .pytest_cache
+
+run-docs:
+ docker build -f docker/docs/Dockerfile -t mkdocs-sedona .
+ docker run --rm -it -p 8000:8000 -v ${PWD}:/docs mkdocs-sedona
diff --git a/docker/docs/Dockerfile b/docker/docs/Dockerfile
new file mode 100644
index 0000000000..a97ae8d226
--- /dev/null
+++ b/docker/docs/Dockerfile
@@ -0,0 +1,8 @@
+FROM squidfunk/mkdocs-material:9.6
+
+RUN apk update
+RUN apk add gcc musl-dev linux-headers
+RUN pip install mkdocs-macros-plugin \
+ mkdocs-git-revision-date-localized-plugin \
+ mkdocs-jupyter \
+ mike
diff --git a/docs/tutorial/geopandas-shapely.md
b/docs/tutorial/geopandas-shapely.md
index bc286229b3..c3fcbf2dc3 100644
--- a/docs/tutorial/geopandas-shapely.md
+++ b/docs/tutorial/geopandas-shapely.md
@@ -67,6 +67,25 @@ This query will show the following outputs:
```
+To leverage Arrow optimization and speed up the conversion, you can use the
`create_spatial_dataframe`
+that takes a SparkSession and GeoDataFrame as parameters and returns a Sedona
DataFrame.
+
+```python
+def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) ->
DataFrame
+```
+
+- spark: SparkSession
+- gdf: gpd.GeoDataFrame
+- return: DataFrame
+
+Example:
+
+```python
+from sedona.utils.geoarrow import create_spatial_dataframe
+
+create_spatial_dataframe(spark, gdf)
+```
+
### From Sedona DataFrame to GeoPandas
Reading data with Spark and converting to GeoPandas
diff --git a/python/sedona/utils/geoarrow.py b/python/sedona/utils/geoarrow.py
index b8ade8528b..353b4ff7f8 100644
--- a/python/sedona/utils/geoarrow.py
+++ b/python/sedona/utils/geoarrow.py
@@ -14,13 +14,24 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import itertools
+from typing import List, Callable
# We may be able to achieve streaming rather than complete materialization by
using
# with the ArrowStreamSerializer (instead of the ArrowCollectSerializer)
-from sedona.sql.types import GeometryType
from sedona.sql.st_functions import ST_AsEWKB
+from pyspark.sql import SparkSession
+from pyspark.sql import DataFrame
+from pyspark.sql.types import StructType, StructField, DataType, ArrayType,
MapType
+
+from sedona.sql.types import GeometryType
+import geopandas as gpd
+from pyspark.sql.pandas.types import (
+ from_arrow_type,
+)
+from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
def dataframe_to_arrow(df, crs=None):
@@ -186,3 +197,129 @@ def unique_srid_from_ewkb(obj):
import pyproj
return pyproj.CRS(f"EPSG:{epsg_code}")
+
+
+def _dedup_names(names: List[str]) -> List[str]:
+ if len(set(names)) == len(names):
+ return names
+ else:
+
+ def _gen_dedup(_name: str) -> Callable[[], str]:
+ _i = itertools.count()
+ return lambda: f"{_name}_{next(_i)}"
+
+ def _gen_identity(_name: str) -> Callable[[], str]:
+ return lambda: _name
+
+ gen_new_name = {
+ name: _gen_dedup(name) if len(list(group)) > 1 else
_gen_identity(name)
+ for name, group in itertools.groupby(sorted(names))
+ }
+ return [gen_new_name[name]() for name in names]
+
+
+# Backport from Spark 4.0
+#
https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/types.py#L1385
+def _deduplicate_field_names(dt: DataType) -> DataType:
+ if isinstance(dt, StructType):
+ dedup_field_names = _dedup_names(dt.names)
+
+ return StructType(
+ [
+ StructField(
+ dedup_field_names[i],
+ _deduplicate_field_names(field.dataType),
+ nullable=field.nullable,
+ )
+ for i, field in enumerate(dt.fields)
+ ]
+ )
+ elif isinstance(dt, ArrayType):
+ return ArrayType(
+ _deduplicate_field_names(dt.elementType),
containsNull=dt.containsNull
+ )
+ elif isinstance(dt, MapType):
+ return MapType(
+ _deduplicate_field_names(dt.keyType),
+ _deduplicate_field_names(dt.valueType),
+ valueContainsNull=dt.valueContainsNull,
+ )
+ else:
+ return dt
+
+
+def infer_schema(gdf: gpd.GeoDataFrame) -> StructType:
+ import pyarrow as pa
+
+ fields = gdf.dtypes.reset_index().values.tolist()
+ geom_fields = []
+ index = 0
+ for name, dtype in fields:
+ if dtype == "geometry":
+ geom_fields.append((index, name))
+ continue
+
+ index += 1
+
+ if not geom_fields:
+ raise ValueError("No geometry field found in the GeoDataFrame")
+
+ pa_schema = pa.Schema.from_pandas(
+ gdf.drop([name for _, name in geom_fields], axis=1)
+ )
+
+ spark_schema = []
+
+ for field in pa_schema:
+ field_type = field.type
+ spark_type = from_arrow_type(field_type)
+ spark_schema.append(StructField(field.name, spark_type, True))
+
+ for index, geom_field in geom_fields:
+ spark_schema.insert(index, StructField(geom_field, GeometryType(),
True))
+
+ return StructType(spark_schema)
+
+
+# Modified backport from Spark 4.0
+#
https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/conversion.py#L632
+def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) ->
DataFrame:
+ from pyspark.sql.pandas.types import (
+ to_arrow_type,
+ )
+
+ def reader_func(temp_filename):
+ return spark._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)
+
+ def create_iter_server():
+ return spark._jvm.ArrowIteratorServer()
+
+ schema = infer_schema(gdf)
+ timezone = spark._jconf.sessionLocalTimeZone()
+ step = spark._jconf.arrowMaxRecordsPerBatch()
+ step = step if step > 0 else len(gdf)
+ pdf_slices = (gdf.iloc[start : start + step] for start in range(0,
len(gdf), step))
+ spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields]
+
+ arrow_data = [
+ [
+ (c, to_arrow_type(t) if t is not None else None, t)
+ for (_, c), t in zip(pdf_slice.items(), spark_types)
+ ]
+ for pdf_slice in pdf_slices
+ ]
+
+ safecheck = spark._jconf.arrowSafeTypeConversion()
+ ser = ArrowStreamPandasSerializer(timezone, safecheck)
+ jiter = spark._sc._serialize_to_jvm(
+ arrow_data, ser, reader_func, create_iter_server
+ )
+
+ jsparkSession = spark._jsparkSession
+ jdf = spark._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(),
jsparkSession)
+
+ df = DataFrame(jdf, spark)
+
+ df._schema = schema
+
+ return df
diff --git a/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py
b/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py
new file mode 100644
index 0000000000..46410acbfa
--- /dev/null
+++ b/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py
@@ -0,0 +1,94 @@
+import pytest
+
+from sedona.sql.types import GeometryType
+from sedona.utils.geoarrow import create_spatial_dataframe
+from tests.test_base import TestBase
+import geopandas as gpd
+import pyspark
+
+
+class TestGeopandasToSedonaWithArrow(TestBase):
+
+ @pytest.mark.skipif(
+ not pyspark.__version__.startswith("3.5"),
+ reason="It's only working with Spark 3.5",
+ )
+ def test_conversion_dataframe(self):
+ gdf = gpd.GeoDataFrame(
+ {
+ "name": ["Sedona", "Apache"],
+ "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+ }
+ )
+
+ df = create_spatial_dataframe(self.spark, gdf)
+
+ assert df.count() == 2
+ assert df.columns == ["name", "geometry"]
+ assert df.schema["geometry"].dataType == GeometryType()
+
+ @pytest.mark.skipif(
+ not pyspark.__version__.startswith("3.5"),
+ reason="It's only working with Spark 3.5",
+ )
+ def test_different_geometry_positions(self):
+ gdf = gpd.GeoDataFrame(
+ {
+ "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+ "name": ["Sedona", "Apache"],
+ }
+ )
+
+ gdf2 = gpd.GeoDataFrame(
+ {
+ "name": ["Sedona", "Apache"],
+ "name1": ["Sedona", "Apache"],
+ "name2": ["Sedona", "Apache"],
+ "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+ }
+ )
+
+ df1 = create_spatial_dataframe(self.spark, gdf)
+ df2 = create_spatial_dataframe(self.spark, gdf2)
+
+ assert df1.count() == 2
+ assert df1.columns == ["geometry", "name"]
+ assert df1.schema["geometry"].dataType == GeometryType()
+
+ assert df2.count() == 2
+ assert df2.columns == ["name", "name1", "name2", "geometry"]
+ assert df2.schema["geometry"].dataType == GeometryType()
+
+ @pytest.mark.skipif(
+ not pyspark.__version__.startswith("3.5"),
+ reason="It's only working with Spark 3.5",
+ )
+ def test_multiple_geometry_columns(self):
+ gdf = gpd.GeoDataFrame(
+ {
+ "name": ["Sedona", "Apache"],
+ "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+ "geometry2": gpd.points_from_xy([0, 1], [0, 1]),
+ }
+ )
+
+ df = create_spatial_dataframe(self.spark, gdf)
+
+ assert df.count() == 2
+ assert df.columns == ["name", "geometry2", "geometry"]
+ assert df.schema["geometry"].dataType == GeometryType()
+ assert df.schema["geometry2"].dataType == GeometryType()
+
+ @pytest.mark.skipif(
+ not pyspark.__version__.startswith("3.5"),
+ reason="It's only working with Spark 3.5",
+ )
+ def test_missing_geometry_column(self):
+ gdf = gpd.GeoDataFrame(
+ {
+ "name": ["Sedona", "Apache"],
+ },
+ )
+
+ with pytest.raises(ValueError):
+ create_spatial_dataframe(self.spark, gdf)