This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new ebd6f67c50 [SEDONA-714] Add geopandas to spark arrow conversion. 
(#1825)
ebd6f67c50 is described below

commit ebd6f67c5029554055db6b81ffba5d3ecaad3b62
Author: PaweÅ‚ Tokaj <[email protected]>
AuthorDate: Wed Feb 26 01:31:59 2025 +0100

    [SEDONA-714] Add geopandas to spark arrow conversion. (#1825)
    
    * SEDONA-714 Add geopandas to spark arrow conversion.
    
    * SEDONA-714 Add geopandas to spark arrow conversion.
    
    * SEDONA-714 Add geopandas to spark arrow conversion.
    
    * SEDONA-714 Add geopandas to spark arrow conversion.
    
    * SEDONA-714 Add geopandas to spark arrow conversion.
    
    * Update python/sedona/utils/geoarrow.py
    
    Co-authored-by: Dewey Dunnington <[email protected]>
    
    * SEDONA-714 Add geopandas to spark arrow conversion.
    
    * SEDONA-714 Add docs.
    
    * SEDONA-714 Add docs.
    
    ---------
    
    Co-authored-by: Dewey Dunnington <[email protected]>
---
 Makefile                                           |   4 +
 docker/docs/Dockerfile                             |   8 ++
 docs/tutorial/geopandas-shapely.md                 |  19 +++
 python/sedona/utils/geoarrow.py                    | 139 ++++++++++++++++++++-
 .../test_arrow_conversion_geopandas_to_sedona.py   |  94 ++++++++++++++
 5 files changed, 263 insertions(+), 1 deletion(-)

diff --git a/Makefile b/Makefile
index 274707936a..4034de4d7f 100644
--- a/Makefile
+++ b/Makefile
@@ -65,3 +65,7 @@ clean:
        rm -rf __pycache__
        rm -rf .mypy_cache
        rm -rf .pytest_cache
+
+run-docs:
+       docker build -f docker/docs/Dockerfile -t mkdocs-sedona .
+       docker run --rm -it -p 8000:8000 -v ${PWD}:/docs mkdocs-sedona
diff --git a/docker/docs/Dockerfile b/docker/docs/Dockerfile
new file mode 100644
index 0000000000..a97ae8d226
--- /dev/null
+++ b/docker/docs/Dockerfile
@@ -0,0 +1,8 @@
+FROM squidfunk/mkdocs-material:9.6
+
+RUN apk update
+RUN apk add gcc musl-dev linux-headers
+RUN pip install mkdocs-macros-plugin \
+    mkdocs-git-revision-date-localized-plugin \
+    mkdocs-jupyter \
+    mike
diff --git a/docs/tutorial/geopandas-shapely.md 
b/docs/tutorial/geopandas-shapely.md
index bc286229b3..c3fcbf2dc3 100644
--- a/docs/tutorial/geopandas-shapely.md
+++ b/docs/tutorial/geopandas-shapely.md
@@ -67,6 +67,25 @@ This query will show the following outputs:
 
 ```
 
+To leverage Arrow optimization and speed up the conversion, you can use the 
`create_spatial_dataframe`
+that takes a SparkSession and GeoDataFrame as parameters and returns a Sedona 
DataFrame.
+
+```python
+def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) -> 
DataFrame
+```
+
+- spark: SparkSession
+- gdf: gpd.GeoDataFrame
+- return: DataFrame
+
+Example:
+
+```python
+from sedona.utils.geoarrow import create_spatial_dataframe
+
+create_spatial_dataframe(spark, gdf)
+```
+
 ### From Sedona DataFrame to GeoPandas
 
 Reading data with Spark and converting to GeoPandas
diff --git a/python/sedona/utils/geoarrow.py b/python/sedona/utils/geoarrow.py
index b8ade8528b..353b4ff7f8 100644
--- a/python/sedona/utils/geoarrow.py
+++ b/python/sedona/utils/geoarrow.py
@@ -14,13 +14,24 @@
 #  KIND, either express or implied.  See the License for the
 #  specific language governing permissions and limitations
 #  under the License.
+import itertools
+from typing import List, Callable
 
 # We may be able to achieve streaming rather than complete materialization by 
using
 # with the ArrowStreamSerializer (instead of the ArrowCollectSerializer)
 
 
-from sedona.sql.types import GeometryType
 from sedona.sql.st_functions import ST_AsEWKB
+from pyspark.sql import SparkSession
+from pyspark.sql import DataFrame
+from pyspark.sql.types import StructType, StructField, DataType, ArrayType, 
MapType
+
+from sedona.sql.types import GeometryType
+import geopandas as gpd
+from pyspark.sql.pandas.types import (
+    from_arrow_type,
+)
+from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer
 
 
 def dataframe_to_arrow(df, crs=None):
@@ -186,3 +197,129 @@ def unique_srid_from_ewkb(obj):
     import pyproj
 
     return pyproj.CRS(f"EPSG:{epsg_code}")
+
+
+def _dedup_names(names: List[str]) -> List[str]:
+    if len(set(names)) == len(names):
+        return names
+    else:
+
+        def _gen_dedup(_name: str) -> Callable[[], str]:
+            _i = itertools.count()
+            return lambda: f"{_name}_{next(_i)}"
+
+        def _gen_identity(_name: str) -> Callable[[], str]:
+            return lambda: _name
+
+        gen_new_name = {
+            name: _gen_dedup(name) if len(list(group)) > 1 else 
_gen_identity(name)
+            for name, group in itertools.groupby(sorted(names))
+        }
+        return [gen_new_name[name]() for name in names]
+
+
+# Backport from Spark 4.0
+# 
https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/types.py#L1385
+def _deduplicate_field_names(dt: DataType) -> DataType:
+    if isinstance(dt, StructType):
+        dedup_field_names = _dedup_names(dt.names)
+
+        return StructType(
+            [
+                StructField(
+                    dedup_field_names[i],
+                    _deduplicate_field_names(field.dataType),
+                    nullable=field.nullable,
+                )
+                for i, field in enumerate(dt.fields)
+            ]
+        )
+    elif isinstance(dt, ArrayType):
+        return ArrayType(
+            _deduplicate_field_names(dt.elementType), 
containsNull=dt.containsNull
+        )
+    elif isinstance(dt, MapType):
+        return MapType(
+            _deduplicate_field_names(dt.keyType),
+            _deduplicate_field_names(dt.valueType),
+            valueContainsNull=dt.valueContainsNull,
+        )
+    else:
+        return dt
+
+
+def infer_schema(gdf: gpd.GeoDataFrame) -> StructType:
+    import pyarrow as pa
+
+    fields = gdf.dtypes.reset_index().values.tolist()
+    geom_fields = []
+    index = 0
+    for name, dtype in fields:
+        if dtype == "geometry":
+            geom_fields.append((index, name))
+            continue
+
+        index += 1
+
+    if not geom_fields:
+        raise ValueError("No geometry field found in the GeoDataFrame")
+
+    pa_schema = pa.Schema.from_pandas(
+        gdf.drop([name for _, name in geom_fields], axis=1)
+    )
+
+    spark_schema = []
+
+    for field in pa_schema:
+        field_type = field.type
+        spark_type = from_arrow_type(field_type)
+        spark_schema.append(StructField(field.name, spark_type, True))
+
+    for index, geom_field in geom_fields:
+        spark_schema.insert(index, StructField(geom_field, GeometryType(), 
True))
+
+    return StructType(spark_schema)
+
+
+# Modified backport from Spark 4.0
+# 
https://github.com/apache/spark/blob/3515b207c41d78194d11933cd04bddc21f8418dd/python/pyspark/sql/pandas/conversion.py#L632
+def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) -> 
DataFrame:
+    from pyspark.sql.pandas.types import (
+        to_arrow_type,
+    )
+
+    def reader_func(temp_filename):
+        return spark._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename)
+
+    def create_iter_server():
+        return spark._jvm.ArrowIteratorServer()
+
+    schema = infer_schema(gdf)
+    timezone = spark._jconf.sessionLocalTimeZone()
+    step = spark._jconf.arrowMaxRecordsPerBatch()
+    step = step if step > 0 else len(gdf)
+    pdf_slices = (gdf.iloc[start : start + step] for start in range(0, 
len(gdf), step))
+    spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields]
+
+    arrow_data = [
+        [
+            (c, to_arrow_type(t) if t is not None else None, t)
+            for (_, c), t in zip(pdf_slice.items(), spark_types)
+        ]
+        for pdf_slice in pdf_slices
+    ]
+
+    safecheck = spark._jconf.arrowSafeTypeConversion()
+    ser = ArrowStreamPandasSerializer(timezone, safecheck)
+    jiter = spark._sc._serialize_to_jvm(
+        arrow_data, ser, reader_func, create_iter_server
+    )
+
+    jsparkSession = spark._jsparkSession
+    jdf = spark._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), 
jsparkSession)
+
+    df = DataFrame(jdf, spark)
+
+    df._schema = schema
+
+    return df
diff --git a/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py 
b/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py
new file mode 100644
index 0000000000..46410acbfa
--- /dev/null
+++ b/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py
@@ -0,0 +1,94 @@
+import pytest
+
+from sedona.sql.types import GeometryType
+from sedona.utils.geoarrow import create_spatial_dataframe
+from tests.test_base import TestBase
+import geopandas as gpd
+import pyspark
+
+
+class TestGeopandasToSedonaWithArrow(TestBase):
+
+    @pytest.mark.skipif(
+        not pyspark.__version__.startswith("3.5"),
+        reason="It's only working with Spark 3.5",
+    )
+    def test_conversion_dataframe(self):
+        gdf = gpd.GeoDataFrame(
+            {
+                "name": ["Sedona", "Apache"],
+                "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+            }
+        )
+
+        df = create_spatial_dataframe(self.spark, gdf)
+
+        assert df.count() == 2
+        assert df.columns == ["name", "geometry"]
+        assert df.schema["geometry"].dataType == GeometryType()
+
+    @pytest.mark.skipif(
+        not pyspark.__version__.startswith("3.5"),
+        reason="It's only working with Spark 3.5",
+    )
+    def test_different_geometry_positions(self):
+        gdf = gpd.GeoDataFrame(
+            {
+                "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+                "name": ["Sedona", "Apache"],
+            }
+        )
+
+        gdf2 = gpd.GeoDataFrame(
+            {
+                "name": ["Sedona", "Apache"],
+                "name1": ["Sedona", "Apache"],
+                "name2": ["Sedona", "Apache"],
+                "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+            }
+        )
+
+        df1 = create_spatial_dataframe(self.spark, gdf)
+        df2 = create_spatial_dataframe(self.spark, gdf2)
+
+        assert df1.count() == 2
+        assert df1.columns == ["geometry", "name"]
+        assert df1.schema["geometry"].dataType == GeometryType()
+
+        assert df2.count() == 2
+        assert df2.columns == ["name", "name1", "name2", "geometry"]
+        assert df2.schema["geometry"].dataType == GeometryType()
+
+    @pytest.mark.skipif(
+        not pyspark.__version__.startswith("3.5"),
+        reason="It's only working with Spark 3.5",
+    )
+    def test_multiple_geometry_columns(self):
+        gdf = gpd.GeoDataFrame(
+            {
+                "name": ["Sedona", "Apache"],
+                "geometry": gpd.points_from_xy([0, 1], [0, 1]),
+                "geometry2": gpd.points_from_xy([0, 1], [0, 1]),
+            }
+        )
+
+        df = create_spatial_dataframe(self.spark, gdf)
+
+        assert df.count() == 2
+        assert df.columns == ["name", "geometry2", "geometry"]
+        assert df.schema["geometry"].dataType == GeometryType()
+        assert df.schema["geometry2"].dataType == GeometryType()
+
+    @pytest.mark.skipif(
+        not pyspark.__version__.startswith("3.5"),
+        reason="It's only working with Spark 3.5",
+    )
+    def test_missing_geometry_column(self):
+        gdf = gpd.GeoDataFrame(
+            {
+                "name": ["Sedona", "Apache"],
+            },
+        )
+
+        with pytest.raises(ValueError):
+            create_spatial_dataframe(self.spark, gdf)

Reply via email to