This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new a34c3f83a5 [GH-2129] Geopandas: Refactor to use Spark Dataframe API
instead of Spark SQL (#2131)
a34c3f83a5 is described below
commit a34c3f83a56763c443fa30f76d084e327d68c244
Author: Peter Nguyen <[email protected]>
AuthorDate: Tue Jul 22 12:52:24 2025 -0700
[GH-2129] Geopandas: Refactor to use Spark Dataframe API instead of Spark
SQL (#2131)
---
python/sedona/geopandas/geodataframe.py | 45 +-
python/sedona/geopandas/geoseries.py | 684 +++++++++++----------
python/tests/geopandas/test_geodataframe.py | 34 +-
python/tests/geopandas/test_geoseries.py | 13 +
.../tests/geopandas/test_match_geopandas_series.py | 12 -
5 files changed, 415 insertions(+), 373 deletions(-)
diff --git a/python/sedona/geopandas/geodataframe.py
b/python/sedona/geopandas/geodataframe.py
index f4da65207d..60282a7c7f 100644
--- a/python/sedona/geopandas/geodataframe.py
+++ b/python/sedona/geopandas/geodataframe.py
@@ -39,6 +39,7 @@ from pandas.api.extensions import register_extension_dtype
from geopandas.geodataframe import crs_mismatch_error
from geopandas.array import GeometryDtype
from shapely.geometry.base import BaseGeometry
+from pyspark.pandas.internal import SPARK_DEFAULT_INDEX_NAME,
NATURAL_ORDER_COLUMN_NAME
register_extension_dtype(GeometryDtype)
@@ -344,7 +345,7 @@ class GeoDataFrame(GeoFrame, pspd.DataFrame):
# Here we are getting a ps.Series with the same underlying anchor
(ps.Dataframe).
# This is important so we don't unnecessarily try to perform
operations on different dataframes
- ps_series = pspd.DataFrame.__getitem__(self, column_name)
+ ps_series: pspd.Series = pspd.DataFrame.__getitem__(self,
column_name)
try:
result = sgpd.GeoSeries(ps_series)
@@ -444,9 +445,12 @@ class GeoDataFrame(GeoFrame, pspd.DataFrame):
assert index is None
assert dtype is None
assert not copy
- df = data
+ # Need to convert GeoDataFrame to pd.DataFrame for below cast
to work
+ pd_df = (
+ pd.DataFrame(data) if isinstance(data, gpd.GeoDataFrame)
else data
+ )
else:
- df = pd.DataFrame(
+ pd_df = pd.DataFrame(
data=data,
index=index,
dtype=dtype,
@@ -454,7 +458,8 @@ class GeoDataFrame(GeoFrame, pspd.DataFrame):
)
# Spark complains if it's left as a geometry type
- pd_df = df.astype(object)
+ geom_type_cols = pd_df.select_dtypes(include=["geometry"]).columns
+ pd_df[geom_type_cols] = pd_df[geom_type_cols].astype(object)
# initialize the parent class pyspark Dataframe with the pandas
Dataframe
super().__init__(
@@ -630,13 +635,14 @@ class GeoDataFrame(GeoFrame, pspd.DataFrame):
"`set_geometry` when this is the case."
)
warnings.warn(msg, category=FutureWarning, stacklevel=2)
- if isinstance(col, (pspd.Series, pd.Series)) and col.name is not
None:
- geo_column_name = col.name
-
- level = col
-
- if not isinstance(level, pspd.Series):
- level = pspd.Series(level)
+ if isinstance(col, (pspd.Series, pd.Series)):
+ if col.name is not None:
+ geo_column_name = col.name
+ level = col
+ else:
+ level = col.rename(geo_column_name)
+ else:
+ level = pspd.Series(col, name=geo_column_name)
elif hasattr(col, "ndim") and col.ndim > 1:
raise ValueError("Must pass array with one dimension only.")
else: # should be a colname
@@ -746,17 +752,18 @@ class GeoDataFrame(GeoFrame, pspd.DataFrame):
if col in self.columns:
raise ValueError(f"Column named {col} already exists")
else:
+ mapper = {col: col for col in list(self.columns)}
+ mapper[geometry_col] = col
+
if inplace:
- self.rename(columns={geometry_col: col}, inplace=inplace)
- self.set_geometry(col, inplace=inplace)
+ self.rename(columns=mapper, inplace=True, errors="raise")
+ self.set_geometry(col, inplace=True)
return None
- # The same .rename().set_geometry() logic errors for this case, so
we do it manually instead
- ps_series = self._psser_for((geometry_col,)).rename(col)
- sdf = self.copy()
- sdf[col] = ps_series
- sdf = sdf.set_geometry(col)
- return sdf
+ df = self.copy()
+ df.rename(columns=mapper, inplace=True, errors="raise")
+ df.set_geometry(col, inplace=True)
+ return df
#
============================================================================
# PROPERTIES AND ATTRIBUTES
diff --git a/python/sedona/geopandas/geoseries.py
b/python/sedona/geopandas/geoseries.py
index 6e33c9e40e..434c0ba315 100644
--- a/python/sedona/geopandas/geoseries.py
+++ b/python/sedona/geopandas/geoseries.py
@@ -31,6 +31,14 @@ from pyspark.pandas.utils import scol_for, log_advice
from pyspark.sql.types import BinaryType, NullType
from sedona.spark.sql.types import GeometryType
+from sedona.spark.sql import st_aggregates as sta
+from sedona.spark.sql import st_constructors as stc
+from sedona.spark.sql import st_functions as stf
+from sedona.spark.sql import st_predicates as stp
+
+from pyspark.sql import Column as PySparkColumn
+from pyspark.sql import functions as F
+
import shapely
from shapely.geometry.base import BaseGeometry
@@ -42,6 +50,7 @@ from sedona.geopandas.sindex import SpatialIndex
from pyspark.pandas.internal import (
SPARK_DEFAULT_INDEX_NAME, # __index_level_0__
NATURAL_ORDER_COLUMN_NAME,
+ SPARK_DEFAULT_SERIES_NAME, # '0'
)
@@ -434,19 +443,13 @@ class GeoSeries(GeoFrame, pspd.Series):
super().__init__(data=pd_series)
# Ensure we're storing geometry types
- col = next(
- field.name
- for field in self._internal.spark_frame.schema.fields
- if field.name not in (NATURAL_ORDER_COLUMN_NAME,
SPARK_DEFAULT_INDEX_NAME)
- )
- datatype = self._internal.spark_frame.schema[col].dataType
- # Empty lists input will lead to NullType(), so we convert to
GeometryType()
- if datatype == NullType():
- self._internal.spark_frame.schema[col].dataType = GeometryType()
- elif datatype != GeometryType():
+ if (
+ self.spark.data_type != GeometryType()
+ and self.spark.data_type != NullType()
+ ):
raise TypeError(
"Non geometry data passed to GeoSeries constructor, "
- f"received data of dtype '{datatype.typeName()}'"
+ f"received data of dtype '{self.spark.data_type.typeName()}'"
)
if crs:
@@ -493,8 +496,10 @@ class GeoSeries(GeoFrame, pspd.Series):
if len(self) == 0:
return None
- tmp_series: pspd.Series = self._process_geometry_column(
- "ST_SRID", rename="crs", returns_geom=False
+ spark_col = stf.ST_SRID(self.spark.column)
+ tmp_series = self._query_geometry_column(
+ spark_col,
+ returns_geom=False,
)
# All geometries should have the same srid
@@ -629,15 +634,12 @@ class GeoSeries(GeoFrame, pspd.Series):
# 0 indicates no srid in sedona
new_epsg = crs.to_epsg() if crs else 0
- col = self.get_first_geometry_column()
-
- select = f"ST_SetSRID(`{col}`, {new_epsg})"
- # Keep the same column name instead of renaming it
- result = self._query_geometry_column(select, col, rename="")
+ spark_col = stf.ST_SetSRID(self.spark.column, new_epsg)
+ result = self._query_geometry_column(spark_col)
if inplace:
- self._update_anchor(_to_spark_pandas_df(result))
+ self._update_inplace(result)
return None
return result
@@ -646,63 +648,9 @@ class GeoSeries(GeoFrame, pspd.Series):
# INTERNAL HELPER METHODS
#
============================================================================
- def _process_geometry_column(
- self,
- operation: str,
- rename: str,
- returns_geom: bool = True,
- is_aggr: bool = False,
- *args,
- **kwargs,
- ) -> Union["GeoSeries", pspd.Series]:
- """
- Helper method to process a single geometry column with a specified
operation.
- This method wraps the _query_geometry_column method for simpler
convenient use.
-
- Parameters
- ----------
- operation : str
- The spatial operation to apply (e.g., 'ST_Area', 'ST_Buffer').
- rename : str
- The name of the resulting column. If empty, the old column name is
maintained.
- args : tuple
- Positional arguments for the operation.
- kwargs : dict
- Keyword arguments for the operation.
-
- Returns
- -------
- GeoSeries
- A GeoSeries with the operation applied to the geometry column.
- """
- # Find the first column with BinaryType or GeometryType
- first_col = self.get_first_geometry_column() # TODO: fixme
-
- # Handle both positional and keyword arguments
- all_args = list(args)
- for k, v in kwargs.items():
- all_args.append(v)
-
- # Join all arguments as comma-separated values
- params = ""
- if all_args:
- params_list = [
- str(arg) if isinstance(arg, (int, float)) else repr(arg)
- for arg in all_args
- ]
- params = f", {', '.join(params_list)}"
-
- sql_expr = f"{operation}(`{first_col}`{params})"
-
- return self._query_geometry_column(
- sql_expr, first_col, rename, returns_geom=returns_geom,
is_aggr=is_aggr
- )
-
def _query_geometry_column(
self,
- query: str,
- cols: Union[List[str], str],
- rename: str,
+ spark_col: PySparkColumn,
df: pyspark.sql.DataFrame = None,
returns_geom: bool = True,
is_aggr: bool = False,
@@ -712,12 +660,8 @@ class GeoSeries(GeoFrame, pspd.Series):
Parameters
----------
- query : str
+ spark_col : str
The query to apply to the geometry column.
- cols : List[str] or str
- The names of the columns to query.
- rename : str
- The name of the resulting column.
df : pyspark.sql.DataFrame
The dataframe to query. If not provided, the internal dataframe
will be used.
returns_geom : bool, default True
@@ -730,46 +674,29 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries
A GeoSeries with the operation applied to the geometry column.
"""
- if not cols:
- raise ValueError("No valid geometry column found.")
df = self._internal.spark_frame if df is None else df
- if isinstance(cols, str):
- col = cols
- data_type = df.schema[col].dataType
- if isinstance(data_type, BinaryType):
- query = query.replace(f"`{cols}`", f"ST_GeomFromWKB(`{cols}`)")
-
- rename = col if not rename else rename
+ rename = self.name if self.name else SPARK_DEFAULT_SERIES_NAME
- elif isinstance(cols, list):
- for col in cols:
- data_type = df.schema[col].dataType
+ col_expr = spark_col.alias(rename)
- if isinstance(data_type, BinaryType):
- # the backticks here are important so we don't match
strings that happen to be the same as the column name
- query = query.replace(f"`{col}`",
f"ST_GeomFromWKB(`{col}`)")
-
- # must have rename for multiple columns since we don't know which
name to default to
- assert rename
-
- query = f"{query} as `{rename}`"
-
- exprs = [query]
+ exprs = [col_expr]
index_spark_columns = []
index_fields = []
if not is_aggr:
# We always select NATURAL_ORDER_COLUMN_NAME, to avoid having to
regenerate it in the result
# We always select SPARK_DEFAULT_INDEX_NAME, to retain series
index info
- exprs.append(SPARK_DEFAULT_INDEX_NAME)
- exprs.append(NATURAL_ORDER_COLUMN_NAME)
+
+ exprs.append(scol_for(df, SPARK_DEFAULT_INDEX_NAME))
+ exprs.append(scol_for(df, NATURAL_ORDER_COLUMN_NAME))
+
index_spark_columns = [scol_for(df, SPARK_DEFAULT_INDEX_NAME)]
index_fields = [self._internal.index_fields[0]]
# else if is_aggr, we don't select the index columns
- sdf = df.selectExpr(*exprs)
+ sdf = df.select(*exprs)
internal = self._internal.copy(
spark_frame=sdf,
@@ -781,7 +708,12 @@ class GeoSeries(GeoFrame, pspd.Series):
)
ps_series = first_series(PandasOnSparkDataFrame(internal))
- return GeoSeries(ps_series) if returns_geom else ps_series
+ # Convert spark series default name to pandas series default name
(None) if needed
+ series_name = None if rename == SPARK_DEFAULT_SERIES_NAME else rename
+ ps_series = ps_series.rename(series_name)
+
+ result = GeoSeries(ps_series) if returns_geom else ps_series
+ return result
#
============================================================================
# CONVERSION AND SERIALIZATION METHODS
@@ -841,7 +773,7 @@ class GeoSeries(GeoFrame, pspd.Series):
>>> index.size
2
"""
- geometry_column = self.get_first_geometry_column()
+ geometry_column = _get_series_col_name(self)
if geometry_column is None:
raise ValueError("No geometry column found in GeoSeries")
return SpatialIndex(self._internal.spark_frame,
column_name=geometry_column)
@@ -896,8 +828,12 @@ class GeoSeries(GeoFrame, pspd.Series):
1 4.0
dtype: float64
"""
- return self._process_geometry_column(
- "ST_Area", rename="area", returns_geom=False
+
+ spark_col = stf.ST_Area(self.spark.column)
+
+ return self._query_geometry_column(
+ spark_col,
+ returns_geom=False,
)
@property
@@ -923,8 +859,10 @@ class GeoSeries(GeoFrame, pspd.Series):
1 POINT
dtype: object
"""
- result = self._process_geometry_column(
- "GeometryType", rename="geom_type", returns_geom=False
+ spark_col = stf.GeometryType(self.spark.column)
+ result = self._query_geometry_column(
+ spark_col,
+ returns_geom=False,
)
# Sedona returns the string in all caps unlike Geopandas
@@ -974,16 +912,30 @@ class GeoSeries(GeoFrame, pspd.Series):
3 4.828427
dtype: float64
"""
- col = self.get_first_geometry_column()
- select = f"""
- CASE
- WHEN GeometryType(`{col}`) IN ('LINESTRING',
'MULTILINESTRING') THEN ST_Length(`{col}`)
- WHEN GeometryType(`{col}`) IN ('POLYGON', 'MULTIPOLYGON') THEN
ST_Perimeter(`{col}`)
- WHEN GeometryType(`{col}`) IN ('POINT', 'MULTIPOINT') THEN 0.0
- WHEN GeometryType(`{col}`) IN ('GEOMETRYCOLLECTION') THEN
ST_Length(`{col}`) + ST_Perimeter(`{col}`)
- END"""
+
+ spark_expr = (
+ F.when(
+ stf.GeometryType(self.spark.column).isin(
+ ["LINESTRING", "MULTILINESTRING"]
+ ),
+ stf.ST_Length(self.spark.column),
+ )
+ .when(
+ stf.GeometryType(self.spark.column).isin(["POLYGON",
"MULTIPOLYGON"]),
+ stf.ST_Perimeter(self.spark.column),
+ )
+ .when(
+ stf.GeometryType(self.spark.column).isin(["POINT",
"MULTIPOINT"]),
+ 0.0,
+ )
+ .when(
+
stf.GeometryType(self.spark.column).isin(["GEOMETRYCOLLECTION"]),
+ stf.ST_Length(self.spark.column) +
stf.ST_Perimeter(self.spark.column),
+ )
+ )
return self._query_geometry_column(
- select, col, rename="length", returns_geom=False
+ spark_expr,
+ returns_geom=False,
)
@property
@@ -1025,8 +977,11 @@ class GeoSeries(GeoFrame, pspd.Series):
--------
GeoSeries.is_valid_reason : reason for invalidity
"""
- result = self._process_geometry_column(
- "ST_IsValid", rename="is_valid", returns_geom=False
+
+ spark_col = stf.ST_IsValid(self.spark.column)
+ result = self._query_geometry_column(
+ spark_col,
+ returns_geom=False,
)
return to_bool(result)
@@ -1071,8 +1026,10 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.is_valid : detect invalid geometries
GeoSeries.make_valid : fix invalid geometries
"""
- return self._process_geometry_column(
- "ST_IsValidReason", rename="is_valid_reason", returns_geom=False
+ spark_col = stf.ST_IsValidReason(self.spark.column)
+ return self._query_geometry_column(
+ spark_col,
+ returns_geom=False,
)
@property
@@ -1104,8 +1061,10 @@ class GeoSeries(GeoFrame, pspd.Series):
--------
GeoSeries.isna : detect missing values
"""
- result = self._process_geometry_column(
- "ST_IsEmpty", rename="is_empty", returns_geom=False
+ spark_expr = stf.ST_IsEmpty(self.spark.column)
+ result = self._query_geometry_column(
+ spark_expr,
+ returns_geom=False,
)
return to_bool(result)
@@ -1244,13 +1203,16 @@ class GeoSeries(GeoFrame, pspd.Series):
"Array-like distance for dwithin not implemented yet."
)
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_expr = stp.ST_DWithin(F.col("L"), F.col("R"), F.lit(distance))
return self._row_wise_operation(
- f"ST_DWithin(`L`, `R`, {distance})",
- other,
- align,
- rename="dwithin",
+ spark_expr,
+ other_series,
+ align=align,
returns_geom=False,
- default_val="FALSE",
+ default_val=False,
)
def difference(self, other, align=None) -> "GeoSeries":
@@ -1354,11 +1316,15 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.union
GeoSeries.intersection
"""
+
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_expr = stf.ST_Difference(F.col("L"), F.col("R"))
return self._row_wise_operation(
- "ST_Difference(`L`, `R`)",
- other,
- align,
- rename="difference",
+ spark_expr,
+ other_series,
+ align=align,
returns_geom=True,
)
@@ -1389,8 +1355,10 @@ class GeoSeries(GeoFrame, pspd.Series):
1 True
dtype: bool
"""
- result = self._process_geometry_column(
- "ST_IsSimple", rename="is_simple", returns_geom=False
+ spark_expr = stf.ST_IsSimple(self.spark.column)
+ result = self._query_geometry_column(
+ spark_expr,
+ returns_geom=False,
)
return to_bool(result)
@@ -1453,8 +1421,10 @@ class GeoSeries(GeoFrame, pspd.Series):
1 True
dtype: bool
"""
- return self._process_geometry_column(
- "ST_HasZ", rename="has_z", returns_geom=False
+ spark_expr = stf.ST_HasZ(self.spark.column)
+ return self._query_geometry_column(
+ spark_expr,
+ returns_geom=False,
)
def get_precision(self):
@@ -1536,22 +1506,25 @@ class GeoSeries(GeoFrame, pspd.Series):
"""
# Sedona errors on negative indexes, so we use a case statement to
handle it ourselves
- select = """
- ST_GeometryN(
- `L`,
- CASE
- WHEN ST_NumGeometries(`L`) + `R` < 0 THEN NULL
- WHEN `R` < 0 THEN ST_NumGeometries(`L`) + `R`
- ELSE `R`
- END
+ spark_expr = stf.ST_GeometryN(
+ F.col("L"),
+ F.when(
+ stf.ST_NumGeometries(F.col("L")) + F.col("R") < 0,
+ None,
+ )
+ .when(F.col("R") < 0, stf.ST_NumGeometries(F.col("L")) +
F.col("R"))
+ .otherwise(F.col("R")),
)
- """
+
+ other, _ = self._make_series_of_val(index)
+
+ # align = False either way
+ align = False
return self._row_wise_operation(
- select,
- index,
- align=False,
- rename="get_geometry",
+ spark_expr,
+ other,
+ align=align,
returns_geom=True,
default_val=None,
)
@@ -1590,15 +1563,15 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.exterior : outer boundary (without interior rings)
"""
- col = self.get_first_geometry_column()
# Geopandas and shapely return NULL for GeometryCollections, so we
handle it separately
#
https://shapely.readthedocs.io/en/stable/reference/shapely.boundary.html
- select = f"""
- CASE
- WHEN GeometryType(`{col}`) IN ('GEOMETRYCOLLECTION') THEN NULL
- ELSE ST_Boundary(`{col}`)
- END"""
- return self._query_geometry_column(select, col, rename="boundary")
+ spark_expr = F.when(
+ stf.GeometryType(self.spark.column).isin(["GEOMETRYCOLLECTION"]),
+ None,
+ ).otherwise(stf.ST_Boundary(self.spark.column))
+ return self._query_geometry_column(
+ spark_expr,
+ )
@property
def centroid(self) -> "GeoSeries":
@@ -1635,7 +1608,11 @@ class GeoSeries(GeoFrame, pspd.Series):
--------
GeoSeries.representative_point : point guaranteed to be within each
geometry
"""
- return self._process_geometry_column("ST_Centroid", rename="centroid")
+ spark_expr = stf.ST_Centroid(self.spark.column)
+ return self._query_geometry_column(
+ spark_expr,
+ returns_geom=True,
+ )
def concave_hull(self, ratio=0.0, allow_holes=False):
# Implementation of the abstract method
@@ -1698,7 +1675,11 @@ class GeoSeries(GeoFrame, pspd.Series):
--------
GeoSeries.convex_hull : convex hull geometry
"""
- return self._process_geometry_column("ST_Envelope", rename="envelope")
+ spark_expr = stf.ST_Envelope(self.spark.column)
+ return self._query_geometry_column(
+ spark_expr,
+ returns_geom=True,
+ )
def minimum_rotated_rectangle(self):
# Implementation of the abstract method
@@ -1815,9 +1796,11 @@ class GeoSeries(GeoFrame, pspd.Series):
"Sedona only supports the 'structure' method for make_valid"
)
- col = self.get_first_geometry_column()
- select = f"ST_MakeValid(`{col}`, {keep_collapsed})"
- return self._query_geometry_column(select, col, rename="make_valid")
+ spark_expr = stf.ST_MakeValid(self.spark.column, keep_collapsed)
+ return self._query_geometry_column(
+ spark_expr,
+ returns_geom=True,
+ )
def reverse(self):
# Implementation of the abstract method
@@ -1897,10 +1880,9 @@ class GeoSeries(GeoFrame, pspd.Series):
return GeometryCollection()
- # returns_geom needs to be False here so we don't convert back to EWKB
format.
- tmp = self._process_geometry_column(
- "ST_Union_Aggr", rename="union_all", is_aggr=True,
returns_geom=False
- )
+ spark_expr = sta.ST_Union_Aggr(self.spark.column)
+ tmp = self._query_geometry_column(spark_expr, returns_geom=False,
is_aggr=True)
+
ps_series = tmp.take([0])
geom = ps_series.iloc[0]
return geom
@@ -2012,16 +1994,21 @@ class GeoSeries(GeoFrame, pspd.Series):
"""
# Sedona does not support GeometryCollection (errors), so we return
NULL for now to avoid error
- select = """
- CASE
- WHEN GeometryType(`L`) == 'GEOMETRYCOLLECTION' OR
GeometryType(`R`) == 'GEOMETRYCOLLECTION' THEN NULL
- ELSE ST_Crosses(`L`, `R`)
- END
- """
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+ spark_expr = F.when(
+ (stf.GeometryType(F.col("L")) == "GEOMETRYCOLLECTION")
+ | (stf.GeometryType(F.col("R")) == "GEOMETRYCOLLECTION"),
+ None,
+ ).otherwise(stp.ST_Crosses(F.col("L"), F.col("R")))
result = self._row_wise_operation(
- select, other, align, rename="crosses", default_val="FALSE"
+ spark_expr,
+ other_series,
+ align,
+ default_val=False,
)
+
return to_bool(result)
def disjoint(self, other, align=None):
@@ -2133,12 +2120,15 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.intersection
"""
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_expr = stp.ST_Intersects(F.col("L"), F.col("R"))
result = self._row_wise_operation(
- "ST_Intersects(`L`, `R`)",
- other,
+ spark_expr,
+ other_series,
align,
- rename="intersects",
- default_val="FALSE",
+ default_val=False,
)
return to_bool(result)
@@ -2234,12 +2224,15 @@ class GeoSeries(GeoFrame, pspd.Series):
# Note: We cannot efficiently match geopandas behavior because
Sedona's ST_Overlaps returns True for equal geometries
# ST_Overlaps(`L`, `R`) AND ST_Equals(`L`, `R`) does not work because
ST_Equals errors on invalid geometries
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_expr = stp.ST_Overlaps(F.col("L"), F.col("R"))
result = self._row_wise_operation(
- "ST_Overlaps(`L`, `R`)",
- other,
+ spark_expr,
+ other_series,
align,
- rename="overlaps",
- default_val="FALSE",
+ default_val=False,
)
return to_bool(result)
@@ -2347,12 +2340,15 @@ class GeoSeries(GeoFrame, pspd.Series):
"""
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_expr = stp.ST_Touches(F.col("L"), F.col("R"))
result = self._row_wise_operation(
- "ST_Touches(`L`, `R`)",
- other,
+ spark_expr,
+ other_series,
align,
- rename="touches",
- default_val="FALSE",
+ default_val=False,
)
return to_bool(result)
@@ -2462,12 +2458,16 @@ class GeoSeries(GeoFrame, pspd.Series):
--------
GeoSeries.contains
"""
+
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_expr = stp.ST_Within(F.col("L"), F.col("R"))
result = self._row_wise_operation(
- "ST_Within(`L`, `R`)",
- other,
+ spark_expr,
+ other_series,
align,
- rename="within",
- default_val="FALSE",
+ default_val=False,
)
return to_bool(result)
@@ -2578,12 +2578,16 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.covered_by
GeoSeries.overlaps
"""
+
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_expr = stp.ST_Covers(F.col("L"), F.col("R"))
result = self._row_wise_operation(
- "ST_Covers(`L`, `R`)",
- other,
+ spark_expr,
+ other_series,
align,
- rename="covers",
- default_val="FALSE",
+ default_val=False,
)
return to_bool(result)
@@ -2694,12 +2698,16 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.covers
GeoSeries.overlaps
"""
+
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_expr = stp.ST_CoveredBy(F.col("L"), F.col("R"))
result = self._row_wise_operation(
- "ST_CoveredBy(`L`, `R`)",
- other,
+ spark_expr,
+ other_series,
align,
- rename="covered_by",
- default_val="FALSE",
+ default_val=False,
)
return to_bool(result)
@@ -2790,8 +2798,15 @@ class GeoSeries(GeoFrame, pspd.Series):
dtype: float64
"""
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_expr = stf.ST_Distance(F.col("L"), F.col("R"))
result = self._row_wise_operation(
- "ST_Distance(`L`, `R`)", other, align, rename="distance",
default_val="NULL"
+ spark_expr,
+ other_series,
+ align,
+ default_val=None,
)
return result
@@ -2897,28 +2912,41 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.symmetric_difference
GeoSeries.union
"""
- return self._row_wise_operation(
- "ST_Intersection(`L`, `R`)",
- other,
+
+ other_series, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_expr = stf.ST_Intersection(F.col("L"), F.col("R"))
+ result = self._row_wise_operation(
+ spark_expr,
+ other_series,
align,
- rename="intersection",
returns_geom=True,
- default_val="NULL",
+ default_val=None,
)
+ return result
def _row_wise_operation(
self,
- select: str,
- other: Any,
+ spark_col: PySparkColumn,
+ other: pspd.Series,
align: Union[bool, None],
- rename: str,
returns_geom: bool = False,
- default_val: Union[str, None] = None,
+ default_val: Any = None,
):
"""
Helper function to perform a row-wise operation on two GeoSeries.
The self column and other column are aliased to `L` and `R`,
respectively.
+ align : bool or None (default None)
+ If True, automatically aligns GeoSeries based on their indices.
None defaults to True.
+ If False, the order of elements is preserved.
+ Note: align should also be set to False when 'other' a geoseries
created from a single object
+ (e.g. GeoSeries([Point(0, 0) * len(self)])), so that we align
based on natural ordering in case
+ the index is not the default range index from 0.
+ Alternatively, we could create 'other' using the same index as
self,
+ but that would require index=self.index.to_pandas() which is less
scalable.
+
default_val : str or None (default "FALSE")
The value to use if either L or R is null. If None, nulls are not
handled.
"""
@@ -2929,27 +2957,11 @@ class GeoSeries(GeoFrame, pspd.Series):
NATURAL_ORDER_COLUMN_NAME if align is False else
SPARK_DEFAULT_INDEX_NAME
)
- if not isinstance(other, pspd.Series):
- # generator instead of a in-memory list
- data = [other for _ in range(len(self))]
-
- # e.g int, Geom, etc
- other = (
- GeoSeries(data)
- if isinstance(other, BaseGeometry)
- else pspd.Series(data)
- )
-
- # To make sure the result is the same length, we set natural
column as the index
- # in case the index is not the default range index from 0.
- # Alternatively, we could create 'other' using the same index as
self,
- # but that would require index=self.index.to_pandas() which is
less scalable.
- index_col = NATURAL_ORDER_COLUMN_NAME
-
# This code assumes there is only one index (SPARK_DEFAULT_INDEX_NAME)
# and would need to be updated if Sedona later supports multi-index
+
df = self._internal.spark_frame.select(
- col(self.get_first_geometry_column()).alias("L"),
+ self.spark.column.alias("L"),
# For the left side:
# - We always select NATURAL_ORDER_COLUMN_NAME, to avoid having to
regenerate it in the result
# - We always select SPARK_DEFAULT_INDEX_NAME, to retain series
index info
@@ -2957,27 +2969,31 @@ class GeoSeries(GeoFrame, pspd.Series):
col(SPARK_DEFAULT_INDEX_NAME),
)
other_df = other._internal.spark_frame.select(
- col(_get_first_column_name(other)).alias("R"),
+ other.spark.column.alias("R"),
# for the right side, we only need the column that we are joining
on
col(index_col),
)
+
joined_df = df.join(other_df, on=index_col, how="outer")
if default_val is not None:
# ps.Series.fillna() doesn't always work for the output for some
reason
# so we manually handle the nulls here.
- select = f"""
+ spark_col = F.when(
+ F.col("L").isNull() | F.col("R").isNull(),
+ default_val,
+ ).otherwise(spark_col)
+ # The above is equivalent to the following:
+ f"""
CASE
WHEN `L` IS NULL OR `R` IS NULL THEN {default_val}
- ELSE {select}
+ ELSE {spark_col}
END
"""
return self._query_geometry_column(
- select,
- cols=["L", "R"],
- rename=rename,
- df=joined_df,
+ spark_col,
+ joined_df,
returns_geom=returns_geom,
)
@@ -3099,12 +3115,17 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.contains_properly
GeoSeries.within
"""
+
+ other, extended = self._make_series_of_val(other)
+ align = False if extended else align
+
+ spark_col = stp.ST_Contains(F.col("L"), F.col("R"))
result = self._row_wise_operation(
- "ST_Contains(`L`, `R`)",
+ spark_col,
other,
align,
- rename="contains",
- default_val="FALSE",
+ returns_geom=False,
+ default_val=False,
)
return to_bool(result)
@@ -3150,8 +3171,10 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries
A GeoSeries of buffered geometries.
"""
- return self._process_geometry_column(
- "ST_Buffer", rename="buffer", distance=distance
+ spark_col = stf.ST_Buffer(self.spark.column, distance)
+ return self._query_geometry_column(
+ spark_col,
+ returns_geom=True,
)
def to_parquet(self, path, **kwargs):
@@ -3164,16 +3187,8 @@ class GeoSeries(GeoFrame, pspd.Series):
- kwargs: Any
Additional arguments to pass to the Sedona DataFrame output
function.
"""
- col = self.get_first_geometry_column()
- # Convert WKB to Sedona geometry objects
- # Specify returns_geom=False to avoid turning it back into EWKB
- result = self._query_geometry_column(
- f"`{col}`",
- cols=col,
- rename="wkb",
- returns_geom=False,
- )
+ result = self._query_geometry_column(self.spark.column)
# Use the Spark DataFrame's write method to write to GeoParquet format
result._internal.spark_frame.write.format("geoparquet").save(path,
**kwargs)
@@ -3256,7 +3271,11 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.z
"""
- return self._process_geometry_column("ST_X", rename="x",
returns_geom=False)
+ spark_col = stf.ST_X(self.spark.column)
+ return self._query_geometry_column(
+ spark_col,
+ returns_geom=False,
+ )
@property
def y(self) -> pspd.Series:
@@ -3286,7 +3305,11 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.m
"""
- return self._process_geometry_column("ST_Y", rename="y",
returns_geom=False)
+ spark_col = stf.ST_Y(self.spark.column)
+ return self._query_geometry_column(
+ spark_col,
+ returns_geom=False,
+ )
@property
def z(self) -> pspd.Series:
@@ -3316,7 +3339,11 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.m
"""
- return self._process_geometry_column("ST_Z", rename="z",
returns_geom=False)
+ spark_col = stf.ST_Z(self.spark.column)
+ return self._query_geometry_column(
+ spark_col,
+ returns_geom=False,
+ )
@property
def m(self) -> pspd.Series:
@@ -3614,7 +3641,7 @@ class GeoSeries(GeoFrame, pspd.Series):
spark_df = data._internal.spark_frame
assert len(schema) == 1
spark_df = spark_df.withColumnRenamed(
- _get_first_column_name(data), schema[0].name
+ _get_series_col_name(data), schema[0].name
)
else:
spark_df = default_session().createDataFrame(data, schema=schema)
@@ -3685,10 +3712,10 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.notna : inverse of isna
GeoSeries.is_empty : detect empty geometries
"""
- col = self.get_first_geometry_column()
- select = f"`{col}` IS NULL"
+ spark_expr = F.isnull(self.spark.column)
result = self._query_geometry_column(
- select, col, rename="isna", returns_geom=False
+ spark_expr,
+ returns_geom=False,
)
return to_bool(result)
@@ -3730,11 +3757,11 @@ class GeoSeries(GeoFrame, pspd.Series):
GeoSeries.isna : inverse of notna
GeoSeries.is_empty : detect empty geometries
"""
- col = self.get_first_geometry_column()
- select = f"`{col}` IS NOT NULL"
-
+ # After Sedona's minimum spark version is 3.5.0, we can use
F.isnotnull(self.spark.column) instead
+ spark_expr = ~F.isnull(self.spark.column)
result = self._query_geometry_column(
- select, col, rename="notna", returns_geom=False
+ spark_expr,
+ returns_geom=False,
)
return to_bool(result)
@@ -3828,45 +3855,45 @@ class GeoSeries(GeoFrame, pspd.Series):
"GeoSeries.fillna() with limit is not implemented yet."
)
- col = self.get_first_geometry_column()
+ align = True
+
if pd.isna(value) == True or isinstance(value, BaseGeometry):
if (
value is not None and pd.isna(value) == True
): # ie. value is np.nan or pd.NA:
- value = "NULL"
+ value = None
else:
if value is None:
from shapely.geometry import GeometryCollection
value = GeometryCollection()
- value = f"ST_GeomFromText('{value.wkt}')"
+ other, extended = self._make_series_of_val(value)
+ align = False if extended else align
- select = f"COALESCE(`{col}`, {value})"
- result = self._query_geometry_column(select, col, "")
elif isinstance(value, (GeoSeries, GeometryArray, gpd.GeoSeries)):
if not isinstance(value, GeoSeries):
value = GeoSeries(value)
# Replace all None's with empty geometries (this is a recursive
call)
- value = value.fillna(None)
-
- # Coalesce: If the value in L is null, use the corresponding value
in R for that row
- select = f"COALESCE(`L`, `R`)"
- result = self._row_wise_operation(
- select,
- value,
- align=None,
- rename="fillna",
- returns_geom=True,
- default_val=None,
- )
+ other = value.fillna(None)
+
else:
raise ValueError(f"Invalid value type: {type(value)}")
+ # Coalesce: If the value in L is null, use the corresponding value in
R for that row
+ spark_expr = F.coalesce(F.col("L"), F.col("R"))
+ result = self._row_wise_operation(
+ spark_expr,
+ other,
+ align=align,
+ returns_geom=True,
+ default_val=None,
+ )
+
if inplace:
- self._update_anchor(_to_spark_pandas_df(result))
+ self._update_inplace(result)
return None
return result
@@ -3971,11 +3998,13 @@ class GeoSeries(GeoFrame, pspd.Series):
if old_crs.is_exact_same(crs):
return self
- col = self.get_first_geometry_column()
+ spark_expr = stf.ST_Transform(
+ self.spark.column,
+ F.lit(f"EPSG:{old_crs.to_epsg()}"),
+ F.lit(f"EPSG:{crs.to_epsg()}"),
+ )
return self._query_geometry_column(
- f"ST_Transform(`{col}`, 'EPSG:{old_crs.to_epsg()}',
'EPSG:{crs.to_epsg()}')",
- col,
- "",
+ spark_expr,
)
@property
@@ -4007,26 +4036,16 @@ class GeoSeries(GeoFrame, pspd.Series):
1 POLYGON ((0 0, 1 1, 1 0, 0 0)) 0.0 0.0 1.0 1.0
2 LINESTRING (0 1, 1 2) 0.0 1.0 1.0 2.0
"""
- col = self.get_first_geometry_column()
-
selects = [
- f"ST_XMin(`{col}`) as minx",
- f"ST_YMin(`{col}`) as miny",
- f"ST_XMax(`{col}`) as maxx",
- f"ST_YMax(`{col}`) as maxy",
+ stf.ST_XMin(self.spark.column).alias("minx"),
+ stf.ST_YMin(self.spark.column).alias("miny"),
+ stf.ST_XMax(self.spark.column).alias("maxx"),
+ stf.ST_YMax(self.spark.column).alias("maxy"),
]
df = self._internal.spark_frame
- data_type = df.schema[col].dataType
-
- if isinstance(data_type, BinaryType):
- selects = [
- select.replace(f"`{col}`", f"ST_GeomFromWKB(`{col}`)")
- for select in selects
- ]
-
- sdf = df.selectExpr(*selects)
+ sdf = df.select(*selects)
internal = InternalFrame(
spark_frame=sdf,
index_spark_columns=None,
@@ -4242,18 +4261,12 @@ class GeoSeries(GeoFrame, pspd.Series):
dtype: object
"""
- col = self.get_first_geometry_column()
- select = f"ST_AsBinary(`{col}`)"
+ spark_expr = stf.ST_AsBinary(self.spark.column)
if hex:
- # this is using pyspark's hex function since Sedona doesn't
support hex WKB conversion at the moment
- # (it only supports hex EWKB)
- select = f"hex({select})"
-
+ spark_expr = F.hex(spark_expr)
return self._query_geometry_column(
- select,
- cols=col,
- rename="to_wkb",
+ spark_expr,
returns_geom=False,
)
@@ -4293,9 +4306,9 @@ class GeoSeries(GeoFrame, pspd.Series):
--------
GeoSeries.to_wkb
"""
- return self._process_geometry_column(
- "ST_AsText",
- rename="to_wkt",
+ spark_expr = stf.ST_AsText(self.spark.column)
+ return self._query_geometry_column(
+ spark_expr,
returns_geom=False,
)
@@ -4313,22 +4326,28 @@ class GeoSeries(GeoFrame, pspd.Series):
# # Utils
#
-----------------------------------------------------------------------------
- def get_first_geometry_column(self) -> str:
- first_binary_or_geometry_col = next(
- (
- field.name
- for field in self._internal.spark_frame.schema.fields
- if isinstance(field.dataType, BinaryType)
- or field.dataType.typeName() == "geometrytype"
- ),
- None,
- )
- if first_binary_or_geometry_col:
- return first_binary_or_geometry_col
+ def _update_inplace(self, result: "GeoSeries"):
+ self.rename(result.name, inplace=True)
+ self._update_anchor(result._anchor)
- raise ValueError(
- "get_first_geometry_column: No geometry column found in the
GeoSeries."
- )
+ def _make_series_of_val(self, value: Any):
+ """
+ A helper method to turn single objects into series (ps.Series or
GeoSeries when possible)
+ Returns:
+ tuple[pspd.Series, bool]:
+ - The series of the value
+ - Whether returned value was a single object extended into a
series (useful for row-wise 'align' parameter)
+ """
+ # generator instead of a in-memory list
+ if not isinstance(value, pspd.Series):
+ lst = [value for _ in range(len(self))]
+ if isinstance(value, BaseGeometry):
+ return GeoSeries(lst), True
+ else:
+ # e.g int input
+ return pspd.Series(lst), True
+ else:
+ return value, False
# -----------------------------------------------------------------------------
@@ -4336,25 +4355,8 @@ class GeoSeries(GeoFrame, pspd.Series):
# -----------------------------------------------------------------------------
-def _get_first_column_name(series: pspd.Series) -> str:
- """
- Get the first column name of a Series.
-
- Parameters:
- - series: The input Series.
-
- Returns:
- - str: The first column name of the Series.
- """
- return next(
- field.name
- for field in series._internal.spark_frame.schema.fields
- if field.name not in (SPARK_DEFAULT_INDEX_NAME,
NATURAL_ORDER_COLUMN_NAME)
- )
-
-
-def _to_spark_pandas_df(ps_series: pspd.Series) -> pspd.DataFrame:
- return pspd.DataFrame(ps_series._psdf._internal)
+def _get_series_col_name(ps_series: pspd.Series) -> str:
+ return ps_series.name if ps_series.name else SPARK_DEFAULT_SERIES_NAME
def to_bool(ps_series: pspd.Series, default: bool = False) -> pspd.Series:
diff --git a/python/tests/geopandas/test_geodataframe.py
b/python/tests/geopandas/test_geodataframe.py
index c35854e047..4857435946 100644
--- a/python/tests/geopandas/test_geodataframe.py
+++ b/python/tests/geopandas/test_geodataframe.py
@@ -19,6 +19,7 @@ import tempfile
from shapely.geometry import (
Point,
+ Polygon,
)
import shapely
@@ -53,7 +54,38 @@ class TestDataframe(TestGeopandasBase):
sgpd_df = GeoDataFrame(obj)
check_geodataframe(sgpd_df)
- # These need to be separate to make sure Sedona's Geometry UDTs have been
registered
+ @pytest.mark.parametrize(
+ "obj",
+ [
+ pd.DataFrame(
+ {
+ "non-geom": [1, 2, 3],
+ "geometry": [
+ Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]) for _ in
range(3)
+ ],
+ }
+ ),
+ gpd.GeoDataFrame(
+ {
+ "geom2": [Point(x, x) for x in range(3)],
+ "non-geom": [4, 5, 6],
+ "geometry": [
+ Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]) for _ in
range(3)
+ ],
+ }
+ ),
+ ],
+ )
+ def test_complex_df(self, obj):
+ sgpd_df = GeoDataFrame(obj)
+ name = "geometry"
+ sgpd_df.set_geometry(name, inplace=True)
+ check_geodataframe(sgpd_df)
+ result = sgpd_df.area
+ expected = pd.Series([1.0, 1.0, 1.0], name=name)
+ self.check_pd_series_equal(result, expected)
+
+ # These need to be defined inside the function to ensure Sedona's Geometry
UDTs have been registered
def test_constructor_pandas_on_spark(self):
for obj in [
ps.DataFrame([Point(x, x) for x in range(3)]),
diff --git a/python/tests/geopandas/test_geoseries.py
b/python/tests/geopandas/test_geoseries.py
index 0c2dc97831..66d6b75d11 100644
--- a/python/tests/geopandas/test_geoseries.py
+++ b/python/tests/geopandas/test_geoseries.py
@@ -20,6 +20,7 @@ import numpy as np
import pytest
import pandas as pd
import geopandas as gpd
+import pyspark.pandas as ps
import sedona.geopandas as sgpd
from sedona.geopandas import GeoSeries
from tests.geopandas.test_geopandas_base import TestGeopandasBase
@@ -65,6 +66,18 @@ class TestGeoSeries(TestGeopandasBase):
s = sgpd.GeoSeries([])
assert s.count() == 0
+ def test_non_geom_fails(self):
+ with pytest.raises(TypeError):
+ GeoSeries([0, 1, 2])
+ with pytest.raises(TypeError):
+ GeoSeries([0, 1, 2], crs="epsg:4326")
+ with pytest.raises(TypeError):
+ GeoSeries(["a", "b", "c"])
+ with pytest.raises(TypeError):
+ GeoSeries(pd.Series([0, 1, 2]), crs="epsg:4326")
+ with pytest.raises(TypeError):
+ GeoSeries(ps.Series([0, 1, 2]))
+
def test_area(self):
result = self.geoseries.area.to_pandas()
expected = pd.Series([0.0, 0.0, 5.23, 5.23])
diff --git a/python/tests/geopandas/test_match_geopandas_series.py
b/python/tests/geopandas/test_match_geopandas_series.py
index a23d9db1f4..df4641d54b 100644
--- a/python/tests/geopandas/test_match_geopandas_series.py
+++ b/python/tests/geopandas/test_match_geopandas_series.py
@@ -128,18 +128,6 @@ class TestMatchGeopandasSeries(TestGeopandasBase):
assert isinstance(gpd_series, gpd.GeoSeries)
assert isinstance(gpd_series.geometry, gpd.GeoSeries)
- def test_non_geom_fails(self):
- with pytest.raises(TypeError):
- GeoSeries([0, 1, 2])
- with pytest.raises(TypeError):
- GeoSeries([0, 1, 2], crs="epsg:4326")
- with pytest.raises(TypeError):
- GeoSeries(["a", "b", "c"])
- with pytest.raises(TypeError):
- GeoSeries(pd.Series([0, 1, 2]), crs="epsg:4326")
- with pytest.raises(TypeError):
- GeoSeries(ps.Series([0, 1, 2]))
-
def test_to_geopandas(self):
for _, geom in self.geoms:
sgpd_result = GeoSeries(geom)