This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 629d47958c [GH-2066] Implement spatial index sindex query() and unit
tests (#2067)
629d47958c is described below
commit 629d47958cadcdfa82a41d1825f2bb410de90c37
Author: Feng Zhang <[email protected]>
AuthorDate: Wed Jul 9 11:48:02 2025 -0700
[GH-2066] Implement spatial index sindex query() and unit tests (#2067)
* [GH-] Implement GeoSeries sindex and tests
* fix test_spatial_index_with_shapely_array
* fix NumPy comparisons in the test
* do not test internal objects
* address the copilot comments and refactor __init__
* fix spark mode count
* add index building logic
* add query implementation for spatial index
* remove unused code
* fix test
* shapely only support strtree
* fix results and add log_advice to query()
* switch to use StructuredAdapter
---
python/sedona/geopandas/base.py | 2 +-
python/sedona/geopandas/geodataframe.py | 18 ++-
python/sedona/geopandas/geoindex.py | 28 ----
python/sedona/geopandas/geoseries.py | 29 +++-
python/sedona/geopandas/sindex.py | 225 ++++++++++++++++++++++++++++++++
python/tests/geopandas/test_sindex.py | 145 ++++++++++++++++++++
6 files changed, 410 insertions(+), 37 deletions(-)
diff --git a/python/sedona/geopandas/base.py b/python/sedona/geopandas/base.py
index f99a19cff8..8bc1ad4dd3 100644
--- a/python/sedona/geopandas/base.py
+++ b/python/sedona/geopandas/base.py
@@ -75,7 +75,7 @@ class GeoFrame(metaclass=ABCMeta):
@property
@abstractmethod
- def geoindex(self) -> "GeoIndex":
+ def sindex(self) -> "SpatialIndex":
raise NotImplementedError("This method is not implemented yet.")
@abstractmethod
diff --git a/python/sedona/geopandas/geodataframe.py
b/python/sedona/geopandas/geodataframe.py
index a2f90dff4b..b627d06a7d 100644
--- a/python/sedona/geopandas/geodataframe.py
+++ b/python/sedona/geopandas/geodataframe.py
@@ -28,7 +28,7 @@ from pyspark.pandas.internal import InternalFrame
from sedona.geopandas._typing import Label
from sedona.geopandas.base import GeoFrame
-from sedona.geopandas.geoindex import GeoIndex
+from sedona.geopandas.sindex import SpatialIndex
class GeoDataFrame(GeoFrame, pspd.DataFrame):
@@ -250,9 +250,19 @@ class GeoDataFrame(GeoFrame, pspd.DataFrame):
raise NotImplementedError("This method is not implemented yet.")
@property
- def geoindex(self) -> GeoIndex:
- # Implementation of the abstract method
- raise NotImplementedError("This method is not implemented yet.")
+ def sindex(self) -> SpatialIndex | None:
+ """
+ Returns a spatial index for the GeoDataFrame.
+ The spatial index allows for efficient spatial queries. If the spatial
+ index cannot be created (e.g., no geometry column is present), this
+ property will return None.
+ Returns:
+ - SpatialIndex: The spatial index for the GeoDataFrame.
+ - None: If the spatial index is not supported.
+ """
+ if "geometry" in self.columns:
+ return SpatialIndex(self._internal.spark_frame,
column_name="geometry")
+ return None
def copy(self, deep=False):
"""
diff --git a/python/sedona/geopandas/geoindex.py
b/python/sedona/geopandas/geoindex.py
deleted file mode 100644
index 4dbc04b742..0000000000
--- a/python/sedona/geopandas/geoindex.py
+++ /dev/null
@@ -1,28 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-
-class GeoIndex:
- """
- A placeholder class for GeoIndex.
- """
-
- def __init__(self):
- raise NotImplementedError("This method is not implemented yet.")
-
- def some_method(self):
- raise NotImplementedError("This method is not implemented yet.")
diff --git a/python/sedona/geopandas/geoseries.py
b/python/sedona/geopandas/geoseries.py
index 2e4dc6bb2a..6160438a60 100644
--- a/python/sedona/geopandas/geoseries.py
+++ b/python/sedona/geopandas/geoseries.py
@@ -37,7 +37,7 @@ from shapely.geometry.base import BaseGeometry
from sedona.geopandas._typing import Label
from sedona.geopandas.base import GeoFrame
from sedona.geopandas.geodataframe import GeoDataFrame
-from sedona.geopandas.geoindex import GeoIndex
+from sedona.geopandas.sindex import SpatialIndex
from pyspark.pandas.internal import (
SPARK_DEFAULT_INDEX_NAME, # __index_level_0__
@@ -506,9 +506,30 @@ class GeoSeries(GeoFrame, pspd.Series):
return self
@property
- def geoindex(self) -> "GeoIndex":
- # Implementation of the abstract method
- raise NotImplementedError("This method is not implemented yet.")
+ def sindex(self) -> SpatialIndex:
+ """
+ Returns a spatial index built from the geometries.
+
+ Returns
+ -------
+ SpatialIndex
+ The spatial index for this GeoDataFrame.
+
+ Examples
+ --------
+ >>> from shapely.geometry import Point
+ >>> from sedona.geopandas import GeoDataFrame
+ >>>
+ >>> gdf = GeoDataFrame([{"geometry": Point(1, 1), "value": 1},
+ ... {"geometry": Point(2, 2), "value": 2}])
+ >>> index = gdf.sindex
+ >>> index.size
+ 2
+ """
+ geometry_column = self.get_first_geometry_column()
+ if geometry_column is None:
+ raise ValueError("No geometry column found in GeoSeries")
+ return SpatialIndex(self._internal.spark_frame,
column_name=geometry_column)
def copy(self, deep=False):
"""
diff --git a/python/sedona/geopandas/sindex.py
b/python/sedona/geopandas/sindex.py
new file mode 100644
index 0000000000..426b2d12ab
--- /dev/null
+++ b/python/sedona/geopandas/sindex.py
@@ -0,0 +1,225 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+from pyspark.pandas.utils import log_advice
+from pyspark.sql import DataFrame as PySparkDataFrame
+
+from sedona.spark import StructuredAdapter
+from sedona.spark.core.enums import IndexType
+
+
+class SpatialIndex:
+ """
+ A wrapper around Sedona's spatial index functionality.
+ """
+
+ def __init__(self, geometry, index_type="strtree", column_name=None):
+ """
+ Initialize the SpatialIndex with geometry data.
+
+ Parameters
+ ----------
+ geometry : np.array of Shapely geometries, PySparkDataFrame column, or
PySparkDataFrame
+ index_type : str, default "strtree"
+ The type of spatial index to use.
+ column_name : str, optional
+ The column name to extract geometry from if `geometry` is a
PySparkDataFrame.
+ """
+
+ if isinstance(geometry, np.ndarray):
+ self.geometry = geometry
+ self.index_type = index_type
+ self._dataframe = None
+ self._is_spark = False
+ # Build local index for numpy array
+ self._build_local_index()
+ elif isinstance(geometry, PySparkDataFrame):
+ if column_name is None:
+ raise ValueError(
+ "column_name must be specified when geometry is a
PySparkDataFrame"
+ )
+ self.geometry = geometry[column_name]
+ self.index_type = index_type
+ self._dataframe = geometry
+ self._is_spark = True
+ # Build distributed spatial index
+ self._build_spark_index(column_name)
+ else:
+ raise TypeError(
+ "Invalid type for `geometry`. Expected np.array or
PySparkDataFrame."
+ )
+
+ def query(self, geometry, predicate=None, sort=False):
+ """
+ Query the spatial index for geometries that intersect the given
geometry.
+
+ Parameters
+ ----------
+ geometry : Shapely geometry
+ The geometry to query against the spatial index.
+ predicate : str, optional
+ Spatial predicate to filter results (e.g., 'intersects',
'contains').
+ sort : bool, optional, default False
+ Whether to sort the results.
+
+ Returns
+ -------
+ list
+ List of indices of matching geometries.
+ """
+ log_advice(
+ "`query` returns local list of indices of matching geometries onto
driver's memory. "
+ "It should only be used if the resulting collection is expected to
be small."
+ )
+
+ if self.is_empty:
+ return []
+
+ if self._is_spark:
+ # For Spark-based spatial index
+ from sedona.spark.core.spatialOperator import RangeQuery
+
+ # Execute the spatial range query
+ if predicate == "contains":
+ result_rdd = RangeQuery.SpatialRangeQuery(
+ self._indexed_rdd, geometry, True, True
+ )
+ else: # Default to intersects
+ result_rdd = RangeQuery.SpatialRangeQuery(
+ self._indexed_rdd, geometry, False, True
+ )
+
+ results = result_rdd.collect()
+ return results
+ else:
+ # For local spatial index based on Shapely STRtree
+ if predicate == "contains":
+ # STRtree doesn't directly support contains predicate
+ # We need to filter results after querying
+ candidate_indices = self._index.query(geometry)
+ results = [
+ i for i in candidate_indices if
geometry.contains(self.geometry[i])
+ ]
+ else:
+ # Default is intersects
+ results = self._index.query(geometry)
+
+ if sort and results:
+ # Sort by distance to the query geometry if requested
+ results = sorted(
+ results, key=lambda i: self.geometry[i].distance(geometry)
+ )
+
+ return results
+
+ def nearest(self, geometry, k=1, return_distance=False):
+ """
+ Find the nearest geometry in the spatial index.
+
+ Parameters
+ ----------
+ geometry : Shapely geometry
+ The geometry to find the nearest neighbor for.
+ k : int, optional, default 1
+ Number of nearest neighbors to find.
+ return_distance : bool, optional, default False
+ Whether to return distances along with indices.
+
+ Returns
+ -------
+ list or tuple
+ List of indices of nearest geometries, optionally with distances.
+ """
+ # Placeholder for KNN query using Sedona
+ raise NotImplementedError("This method is not implemented yet.")
+
+ def intersection(self, bounds):
+ """
+ Find geometries that intersect the given bounding box.
+
+ Parameters
+ ----------
+ bounds : tuple
+ Bounding box as (min_x, min_y, max_x, max_y).
+
+ Returns
+ -------
+ list
+ List of indices of matching geometries.
+ """
+ raise NotImplementedError("This method is not implemented yet.")
+
+ @property
+ def size(self):
+ """
+ Get the size of the spatial index.
+
+ Returns
+ -------
+ int
+ Number of geometries in the index.
+ """
+ if self._is_spark:
+ return self._dataframe.count()
+ return len(self.geometry)
+
+ @property
+ def is_empty(self):
+ """
+ Check if the spatial index is empty.
+
+ Returns
+ -------
+ bool
+ True if the index is empty, False otherwise.
+ """
+ return self.size == 0
+
+ def _build_spark_index(self, column_name):
+ """
+ Build a distributed spatial index on the geometry column of the
DataFrame.
+
+ This uses Sedona's built-in indexing functionality.
+ """
+
+ # Convert index_type string to Sedona IndexType enum
+ index_type_map = {"strtree": IndexType.RTREE, "quadtree":
IndexType.QUADTREE}
+ sedona_index_type = index_type_map.get(self.index_type.lower(),
IndexType.RTREE)
+
+ # Create a SpatialRDD from the DataFrame
+ spatial_rdd = StructuredAdapter.toSpatialRdd(self._dataframe,
column_name)
+
+ # Build spatial index
+ spatial_rdd.buildIndex(sedona_index_type, False)
+
+ # Store the indexed RDD
+ self._indexed_rdd = spatial_rdd
+
+ def _build_local_index(self):
+ """
+ Build a local spatial index for numpy array of geometries.
+ """
+ from shapely.strtree import STRtree
+
+ if len(self.geometry) > 0:
+ if self.index_type.lower() == "strtree":
+ self._index = STRtree(self.geometry)
+ else:
+ raise ValueError(
+ f"Unsupported index type: {self.index_type}. Only
'strtree' is supported for local indexing."
+ )
diff --git a/python/tests/geopandas/test_sindex.py
b/python/tests/geopandas/test_sindex.py
new file mode 100644
index 0000000000..a6e76b65d0
--- /dev/null
+++ b/python/tests/geopandas/test_sindex.py
@@ -0,0 +1,145 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import numpy as np
+from pyspark.sql.functions import expr
+from shapely.geometry import Point, Polygon, LineString
+
+from tests.test_base import TestBase
+from sedona.geopandas import GeoSeries
+from sedona.geopandas.sindex import SpatialIndex
+
+
+class TestSpatialIndex(TestBase):
+ """Tests for the spatial index functionality in GeoSeries."""
+
+ def setup_method(self):
+ """Set up test data."""
+ # Create a GeoSeries with point geometries
+ self.points = GeoSeries(
+ [Point(0, 0), Point(1, 1), Point(2, 2), Point(3, 3), Point(4, 4)]
+ )
+
+ # Create a GeoSeries with polygon geometries
+ self.polygons = GeoSeries(
+ [
+ Polygon([(0, 0), (1, 0), (1, 1), (0, 1)]),
+ Polygon([(1, 1), (2, 1), (2, 2), (1, 2)]),
+ Polygon([(2, 2), (3, 2), (3, 3), (2, 3)]),
+ Polygon([(3, 3), (4, 3), (4, 4), (3, 4)]),
+ Polygon([(4, 4), (5, 4), (5, 5), (4, 5)]),
+ ]
+ )
+
+ # Create a GeoSeries with line geometries
+ self.lines = GeoSeries(
+ [
+ LineString([(0, 0), (1, 1)]),
+ LineString([(1, 1), (2, 2)]),
+ LineString([(2, 2), (3, 3)]),
+ LineString([(3, 3), (4, 4)]),
+ LineString([(4, 4), (5, 5)]),
+ ]
+ )
+
+ def test_sindex_property_exists(self):
+ """Test that the sindex property exists on GeoSeries."""
+ assert hasattr(self.points, "sindex")
+ assert hasattr(self.polygons, "sindex")
+ assert hasattr(self.lines, "sindex")
+
+ def test_query_with_point(self):
+ """Test querying the spatial index with a point geometry."""
+ # Create a list of Shapely geometries - squares around points (0,0),
(1,1), etc.
+ geometries = [
+ Polygon(
+ [
+ (i - 0.5, j - 0.5),
+ (i + 0.5, j - 0.5),
+ (i + 0.5, j + 0.5),
+ (i - 0.5, j + 0.5),
+ ]
+ )
+ for i, j in [(0, 0), (1, 1), (2, 2), (3, 3), (4, 4)]
+ ]
+
+ # Create a spatial index from the geometries
+ geom_array = np.array(geometries, dtype=object)
+ sindex = SpatialIndex(geom_array)
+
+ # Test query with a point that should intersect with one polygon
+ query_point = Point(2.2, 2.2)
+ result_indices = sindex.query(query_point)
+ assert len(result_indices) == 1
+
+ # Test query with a point that intersects no polygons
+ empty_point = Point(10, 10)
+ empty_results = sindex.query(empty_point)
+ assert len(empty_results) == 0
+
+ def test_query_with_spark_dataframe(self):
+ """Test querying the spatial index with a Spark DataFrame."""
+ # Create a spatial DataFrame with polygons
+ polygons_data = [
+ (1, "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))"),
+ (2, "POLYGON((1 1, 2 1, 2 2, 1 2, 1 1))"),
+ (3, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"),
+ (4, "POLYGON((3 3, 4 3, 4 4, 3 4, 3 3))"),
+ (5, "POLYGON((4 4, 5 4, 5 5, 4 5, 4 4))"),
+ ]
+
+ df = self.spark.createDataFrame(polygons_data, ["id", "wkt"])
+ spatial_df = df.withColumn("geometry", expr("ST_GeomFromWKT(wkt)"))
+
+ # Create a SpatialIndex from the DataFrame
+ sindex = SpatialIndex(spatial_df, index_type="strtree",
column_name="geometry")
+
+ # Test query with a point that should intersect with one polygon
+ from shapely.geometry import Point
+
+ query_point = Point(2.2, 2.2)
+
+ # Execute query
+ result_indices = sindex.query(query_point, "contains")
+
+ # Verify results - should find at least one result (polygon containing
the point)
+ assert len(result_indices) > 0
+
+ # Test query with a polygon that should intersect multiple polygons
+ from shapely.geometry import box
+
+ query_box = box(1.5, 1.5, 3.5, 3.5)
+
+ # Execute query
+ box_results = sindex.query(query_box, predicate="contains")
+
+ # Verify results - should find multiple polygons
+ assert len(box_results) > 1
+
+ # Test with contains predicate
+ # The query box fully contains polygon at index 2 (POLYGON((2 2, 3 2,
3 3, 2 3, 2 2)))
+ contains_results = sindex.query(query_box, predicate="contains")
+
+ # Verify contains results
+ assert len(contains_results) >= 1
+
+ # Test with a point outside any polygon
+ outside_point = Point(10, 10)
+ outside_results = sindex.query(outside_point)
+
+ # Verify no results for point outside
+ assert len(outside_results) == 0