This is an automated email from the ASF dual-hosted git repository. imbruced pushed a commit to branch SEDONA-714-add-geopandas-to-spark-arrow-conversion in repository https://gitbox.apache.org/repos/asf/sedona.git
commit ebdee33ee138f10e90057926f9c7998f517d78ff Author: pawelkocinski <[email protected]> AuthorDate: Sun Feb 23 21:53:47 2025 +0100 SEDONA-714 Add geopandas to spark arrow conversion. --- python/sedona/utils/geoarrow.py | 85 +++++++++++++++++++++- python/tests/test_base.py | 2 +- .../test_arrow_conversion_geopandas_to_sedona.py | 77 ++++++++++++++++++++ 3 files changed, 162 insertions(+), 2 deletions(-) diff --git a/python/sedona/utils/geoarrow.py b/python/sedona/utils/geoarrow.py index b8ade8528b..8c730d9a39 100644 --- a/python/sedona/utils/geoarrow.py +++ b/python/sedona/utils/geoarrow.py @@ -19,8 +19,18 @@ # with the ArrowStreamSerializer (instead of the ArrowCollectSerializer) -from sedona.sql.types import GeometryType from sedona.sql.st_functions import ST_AsEWKB +from pyspark.sql import SparkSession +from pyspark.sql import DataFrame +from pyspark.sql.types import StructType, StructField +import pyarrow as pa + +from sedona.sql.types import GeometryType +import geopandas as gpd +from pyspark.sql.pandas.types import ( + from_arrow_type, +) +from pyspark.sql.pandas.serializers import ArrowStreamPandasSerializer def dataframe_to_arrow(df, crs=None): @@ -186,3 +196,76 @@ def unique_srid_from_ewkb(obj): import pyproj return pyproj.CRS(f"EPSG:{epsg_code}") + + +def infer_schema(gdf: gpd.GeoDataFrame) -> StructType: + fields = gdf.dtypes.reset_index().values.tolist() + geom_fields = [] + index = 0 + for name, dtype in fields: + if dtype == "geometry": + geom_fields.append((index, name)) + continue + + index += 1 + + if not geom_fields: + raise ValueError("No geometry field found in the GeoDataFrame") + + pa_schema = pa.Schema.from_pandas( + gdf.drop([name for _, name in geom_fields], axis=1) + ) + spark_schema = [] + + for field in pa_schema: + field_type = field.type + spark_type = from_arrow_type(field_type) + spark_schema.append(StructField(field.name, spark_type, True)) + + for index, geom_field in geom_fields: + spark_schema.insert(index, StructField(geom_field, GeometryType(), True)) + + return StructType(spark_schema) + + +def create_spatial_dataframe(spark: SparkSession, gdf: gpd.GeoDataFrame) -> DataFrame: + from pyspark.sql.pandas.types import ( + to_arrow_type, + _deduplicate_field_names, + ) + + def reader_func(temp_filename): + return spark._jvm.PythonSQLUtils.readArrowStreamFromFile(temp_filename) + + def create_iter_server(): + return spark._jvm.ArrowIteratorServer() + + schema = infer_schema(gdf) + timezone = spark._jconf.sessionLocalTimeZone() + step = spark._jconf.arrowMaxRecordsPerBatch() + step = step if step > 0 else len(gdf) + pdf_slices = (gdf.iloc[start : start + step] for start in range(0, len(gdf), step)) + spark_types = [_deduplicate_field_names(f.dataType) for f in schema.fields] + + arrow_data = [ + [ + (c, to_arrow_type(t) if t is not None else None, t) + for (_, c), t in zip(pdf_slice.items(), spark_types) + ] + for pdf_slice in pdf_slices + ] + + safecheck = spark._jconf.arrowSafeTypeConversion() + ser = ArrowStreamPandasSerializer(timezone, safecheck) + jiter = spark._sc._serialize_to_jvm( + arrow_data, ser, reader_func, create_iter_server + ) + + jsparkSession = spark._jsparkSession + jdf = spark._jvm.PythonSQLUtils.toDataFrame(jiter, schema.json(), jsparkSession) + + df = DataFrame(jdf, spark) + + df._schema = schema + + return df diff --git a/python/tests/test_base.py b/python/tests/test_base.py index 2769a93cdd..710b7a5564 100644 --- a/python/tests/test_base.py +++ b/python/tests/test_base.py @@ -16,7 +16,7 @@ # under the License. import os from tempfile import mkdtemp -from typing import Iterable, Union +from typing import Iterable import pyspark diff --git a/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py b/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py new file mode 100644 index 0000000000..e2c8344bc3 --- /dev/null +++ b/python/tests/utils/test_arrow_conversion_geopandas_to_sedona.py @@ -0,0 +1,77 @@ +import pytest + +from sedona.sql.types import GeometryType +from sedona.utils.geoarrow import create_spatial_dataframe +from tests.test_base import TestBase +import geopandas as gpd + + +class TestGeopandasToSedonaWithArrow(TestBase): + + def test_conversion_dataframe(self): + gdf = gpd.GeoDataFrame( + { + "name": ["Sedona", "Apache"], + "geometry": gpd.points_from_xy([0, 1], [0, 1]), + } + ) + + df = create_spatial_dataframe(self.spark, gdf) + + assert df.count() == 2 + assert df.columns == ["name", "geometry"] + assert df.schema["geometry"].dataType == GeometryType() + + def test_different_geometry_positions(self): + gdf = gpd.GeoDataFrame( + { + "geometry": gpd.points_from_xy([0, 1], [0, 1]), + "name": ["Sedona", "Apache"], + } + ) + + gdf2 = gpd.GeoDataFrame( + { + "name": ["Sedona", "Apache"], + "name1": ["Sedona", "Apache"], + "name2": ["Sedona", "Apache"], + "geometry": gpd.points_from_xy([0, 1], [0, 1]), + } + ) + + df1 = create_spatial_dataframe(self.spark, gdf) + df2 = create_spatial_dataframe(self.spark, gdf2) + + assert df1.count() == 2 + assert df1.columns == ["geometry", "name"] + assert df1.schema["geometry"].dataType == GeometryType() + + assert df2.count() == 2 + assert df2.columns == ["name", "name1", "name2", "geometry"] + assert df2.schema["geometry"].dataType == GeometryType() + + def test_multiple_geometry_columns(self): + gdf = gpd.GeoDataFrame( + { + "name": ["Sedona", "Apache"], + "geometry": gpd.points_from_xy([0, 1], [0, 1]), + "geometry2": gpd.points_from_xy([0, 1], [0, 1]), + } + ) + + df = create_spatial_dataframe(self.spark, gdf) + + assert df.count() == 2 + assert df.columns == ["name", "geometry2", "geometry"] + assert df.schema["geometry"].dataType == GeometryType() + assert df.schema["geometry2"].dataType == GeometryType() + + def test_missing_geometry_column(self): + gdf = gpd.GeoDataFrame( + { + "name": ["Sedona", "Apache"], + }, + ) + + with pytest.raises(ValueError): + create_spatial_dataframe(self.spark, gdf)
