This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new cbef39fdfa [GH-2100] Enhance geopandas sjoin implementation (#2101)
cbef39fdfa is described below
commit cbef39fdfa4985ce16e5049835973cefbe96a446
Author: Feng Zhang <[email protected]>
AuthorDate: Thu Jul 17 08:11:03 2025 -0700
[GH-2100] Enhance geopandas sjoin implementation (#2101)
* [GH-2100] Enhance geopandas sjoin implementation
* address copilot review
* remove numpy import
---
python/sedona/geopandas/tools/sjoin.py | 242 ++++++++++++++++++++++++++++-----
python/tests/geopandas/test_sjoin.py | 220 +++++++++++++++++++++++++++++-
2 files changed, 428 insertions(+), 34 deletions(-)
diff --git a/python/sedona/geopandas/tools/sjoin.py
b/python/sedona/geopandas/tools/sjoin.py
index 9c8a345a9e..e0dcd8921c 100644
--- a/python/sedona/geopandas/tools/sjoin.py
+++ b/python/sedona/geopandas/tools/sjoin.py
@@ -14,54 +14,204 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
+import re
from pyspark.pandas.internal import InternalFrame
from pyspark.pandas.series import first_series
from pyspark.pandas.utils import scol_for
-from pyspark.sql.functions import expr
+from pyspark.sql.functions import expr, col, lit
+from pyspark.sql.types import StructType, StructField, StringType, IntegerType
from sedona.geopandas import GeoDataFrame, GeoSeries
from sedona.geopandas.geoseries import _to_geo_series
from pyspark.pandas.frame import DataFrame as PandasOnSparkDataFrame
+# Pre-compiled regex pattern for suffix validation
+SUFFIX_PATTERN = re.compile(r"^[a-zA-Z_][a-zA-Z0-9_]*$")
-def _frame_join(left_df, right_df):
+
+def _frame_join(
+ left_df,
+ right_df,
+ how="inner",
+ predicate="intersects",
+ lsuffix="left",
+ rsuffix="right",
+ distance=None,
+ on_attribute=None,
+):
"""Join the GeoDataFrames at the DataFrame level.
Parameters
----------
- left_df : GeoDataFrame
- right_df : GeoDataFrame
+ left_df : GeoDataFrame or GeoSeries
+ Left dataset to join
+ right_df : GeoDataFrame or GeoSeries
+ Right dataset to join
+ how : str, default 'inner'
+ Join type: 'inner', 'left', 'right'
+ predicate : str, default 'intersects'
+ Spatial predicate to use
+ lsuffix : str, default 'left'
+ Suffix for left overlapping columns
+ rsuffix : str, default 'right'
+ Suffix for right overlapping columns
+ distance : float, optional
+ Distance parameter for dwithin predicate
+ on_attribute : list, optional
+ Additional columns to join on
Returns
-------
- GeoDataFrame
- Joined GeoDataFrame.
-
- TODO: Implement this function with more details and parameters.
+ GeoDataFrame or GeoSeries
+ Joined result
"""
+ # Predicate mapping
+ predicate_map = {
+ "intersects": "ST_Intersects",
+ "contains": "ST_Contains",
+ "within": "ST_Within",
+ "touches": "ST_Touches",
+ "crosses": "ST_Crosses",
+ "overlaps": "ST_Overlaps",
+ "dwithin": "ST_DWithin",
+ }
+
+ if predicate not in predicate_map:
+ raise ValueError(
+ f"Predicate '{predicate}' not supported. Available:
{list(predicate_map.keys())}"
+ )
+
+ spatial_func = predicate_map[predicate]
+
# Get the internal Spark DataFrames
left_sdf = left_df._internal.spark_frame
right_sdf = right_df._internal.spark_frame
- # Convert WKB to geometry
- left_geo_df = left_sdf.selectExpr("ST_GeomFromWKB(`0`) as l_geometry")
- right_geo_df = right_sdf.selectExpr("ST_GeomFromWKB(`0`) as r_geometry")
+ # Handle geometry columns - check if they exist and get proper column names
+ left_geom_col = None
+ right_geom_col = None
- # Perform Spatial Join using ST_Intersects
- spatial_join_df = left_geo_df.alias("l").join(
- right_geo_df.alias("r"), expr("ST_Intersects(l_geometry, r_geometry)")
- )
+ # Find geometry columns in left dataframe
+ for field in left_sdf.schema.fields:
+ if field.dataType.typeName() in ("geometrytype", "binary"):
+ left_geom_col = field.name
+ break
- # Use the provided code template to create an InternalFrame and return a
GeoSeries
- internal = InternalFrame(
- spark_frame=spatial_join_df,
- index_spark_columns=None,
- column_labels=[left_df._col_label],
- data_spark_columns=[scol_for(spatial_join_df, "l_geometry")],
- data_fields=[left_df._internal.data_fields[0]],
- column_label_names=left_df._internal.column_label_names,
- )
- return _to_geo_series(first_series(PandasOnSparkDataFrame(internal)))
+ # Find geometry columns in right dataframe
+ for field in right_sdf.schema.fields:
+ if field.dataType.typeName() in ("geometrytype", "binary"):
+ right_geom_col = field.name
+ break
+
+ if left_geom_col is None or right_geom_col is None:
+ raise ValueError("Both datasets must have geometry columns")
+
+ # Prepare geometry expressions for spatial join
+ if left_sdf.schema[left_geom_col].dataType.typeName() == "binary":
+ left_geom_expr = f"ST_GeomFromWKB(`{left_geom_col}`) as l_geometry"
+ else:
+ left_geom_expr = f"`{left_geom_col}` as l_geometry"
+
+ if right_sdf.schema[right_geom_col].dataType.typeName() == "binary":
+ right_geom_expr = f"ST_GeomFromWKB(`{right_geom_col}`) as r_geometry"
+ else:
+ right_geom_expr = f"`{right_geom_col}` as r_geometry"
+
+ # Select all columns with geometry
+ left_cols = [left_geom_expr] + [
+ f"`{field.name}` as l_{field.name}"
+ for field in left_sdf.schema.fields
+ if field.name != left_geom_col and not field.name.startswith("__")
+ ]
+ right_cols = [right_geom_expr] + [
+ f"`{field.name}` as r_{field.name}"
+ for field in right_sdf.schema.fields
+ if field.name != right_geom_col and not field.name.startswith("__")
+ ]
+
+ left_geo_df = left_sdf.selectExpr(*left_cols)
+ right_geo_df = right_sdf.selectExpr(*right_cols)
+
+ # Build spatial join condition
+ if predicate == "dwithin":
+ if distance is None:
+ raise ValueError("Distance parameter is required for 'dwithin'
predicate")
+ spatial_condition = f"{spatial_func}(l_geometry, r_geometry,
{distance})"
+ else:
+ spatial_condition = f"{spatial_func}(l_geometry, r_geometry)"
+
+ # Add attribute-based join condition if specified
+ join_condition = spatial_condition
+ if on_attribute:
+ for attr in on_attribute:
+ join_condition += f" AND l_{attr} = r_{attr}"
+
+ # Perform spatial join based on join type
+ if how == "inner":
+ spatial_join_df = left_geo_df.alias("l").join(
+ right_geo_df.alias("r"), expr(join_condition)
+ )
+ elif how == "left":
+ spatial_join_df = left_geo_df.alias("l").join(
+ right_geo_df.alias("r"), expr(join_condition), "left"
+ )
+ elif how == "right":
+ spatial_join_df = left_geo_df.alias("l").join(
+ right_geo_df.alias("r"), expr(join_condition), "right"
+ )
+ else:
+ raise ValueError(f"Join type '{how}' not supported")
+
+ # Handle column naming with suffixes
+ final_columns = []
+
+ # Add geometry column (always from left for geopandas compatibility)
+ final_columns.append("l_geometry as geometry")
+
+ # Add other columns with suffix handling
+ left_data_cols = [col for col in left_geo_df.columns if col !=
"l_geometry"]
+ right_data_cols = [col for col in right_geo_df.columns if col !=
"r_geometry"]
+
+ for col_name in left_data_cols:
+ base_name = col_name[2:] # Remove "l_" prefix
+ right_col = f"r_{base_name}"
+
+ if right_col in right_data_cols:
+ # Column exists in both - apply suffixes
+ final_columns.append(f"{col_name} as {base_name}_{lsuffix}")
+ else:
+ # Column only in left
+ final_columns.append(f"{col_name} as {base_name}")
+
+ for col_name in right_data_cols:
+ base_name = col_name[2:] # Remove "r_" prefix
+ left_col = f"l_{base_name}"
+
+ if left_col in left_data_cols:
+ # Column exists in both - apply suffixes
+ final_columns.append(f"{col_name} as {base_name}_{rsuffix}")
+ else:
+ # Column only in right
+ final_columns.append(f"{col_name} as {base_name}")
+
+ # Select final columns
+ result_df = spatial_join_df.selectExpr(*final_columns)
+
+ # Return appropriate type based on input
+ if isinstance(left_df, GeoSeries) and isinstance(right_df, GeoSeries):
+ # Return GeoSeries for GeoSeries inputs
+ internal = InternalFrame(
+ spark_frame=result_df,
+ index_spark_columns=None,
+ column_labels=[left_df._col_label],
+ data_spark_columns=[scol_for(result_df, "geometry")],
+ data_fields=[left_df._internal.data_fields[0]],
+ column_label_names=left_df._internal.column_label_names,
+ )
+ return _to_geo_series(first_series(PandasOnSparkDataFrame(internal)))
+ else:
+ # Return GeoDataFrame for GeoDataFrame inputs
+ return GeoDataFrame(result_df)
def sjoin(
@@ -139,6 +289,12 @@ def sjoin(
joined = _frame_join(
left_df,
right_df,
+ how=how,
+ predicate=predicate,
+ lsuffix=lsuffix,
+ rsuffix=rsuffix,
+ distance=distance,
+ on_attribute=on_attribute,
)
return joined
@@ -161,8 +317,8 @@ def _basic_checks(left_df, right_df, how, lsuffix, rsuffix,
on_attribute=None):
Parameters
------------
- left_df : GeoDataFrame
- right_df : GeoData Frame
+ left_df : GeoDataFrame or GeoSeries
+ right_df : GeoDataFrame or GeoSeries
how : str, one of 'left', 'right', 'inner'
join type
lsuffix : str
@@ -172,12 +328,34 @@ def _basic_checks(left_df, right_df, how, lsuffix,
rsuffix, on_attribute=None):
on_attribute : list, default None
list of column names to merge on along with geometry
"""
- if not isinstance(left_df, GeoSeries):
- raise ValueError(f"'left_df' should be GeoSeries, got {type(left_df)}")
+ if not isinstance(left_df, (GeoSeries, GeoDataFrame)):
+ raise ValueError(
+ f"'left_df' should be GeoSeries or GeoDataFrame, got
{type(left_df)}"
+ )
- if not isinstance(right_df, GeoSeries):
- raise ValueError(f"'right_df' should be GeoSeries, got
{type(right_df)}")
+ if not isinstance(right_df, (GeoSeries, GeoDataFrame)):
+ raise ValueError(
+ f"'right_df' should be GeoSeries or GeoDataFrame, got
{type(right_df)}"
+ )
- allowed_hows = ["inner"]
+ allowed_hows = ["inner", "left", "right"]
if how not in allowed_hows:
raise ValueError(f'`how` was "{how}" but is expected to be in
{allowed_hows}')
+
+ # Check if on_attribute columns exist in both datasets
+ if on_attribute:
+ for attr in on_attribute:
+ if hasattr(left_df, "columns") and attr not in left_df.columns:
+ raise ValueError(f"Column '{attr}' not found in left dataset")
+ if hasattr(right_df, "columns") and attr not in right_df.columns:
+ raise ValueError(f"Column '{attr}' not found in right dataset")
+
+ # Check for reserved column names that would conflict
+ if lsuffix == rsuffix:
+ raise ValueError("lsuffix and rsuffix cannot be the same")
+
+ # Validate suffix format (should not contain special characters that would
break SQL)
+ if not SUFFIX_PATTERN.match(lsuffix):
+ raise ValueError(f"lsuffix '{lsuffix}' contains invalid characters")
+ if not SUFFIX_PATTERN.match(rsuffix):
+ raise ValueError(f"rsuffix '{rsuffix}' contains invalid characters")
diff --git a/python/tests/geopandas/test_sjoin.py
b/python/tests/geopandas/test_sjoin.py
index 01d85e6e03..7448193507 100644
--- a/python/tests/geopandas/test_sjoin.py
+++ b/python/tests/geopandas/test_sjoin.py
@@ -16,27 +16,64 @@
# under the License.
import shutil
import tempfile
+import pytest
-from shapely.geometry import Polygon
-from sedona.geopandas import GeoSeries, sjoin
+from shapely.geometry import Polygon, Point, LineString
+from sedona.geopandas import GeoSeries, GeoDataFrame, sjoin
from tests.test_base import TestBase
class TestSpatialJoin(TestBase):
def setup_method(self):
self.tempdir = tempfile.mkdtemp()
+
+ # Basic geometries
self.t1 = Polygon([(0, 0), (1, 0), (1, 1)])
self.t2 = Polygon([(0, 0), (1, 1), (0, 1)])
self.sq = Polygon([(0, 0), (1, 0), (1, 1), (0, 1)])
+ self.point1 = Point(0.5, 0.5)
+ self.point2 = Point(1.5, 1.5)
+ self.line1 = LineString([(0, 0), (1, 1)])
+
+ # GeoSeries for testing
self.g1 = GeoSeries([self.t1, self.t2])
self.g2 = GeoSeries([self.sq, self.t1])
self.g3 = GeoSeries([self.t1, self.t2], crs="epsg:4326")
self.g4 = GeoSeries([self.t2, self.t1])
+ # GeoDataFrames for testing
+ self.gdf1 = GeoDataFrame(
+ {"geometry": [self.t1, self.t2], "id": [1, 2], "name": ["poly1",
"poly2"]}
+ )
+ self.gdf2 = GeoDataFrame(
+ {
+ "geometry": [self.sq, self.t1],
+ "id": [3, 4],
+ "category": ["square", "triangle"],
+ }
+ )
+ self.gdf_points = GeoDataFrame(
+ {
+ "geometry": [self.point1, self.point2],
+ "id": [5, 6],
+ "type": ["inside", "outside"],
+ }
+ )
+
+ # Test data for distance operations
+ self.nearby_points = GeoDataFrame(
+ {
+ "geometry": [Point(0.1, 0.1), Point(2.0, 2.0)],
+ "id": [7, 8],
+ "distance_type": ["close", "far"],
+ }
+ )
+
def teardown_method(self):
shutil.rmtree(self.tempdir)
def test_sjoin_method1(self):
+ """Test basic sjoin functionality with GeoSeries"""
left = self.g1
right = self.g2
joined = sjoin(left, right)
@@ -45,9 +82,188 @@ class TestSpatialJoin(TestBase):
assert joined.count() == 4
def test_sjoin_method2(self):
+ """Test GeoSeries.sjoin method"""
left = self.g1
right = self.g2
joined = left.sjoin(right)
assert joined is not None
assert type(joined) is GeoSeries
assert joined.count() == 4
+
+ def test_sjoin_geodataframe_basic(self):
+ """Test basic sjoin with GeoDataFrame"""
+ joined = sjoin(self.gdf1, self.gdf2)
+ assert joined is not None
+ assert type(joined) is GeoDataFrame
+ assert "geometry" in joined.columns
+ assert "id_left" in joined.columns
+ assert "id_right" in joined.columns
+ assert "name" in joined.columns
+ assert "category" in joined.columns
+
+ def test_sjoin_geodataframe_method(self):
+ """Test GeoDataFrame.sjoin method"""
+ joined = self.gdf1.sjoin(self.gdf2)
+ assert joined is not None
+ assert type(joined) is GeoDataFrame
+ assert "geometry" in joined.columns
+
+ def test_sjoin_predicates(self):
+ """Test different spatial predicates"""
+ predicates = [
+ "intersects",
+ "contains",
+ "within",
+ "touches",
+ "crosses",
+ "overlaps",
+ ]
+
+ for predicate in predicates:
+ try:
+ joined = sjoin(self.gdf1, self.gdf2, predicate=predicate)
+ assert joined is not None
+ assert type(joined) is GeoDataFrame
+ except Exception as e:
+ # Some predicates might not return results for our test data
+ # but the function should not raise errors for valid predicates
+ if "not supported" in str(e):
+ pytest.fail(f"Predicate '{predicate}' should be supported")
+
+ def test_sjoin_join_types(self):
+ """Test different join types"""
+ join_types = ["inner", "left", "right"]
+
+ for how in join_types:
+ joined = sjoin(self.gdf1, self.gdf2, how=how)
+ assert joined is not None
+ assert type(joined) is GeoDataFrame
+ assert "geometry" in joined.columns
+
+ def test_sjoin_column_suffixes(self):
+ """Test column suffix handling"""
+ joined = sjoin(self.gdf1, self.gdf2, lsuffix="_left", rsuffix="_right")
+ assert joined is not None
+ assert type(joined) is GeoDataFrame
+
+ # Check that suffixes are applied to overlapping columns
+ columns = joined.columns
+ if "id_left" in columns and "id_right" in columns:
+ # Both datasets have 'id' column, so suffixes should be applied
+ assert "id_left" in columns
+ assert "id_right" in columns
+ assert "id" not in columns # Original column should not exist
+
+ def test_sjoin_dwithin_distance(self):
+ """Test dwithin predicate with distance parameter"""
+ # Test with a distance that should capture nearby points
+ joined = sjoin(self.gdf1, self.nearby_points, predicate="dwithin",
distance=0.5)
+ assert joined is not None
+ assert type(joined) is GeoDataFrame
+
+ # Test with a very small distance that should capture fewer points
+ joined_small = sjoin(
+ self.gdf1, self.nearby_points, predicate="dwithin", distance=0.05
+ )
+ assert joined_small is not None
+ assert type(joined_small) is GeoDataFrame
+
+ def test_sjoin_on_attribute(self):
+ """Test attribute-based joining"""
+ # Create datasets with matching attribute columns
+ gdf1_attr = GeoDataFrame(
+ {"geometry": [self.t1, self.t2], "zone": ["A", "B"], "value": [1,
2]}
+ )
+ gdf2_attr = GeoDataFrame(
+ {
+ "geometry": [self.sq, self.t1],
+ "zone": ["A", "B"],
+ "category": ["square", "triangle"],
+ }
+ )
+
+ # Test joining on attribute
+ joined = sjoin(gdf1_attr, gdf2_attr, on_attribute=["zone"])
+ assert joined is not None
+ assert type(joined) is GeoDataFrame
+
+ def test_sjoin_points_in_polygons(self):
+ """Test point-in-polygon spatial join"""
+ joined = sjoin(self.gdf_points, self.gdf1, predicate="within")
+ assert joined is not None
+ assert type(joined) is GeoDataFrame
+
+ # The first point should be within the polygon
+ # The second point should be outside
+ # Check that we have some results (at least the point inside the
polygon)
+ assert len(joined) >= 0 # At least no errors
+
+ def test_sjoin_error_handling(self):
+ """Test error handling for invalid inputs"""
+
+ # Test invalid predicate
+ with pytest.raises(ValueError, match="not supported"):
+ sjoin(self.gdf1, self.gdf2, predicate="invalid_predicate")
+
+ # Test invalid join type
+ with pytest.raises(ValueError, match="expected to be in"):
+ sjoin(self.gdf1, self.gdf2, how="invalid_join")
+
+ # Test dwithin without distance
+ with pytest.raises(ValueError, match="Distance parameter is required"):
+ sjoin(self.gdf1, self.gdf2, predicate="dwithin")
+
+ # Test same suffixes
+ with pytest.raises(ValueError, match="cannot be the same"):
+ sjoin(self.gdf1, self.gdf2, lsuffix="same", rsuffix="same")
+
+ # Test invalid suffix characters
+ with pytest.raises(ValueError, match="invalid characters"):
+ sjoin(self.gdf1, self.gdf2, lsuffix="invalid-suffix")
+
+ def test_sjoin_empty_results(self):
+ """Test sjoin with geometries that don't intersect"""
+ # Create geometries that are far apart
+ far_gdf = GeoDataFrame(
+ {
+ "geometry": [Polygon([(10, 10), (11, 10), (11, 11), (10,
11)])],
+ "id": [99],
+ }
+ )
+
+ joined = sjoin(self.gdf1, far_gdf)
+ assert joined is not None
+ assert type(joined) is GeoDataFrame
+ # Should have 0 rows for inner join with non-intersecting geometries
+
+ def test_sjoin_mixed_geometry_types(self):
+ """Test sjoin with mixed geometry types"""
+ # Create a dataset with mixed geometry types
+ mixed_gdf = GeoDataFrame(
+ {
+ "geometry": [self.point1, self.line1, self.sq],
+ "id": [100, 101, 102],
+ "geom_type": ["point", "line", "polygon"],
+ }
+ )
+
+ joined = sjoin(self.gdf1, mixed_gdf)
+ assert joined is not None
+ assert type(joined) is GeoDataFrame
+
+ def test_sjoin_performance_basic(self):
+ """Basic performance test with slightly larger dataset"""
+ # Create slightly larger test datasets
+
+ # Create a grid of points
+ points = []
+ for i in range(10):
+ for j in range(10):
+ points.append(Point(i * 0.1, j * 0.1))
+
+ large_points_gdf = GeoDataFrame({"geometry": points, "id":
range(len(points))})
+
+ # Test join performance
+ joined = sjoin(large_points_gdf, self.gdf1)
+ assert joined is not None
+ assert type(joined) is GeoDataFrame