This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new fa13229577 [GH-2079] Store geometries in EWKB format in spark df and 
enforce it (#2080)
fa13229577 is described below

commit fa1322957746353513489c1085dc335984044967
Author: Peter Nguyen <[email protected]>
AuthorDate: Fri Jul 11 09:44:51 2025 -0700

    [GH-2079] Store geometries in EWKB format in spark df and enforce it (#2080)
    
    * Store geometries in EWKB format in spark df and enforce it
    
    * Use wkb.dumps instead of to_wkb for version compatibility
    
    * print shapely version
    
    * Use crs instead of shapely.get_srid
    
    * Convert back to Sedona geometry objects in to_parquet
    
    * Change to **kwargs
    
    Co-authored-by: Copilot <[email protected]>
    
    ---------
    
    Co-authored-by: Copilot <[email protected]>
---
 python/sedona/geopandas/geoseries.py               | 242 ++++++++++++++-------
 .../tests/geopandas/test_match_geopandas_series.py |  10 +-
 2 files changed, 170 insertions(+), 82 deletions(-)

diff --git a/python/sedona/geopandas/geoseries.py 
b/python/sedona/geopandas/geoseries.py
index 6160438a60..31cee8f7a0 100644
--- a/python/sedona/geopandas/geoseries.py
+++ b/python/sedona/geopandas/geoseries.py
@@ -117,6 +117,23 @@ class GeoSeries(GeoFrame, pspd.Series):
         self._anchor: GeoDataFrame
         self._col_label: Label
 
+        def try_geom_to_ewkb(x) -> bytes:
+            if isinstance(x, BaseGeometry):
+                kwargs = {}
+                if crs:
+                    from pyproj import CRS
+
+                    srid = CRS.from_user_input(crs)
+                    kwargs["srid"] = srid.to_epsg()
+
+                return shapely.wkb.dumps(x, **kwargs)
+            elif isinstance(x, bytearray):
+                return bytes(x)
+            elif x is None or isinstance(x, bytes):
+                return x
+            else:
+                raise TypeError(f"expected geometry or bytes, got {type(x)}: 
{x}")
+
         if isinstance(
             data, (GeoDataFrame, GeoSeries, PandasOnSparkSeries, 
PandasOnSparkDataFrame)
         ):
@@ -142,17 +159,10 @@ class GeoSeries(GeoFrame, pspd.Series):
 
             pd_data = data.to_pandas()
 
-            # If has shapely geometries, convert to wkb since 
pandas-on-pyspark can't understand shapely geometries
-            if (
-                isinstance(pd_data, pd.Series)
-                and any(isinstance(x, BaseGeometry) for x in pd_data)
-            ) or (
-                isinstance(pd_data, pd.DataFrame)
-                and any(isinstance(x, BaseGeometry) for x in 
pd_data.values.ravel())
-            ):
-                pd_data = pd_data.apply(
-                    lambda geom: geom.wkb if geom is not None else None
-                )
+            try:
+                pd_data = pd_data.apply(try_geom_to_ewkb)
+            except Exception as e:
+                raise TypeError(f"Non-geometry column passed to GeoSeries: 
{e}")
 
             super().__init__(
                 data=pd_data,
@@ -162,8 +172,6 @@ class GeoSeries(GeoFrame, pspd.Series):
                 copy=copy,
                 fastpath=fastpath,
             )
-
-            self._anchor = data
         else:
             if isinstance(data, pd.Series):
                 assert index is None
@@ -181,13 +189,23 @@ class GeoSeries(GeoFrame, pspd.Series):
                     copy=copy,
                     fastpath=fastpath,
                 )
-            gs = gpd.GeoSeries(s)
-            pdf = pd.Series(
-                gs.apply(lambda geom: geom.wkb if geom is not None else None)
-            )
+
+            try:
+                pdf = s.apply(try_geom_to_ewkb)
+            except Exception as e:
+                raise TypeError(f"Non-geometry column passed to GeoSeries: 
{e}")
+
             # initialize the parent class pyspark Series with the pandas Series
             super().__init__(data=pdf)
 
+        # manually set it to binary type
+        col = next(
+            field.name
+            for field in self._internal.spark_frame.schema.fields
+            if field.name not in (NATURAL_ORDER_COLUMN_NAME, 
SPARK_DEFAULT_INDEX_NAME)
+        )
+        self._internal.spark_frame.schema[col].dataType = BinaryType()
+
         if crs:
             self.set_crs(crs, inplace=True)
 
@@ -225,7 +243,9 @@ class GeoSeries(GeoFrame, pspd.Series):
         """
         from pyproj import CRS
 
-        tmp_df = self._process_geometry_column("ST_SRID", rename="crs")
+        tmp_df = self._process_geometry_column(
+            "ST_SRID", rename="crs", returns_geom=False
+        )
         srid = tmp_df.take([0])[0]
         # Sedona returns 0 if doesn't exist
         return CRS.from_user_input(srid) if srid != 0 and not pd.isna(srid) 
else None
@@ -354,18 +374,22 @@ class GeoSeries(GeoFrame, pspd.Series):
 
         # 0 indicates no srid in sedona
         new_epsg = crs.to_epsg() if crs else 0
+        col = self.get_first_geometry_column()
+
+        select = f"ST_SetSRID(`{col}`, {new_epsg})"
+
         # Keep the same column name instead of renaming it
-        result = self._process_geometry_column("ST_SetSRID", rename="", 
srid=new_epsg)
+        result = self._query_geometry_column(select, col, rename="")
 
         if inplace:
-            self._update_anchor(result._to_spark_pandas_df())
+            self._update_anchor(_to_spark_pandas_df(result))
             return None
 
         return result
 
     def _process_geometry_column(
-        self, operation: str, rename: str, *args, **kwargs
-    ) -> "GeoSeries":
+        self, operation: str, rename: str, returns_geom: bool = True, *args, 
**kwargs
+    ) -> Union["GeoSeries", pspd.Series]:
         """
         Helper method to process a single geometry column with a specified 
operation.
         This method wraps the _query_geometry_column method for simpler 
convenient use.
@@ -405,7 +429,9 @@ class GeoSeries(GeoFrame, pspd.Series):
 
         sql_expr = f"{operation}(`{first_col}`{params})"
 
-        return self._query_geometry_column(sql_expr, first_col, rename)
+        return self._query_geometry_column(
+            sql_expr, first_col, rename, returns_geom=returns_geom
+        )
 
     def _query_geometry_column(
         self,
@@ -413,7 +439,8 @@ class GeoSeries(GeoFrame, pspd.Series):
         cols: Union[List[str], str],
         rename: str,
         df: pyspark.sql.DataFrame = None,
-    ) -> "GeoSeries":
+        returns_geom: bool = True,
+    ) -> Union["GeoSeries", pspd.Series]:
         """
         Helper method to query a single geometry column with a specified 
operation.
 
@@ -427,6 +454,8 @@ class GeoSeries(GeoFrame, pspd.Series):
             The name of the resulting column.
         df : pyspark.sql.DataFrame
             The dataframe to query. If not provided, the internal dataframe 
will be used.
+        returns_geom : bool, default True
+            If True, the geometry column will be converted back to EWKB format.
 
         Returns
         -------
@@ -436,23 +465,36 @@ class GeoSeries(GeoFrame, pspd.Series):
         if not cols:
             raise ValueError("No valid geometry column found.")
 
-        if isinstance(cols, str):
-            cols = [cols]
-
         df = self._internal.spark_frame if df is None else df
 
-        for col in cols:
+        if isinstance(cols, str):
+            col = cols
             data_type = df.schema[col].dataType
+            if isinstance(data_type, BinaryType):
+                query = query.replace(f"`{cols}`", f"ST_GeomFromWKB(`{cols}`)")
+
+            # Convert back to EWKB format if the return type is a geometry
+            if returns_geom:
+                query = f"ST_AsEWKB({query})"
 
             rename = col if not rename else rename
 
-            if isinstance(data_type, BinaryType):
-                # the backticks here are important so we don't match strings 
that happen to be the same as the column name
-                query = query.replace(f"`{col}`", f"ST_GeomFromWKB(`{col}`)")
+            query = f"{query} as `{rename}`"
+
+        elif isinstance(cols, list):
+            for col in cols:
+                data_type = df.schema[col].dataType
+
+                if isinstance(data_type, BinaryType):
+                    # the backticks here are important so we don't match 
strings that happen to be the same as the column name
+                    query = query.replace(f"`{col}`", 
f"ST_GeomFromWKB(`{col}`)")
+
+            # must have rename for multiple columns since we don't know which 
name to default to
+            assert rename
 
-        sql_expr = f"{query} as `{rename}`"
+            query = f"{query} as `{rename}`"
 
-        sdf = df.selectExpr(sql_expr)
+        sdf = df.selectExpr(query)
         internal = InternalFrame(
             spark_frame=sdf,
             index_spark_columns=None,
@@ -461,7 +503,9 @@ class GeoSeries(GeoFrame, pspd.Series):
             data_fields=[self._internal.data_fields[0]],
             column_label_names=self._internal.column_label_names,
         )
-        return _to_geo_series(first_series(PandasOnSparkDataFrame(internal)))
+        ps_series = first_series(PandasOnSparkDataFrame(internal))
+
+        return GeoSeries(ps_series) if returns_geom else ps_series
 
     @property
     def dtypes(self) -> Union[gpd.GeoSeries, pd.Series, Dtype]:
@@ -490,7 +534,10 @@ class GeoSeries(GeoFrame, pspd.Series):
         pd_series = self._to_internal_pandas()
         try:
             return gpd.GeoSeries(
-                pd_series.map(lambda wkb: shapely.wkb.loads(bytes(wkb))), 
crs=self.crs
+                pd_series.map(
+                    lambda wkb: shapely.wkb.loads(bytes(wkb)) if wkb else None
+                ),
+                crs=self.crs,
             )
         except TypeError:
             return gpd.GeoSeries(pd_series, crs=self.crs)
@@ -498,9 +545,6 @@ class GeoSeries(GeoFrame, pspd.Series):
     def to_spark_pandas(self) -> pspd.Series:
         return pspd.Series(self._psdf._to_internal_pandas())
 
-    def _to_spark_pandas_df(self) -> pspd.DataFrame:
-        return pspd.DataFrame(self._psdf._internal)
-
     @property
     def geometry(self) -> "GeoSeries":
         return self
@@ -581,7 +625,9 @@ class GeoSeries(GeoFrame, pspd.Series):
         1    4.0
         dtype: float64
         """
-        return self._process_geometry_column("ST_Area", 
rename="area").to_spark_pandas()
+        return self._process_geometry_column(
+            "ST_Area", rename="area", returns_geom=False
+        )
 
     @property
     def geom_type(self) -> pspd.Series:
@@ -607,8 +653,8 @@ class GeoSeries(GeoFrame, pspd.Series):
         dtype: object
         """
         result = self._process_geometry_column(
-            "GeometryType", rename="geom_type"
-        ).to_spark_pandas()
+            "GeometryType", rename="geom_type", returns_geom=False
+        )
 
         # Sedona returns the string in all caps unlike Geopandas
         sgpd_to_gpg_name_map = {
@@ -664,8 +710,8 @@ class GeoSeries(GeoFrame, pspd.Series):
                 WHEN GeometryType(`{col}`) IN ('GEOMETRYCOLLECTION') THEN 
ST_Length(`{col}`) + ST_Perimeter(`{col}`)
             END"""
         return self._query_geometry_column(
-            select, col, rename="length"
-        ).to_spark_pandas()
+            select, col, rename="length", returns_geom=False
+        )
 
     @property
     def is_valid(self) -> pspd.Series:
@@ -706,11 +752,10 @@ class GeoSeries(GeoFrame, pspd.Series):
         --------
         GeoSeries.is_valid_reason : reason for invalidity
         """
-        return (
-            self._process_geometry_column("ST_IsValid", rename="is_valid")
-            .to_spark_pandas()
-            .astype("bool")
+        result = self._process_geometry_column(
+            "ST_IsValid", rename="is_valid", returns_geom=False
         )
+        return to_bool(result)
 
     def is_valid_reason(self) -> pspd.Series:
         """Returns a ``Series`` of strings with the reason for invalidity of
@@ -754,8 +799,8 @@ class GeoSeries(GeoFrame, pspd.Series):
         GeoSeries.make_valid : fix invalid geometries
         """
         return self._process_geometry_column(
-            "ST_IsValidReason", rename="is_valid_reason"
-        ).to_spark_pandas()
+            "ST_IsValidReason", rename="is_valid_reason", returns_geom=False
+        )
 
     @property
     def is_empty(self) -> pspd.Series:
@@ -786,11 +831,10 @@ class GeoSeries(GeoFrame, pspd.Series):
         --------
         GeoSeries.isna : detect missing values
         """
-        return (
-            self._process_geometry_column("ST_IsEmpty", rename="is_empty")
-            .to_spark_pandas()
-            .astype("bool")
+        result = self._process_geometry_column(
+            "ST_IsEmpty", rename="is_empty", returns_geom=False
         )
+        return to_bool(result)
 
     def count_coordinates(self):
         # Implementation of the abstract method
@@ -831,11 +875,10 @@ class GeoSeries(GeoFrame, pspd.Series):
         1     True
         dtype: bool
         """
-        return (
-            self._process_geometry_column("ST_IsSimple", rename="is_simple")
-            .to_spark_pandas()
-            .astype("bool")
+        result = self._process_geometry_column(
+            "ST_IsSimple", rename="is_simple", returns_geom=False
         )
+        return to_bool(result)
 
     @property
     def is_ring(self):
@@ -883,8 +926,8 @@ class GeoSeries(GeoFrame, pspd.Series):
         dtype: bool
         """
         return self._process_geometry_column(
-            "ST_HasZ", rename="has_z"
-        ).to_spark_pandas()
+            "ST_HasZ", rename="has_z", returns_geom=False
+        )
 
     def get_precision(self):
         # Implementation of the abstract method
@@ -1182,13 +1225,22 @@ class GeoSeries(GeoFrame, pspd.Series):
         GeoSeries.touches
         GeoSeries.intersection
         """
-        return (
-            self._row_wise_operation(
-                "ST_Intersects(`L`, `R`)", other, align, rename="intersects"
-            )
-            .to_spark_pandas()
-            .astype("bool")
+
+        select = "ST_Intersects(`L`, `R`)"
+
+        # ps.Series.fillna() call in to_bool, doesn't work for the output for
+        # intersects here for some reason. So we manually handle the nulls 
here.
+        select = f"""
+            CASE
+                WHEN `L` IS NULL OR `R` IS NULL THEN FALSE
+                ELSE {select}
+            END
+        """
+
+        result = self._row_wise_operation(
+            select, other, align, rename="intersects", returns_geom=False
         )
+        return to_bool(result)
 
     def intersection(
         self, other: Union["GeoSeries", BaseGeometry], align: Union[bool, 
None] = None
@@ -1302,6 +1354,7 @@ class GeoSeries(GeoFrame, pspd.Series):
         other: Union["GeoSeries", BaseGeometry],
         align: Union[bool, None],
         rename: str,
+        returns_geom: bool = True,
     ):
         """
         Helper function to perform a row-wise operation on two GeoSeries.
@@ -1334,6 +1387,7 @@ class GeoSeries(GeoFrame, pspd.Series):
             cols=["L", "R"],
             rename=rename,
             df=joined_df,
+            returns_geom=returns_geom,
         )
 
     def intersection_all(self):
@@ -1395,8 +1449,19 @@ class GeoSeries(GeoFrame, pspd.Series):
         - kwargs: Any
             Additional arguments to pass to the Sedona DataFrame output 
function.
         """
+        col = self.get_first_geometry_column()
+
+        # Convert WKB to Sedona geometry objects
+        # Specify returns_geom=False to avoid turning it back into EWKB
+        result = self._query_geometry_column(
+            f"`{col}`",
+            cols=col,
+            rename="wkb",
+            returns_geom=False,
+        )
+
         # Use the Spark DataFrame's write method to write to GeoParquet format
-        self._internal.spark_frame.write.format("geoparquet").save(path, 
**kwargs)
+        result._internal.spark_frame.write.format("geoparquet").save(path, 
**kwargs)
 
     def sjoin(
         self,
@@ -1476,7 +1541,7 @@ class GeoSeries(GeoFrame, pspd.Series):
         GeoSeries.z
 
         """
-        return self._process_geometry_column("ST_X", 
rename="x").to_spark_pandas()
+        return self._process_geometry_column("ST_X", rename="x", 
returns_geom=False)
 
     @property
     def y(self) -> pspd.Series:
@@ -1506,7 +1571,7 @@ class GeoSeries(GeoFrame, pspd.Series):
         GeoSeries.m
 
         """
-        return self._process_geometry_column("ST_Y", 
rename="y").to_spark_pandas()
+        return self._process_geometry_column("ST_Y", rename="y", 
returns_geom=False)
 
     @property
     def z(self) -> pspd.Series:
@@ -1536,7 +1601,7 @@ class GeoSeries(GeoFrame, pspd.Series):
         GeoSeries.m
 
         """
-        return self._process_geometry_column("ST_Z", 
rename="z").to_spark_pandas()
+        return self._process_geometry_column("ST_Z", rename="z", 
returns_geom=False)
 
     @property
     def m(self) -> pspd.Series:
@@ -1881,11 +1946,10 @@ class GeoSeries(GeoFrame, pspd.Series):
         """
         col = self.get_first_geometry_column()
         select = f"`{col}` IS NULL"
-        return (
-            self._query_geometry_column(select, col, rename="isna")
-            .to_spark_pandas()
-            .astype("bool")
+        result = self._query_geometry_column(
+            select, col, rename="isna", returns_geom=False
         )
+        return to_bool(result)
 
     def isnull(self) -> pspd.Series:
         """Alias for `isna` method. See `isna` for more detail."""
@@ -1927,11 +1991,11 @@ class GeoSeries(GeoFrame, pspd.Series):
         """
         col = self.get_first_geometry_column()
         select = f"`{col}` IS NOT NULL"
-        return (
-            self._query_geometry_column(select, col, rename="notna")
-            .to_spark_pandas()
-            .astype("bool")
+
+        result = self._query_geometry_column(
+            select, col, rename="notna", returns_geom=False
         )
+        return to_bool(result)
 
     def notnull(self) -> pspd.Series:
         """Alias for `notna` method. See `notna` for more detail."""
@@ -2273,7 +2337,7 @@ class GeoSeries(GeoFrame, pspd.Series):
     # # Utils
     # 
-----------------------------------------------------------------------------
 
-    def get_first_geometry_column(self) -> Union[str, None]:
+    def get_first_geometry_column(self) -> str:
         first_binary_or_geometry_col = next(
             (
                 field.name
@@ -2283,7 +2347,12 @@ class GeoSeries(GeoFrame, pspd.Series):
             ),
             None,
         )
-        return first_binary_or_geometry_col
+        if first_binary_or_geometry_col:
+            return first_binary_or_geometry_col
+
+        raise ValueError(
+            "get_first_geometry_column: No geometry column found in the 
GeoSeries."
+        )
 
 
 # -----------------------------------------------------------------------------
@@ -2291,6 +2360,21 @@ class GeoSeries(GeoFrame, pspd.Series):
 # -----------------------------------------------------------------------------
 
 
+def _to_spark_pandas_df(ps_series: pspd.Series) -> pspd.DataFrame:
+    return pspd.DataFrame(ps_series._psdf._internal)
+
+
+def to_bool(ps_series: pspd.Series, default: bool = False) -> pspd.Series:
+    """
+    Cast a ps.Series to bool type if it's not one, converting None values to 
the default value.
+    """
+    if ps_series.dtype.name != "bool":
+        # fill None values with the default value
+        ps_series.fillna(default, inplace=True)
+
+    return ps_series
+
+
 def _to_geo_series(df: PandasOnSparkSeries) -> GeoSeries:
     """
     Get the first Series from the DataFrame.
diff --git a/python/tests/geopandas/test_match_geopandas_series.py 
b/python/tests/geopandas/test_match_geopandas_series.py
index e63d086c83..f19c04e83a 100644
--- a/python/tests/geopandas/test_match_geopandas_series.py
+++ b/python/tests/geopandas/test_match_geopandas_series.py
@@ -130,6 +130,10 @@ class TestMatchGeopandasSeries(TestBase):
             GeoSeries([0, 1, 2], crs="epsg:4326")
         with pytest.raises(TypeError):
             GeoSeries(["a", "b", "c"])
+        with pytest.raises(TypeError):
+            GeoSeries(pd.Series([0, 1, 2]), crs="epsg:4326")
+        with pytest.raises(TypeError):
+            GeoSeries(ps.Series([0, 1, 2]))
 
     def test_to_geopandas(self):
         for _, geom in self.geoms:
@@ -608,9 +612,9 @@ class TestMatchGeopandasSeries(TestBase):
                 self.check_sgpd_equals_gpd(sgpd_result, gpd_result)
 
                 if len(g1) == len(g2):
-                    sgpd_result = GeoSeries(g1).intersects(GeoSeries(g2), 
align=False)
-                    gpd_result = gpd_series1.intersects(gpd_series2, 
align=False)
-                    self.check_pd_series_equal(sgpd_result, gpd_result)
+                    sgpd_result = GeoSeries(g1).intersection(GeoSeries(g2), 
align=False)
+                    gpd_result = gpd_series1.intersection(gpd_series2, 
align=False)
+                    self.check_sgpd_equals_gpd(sgpd_result, gpd_result)
 
     def test_intersection_all(self):
         pass

Reply via email to