This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 57146118c0 [GH-2127] Add Shapely and WKT geometry support to STAC
reader (#2128)
57146118c0 is described below
commit 57146118c0c49912fae3ebbcd9d27c98c8b2acb2
Author: Feng Zhang <[email protected]>
AuthorDate: Sun Jul 20 23:25:55 2025 -0700
[GH-2127] Add Shapely and WKT geometry support to STAC reader (#2128)
* [GH-2127] Add Shapely and WKT geometry support to STAC reader
Enhanced the STAC reader to support Shapely geometry objects and WKT
strings as spatial filters, providing more flexibility than
the existing bbox-only approach. This improvement goes beyond
pystac-client's capabilities while maintaining full backward
compatibility.
New Features
- New geometry parameter added to all STAC client methods (search,
get_items, get_dataframe, save_to_geoparquet)
- Multiple input format support:
- Shapely geometry objects (Polygon, etc.)
- WKT strings (e.g., "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))")
- Lists of mixed geometries and WKT strings
- Smart precedence: When both bbox and geometry are provided, geometry
takes precedence
- Seamless integration with existing datetime and other filters
Usage Examples
from shapely.geometry import Polygon
# Using WKT string
items = client.search(geometry="POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))")
# Using Shapely polygon
polygon = Polygon([(0, 0), (1, 0), (1, 1), (0, 1), (0, 0)])
items = client.search(geometry=polygon)
# Using list of geometries
geometries = [polygon, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"]
items = client.search(geometry=geometries)
# Combined with other filters
items = client.search(
geometry=polygon,
datetime=["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]
)
* Update python/sedona/spark/stac/collection_client.py
Co-authored-by: Copilot <[email protected]>
* revert incorrect refactoring
---------
Co-authored-by: Copilot <[email protected]>
---
python/sedona/spark/stac/client.py | 23 ++++-
python/sedona/spark/stac/collection_client.py | 76 +++++++++++++---
python/tests/stac/test_collection_client.py | 125 ++++++++++++++++++++++++++
3 files changed, 208 insertions(+), 16 deletions(-)
diff --git a/python/sedona/spark/stac/client.py
b/python/sedona/spark/stac/client.py
index f4af0eeb1b..ac954ea17f 100644
--- a/python/sedona/spark/stac/client.py
+++ b/python/sedona/spark/stac/client.py
@@ -14,12 +14,13 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-from typing import Union, Optional, Iterator
+from typing import Union, Optional, Iterator, List
from sedona.spark.stac.collection_client import CollectionClient
import datetime as python_datetime
from pystac import Item as PyStacItem
+from shapely.geometry.base import BaseGeometry
from pyspark.sql import DataFrame
@@ -77,6 +78,9 @@ class Client:
*ids: Union[str, list],
collection_id: Optional[str] = None,
bbox: Optional[list] = None,
+ geometry: Optional[
+ Union[str, BaseGeometry, List[Union[str, BaseGeometry]]]
+ ] = None,
datetime: Optional[Union[str, python_datetime.datetime, list]] = None,
max_items: Optional[int] = None,
return_dataframe: bool = True,
@@ -95,6 +99,11 @@ class Client:
Each bounding box is represented as a list of four float values:
[min_lon, min_lat, max_lon, max_lat].
Example: [[-180.0, -90.0, 180.0, 90.0]] # This bounding box covers
the entire world.
+ - geometry (Optional[Union[str, BaseGeometry, List[Union[str,
BaseGeometry]]]]): Shapely geometry object(s) or WKT string(s) for spatial
filtering.
+ Can be a single geometry, WKT string, or a list of geometries/WKT
strings.
+ If both bbox and geometry are provided, geometry takes precedence.
+ Example: Polygon(...) or "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))" or
[Polygon(...), Polygon(...)]
+
- datetime (Optional[Union[str, python_datetime.datetime, list]]): A
single datetime, RFC 3339-compliant timestamp,
or a list of date-time ranges for filtering the items. The datetime
can be specified in various formats:
- "YYYY" expands to ["YYYY-01-01T00:00:00Z", "YYYY-12-31T23:59:59Z"]
@@ -119,9 +128,17 @@ class Client:
client = self.get_collection_from_catalog()
if return_dataframe:
return client.get_dataframe(
- *ids, bbox=bbox, datetime=datetime, max_items=max_items
+ *ids,
+ bbox=bbox,
+ geometry=geometry,
+ datetime=datetime,
+ max_items=max_items,
)
else:
return client.get_items(
- *ids, bbox=bbox, datetime=datetime, max_items=max_items
+ *ids,
+ bbox=bbox,
+ geometry=geometry,
+ datetime=datetime,
+ max_items=max_items,
)
diff --git a/python/sedona/spark/stac/collection_client.py
b/python/sedona/spark/stac/collection_client.py
index 0d04c06575..bc020cfdf8 100644
--- a/python/sedona/spark/stac/collection_client.py
+++ b/python/sedona/spark/stac/collection_client.py
@@ -16,13 +16,14 @@
# under the License.
import logging
-from typing import Iterator, Union
+from typing import Iterator, Union, List
from typing import Optional
import datetime as python_datetime
from pyspark.sql import DataFrame, SparkSession
from pyspark.sql.types import dt
from pystac import Item as PyStacItem
+from shapely.geometry.base import BaseGeometry
def get_collection_url(url: str, collection_id: Optional[str] = None) -> str:
@@ -101,7 +102,7 @@ class CollectionClient:
@staticmethod
def _apply_spatial_temporal_filters(
- df: DataFrame, bbox=None, datetime=None
+ df: DataFrame, bbox=None, geometry=None, datetime=None
) -> DataFrame:
"""
This function applies spatial and temporal filters to a Spark
DataFrame.
@@ -111,6 +112,8 @@ class CollectionClient:
- bbox (Optional[list]): A list of bounding boxes for filtering the
items.
Each bounding box is represented as a list of four float values:
[min_lon, min_lat, max_lon, max_lat].
Example: [[-180.0, -90.0, 180.0, 90.0]] # This bounding box covers
the entire world.
+ - geometry (Optional[list]): A list of geometry objects (Shapely or
WKT) for spatial filtering.
+ If both bbox and geometry are provided, geometry takes precedence.
- datetime (Optional[list]): A list of date-time ranges for filtering
the items.
Each date-time range is represented as a list of two strings in ISO
8601 format: [start_datetime, end_datetime].
Example: [["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]] # This
interval covers the entire year of 2020.
@@ -119,16 +122,35 @@ class CollectionClient:
- DataFrame: The filtered Spark DataFrame.
The function constructs SQL conditions for spatial and temporal
filters and applies them to the DataFrame.
- If bbox is provided, it constructs spatial conditions using
st_intersects and ST_GeomFromText.
+ If geometry is provided, it takes precedence over bbox for spatial
filtering.
+ If bbox is provided (and no geometry), it constructs spatial
conditions using st_intersects and ST_GeomFromText.
If datetime is provided, it constructs temporal conditions using the
datetime column.
The conditions are combined using OR logic.
"""
- if bbox:
+ # Geometry takes precedence over bbox
+ if geometry:
+ geometry_conditions = []
+ for geom in geometry:
+ if isinstance(geom, str):
+ # Assume it's WKT
+ geom_wkt = geom
+ elif hasattr(geom, "wkt"):
+ # Shapely geometry object
+ geom_wkt = geom.wkt
+ else:
+ # Try to convert to string (fallback)
+ geom_wkt = str(geom)
+ geometry_conditions.append(
+ f"st_intersects(ST_GeomFromText('{geom_wkt}'), geometry)"
+ )
+ geometry_sql_condition = " OR ".join(geometry_conditions)
+ df = df.filter(geometry_sql_condition)
+ elif bbox:
bbox_conditions = []
- for bbox in bbox:
+ for bbox_item in bbox:
polygon_wkt = (
- f"POLYGON(({bbox[0]} {bbox[1]}, {bbox[2]} {bbox[1]}, "
- f"{bbox[2]} {bbox[3]}, {bbox[0]} {bbox[3]}, {bbox[0]}
{bbox[1]}))"
+ f"POLYGON(({bbox_item[0]} {bbox_item[1]}, {bbox_item[2]}
{bbox_item[1]}, "
+ f"{bbox_item[2]} {bbox_item[3]}, {bbox_item[0]}
{bbox_item[3]}, {bbox_item[0]} {bbox_item[1]}))"
)
bbox_conditions.append(
f"st_intersects(ST_GeomFromText('{polygon_wkt}'),
geometry)"
@@ -194,6 +216,9 @@ class CollectionClient:
self,
*ids: Union[str, list],
bbox: Optional[list] = None,
+ geometry: Optional[
+ Union[str, BaseGeometry, List[Union[str, BaseGeometry]]]
+ ] = None,
datetime: Optional[Union[str, python_datetime.datetime, list]] = None,
max_items: Optional[int] = None,
) -> Iterator[PyStacItem]:
@@ -204,6 +229,9 @@ class CollectionClient:
optional filters to the data. The filters include:
- IDs: A list of item IDs to filter the items. If not provided, no ID
filtering is applied.
- bbox (Optional[list]): A list of bounding boxes for filtering the
items.
+ - geometry (Optional[Union[str, BaseGeometry, List[Union[str,
BaseGeometry]]]]): Shapely geometry object(s) or WKT string(s) for spatial
filtering.
+ Can be a single geometry, WKT string, or a list of geometries/WKT
strings.
+ If both bbox and geometry are provided, geometry takes precedence.
- datetime (Optional[Union[str, python_datetime.datetime, list]]): A
single datetime, RFC 3339-compliant timestamp,
or a list of date-time ranges for filtering the items.
- max_items (Optional[int]): The maximum number of items to return
from the search, even if there are more matching results.
@@ -217,7 +245,7 @@ class CollectionClient:
is raised with a message indicating the failure.
"""
try:
- df = self.load_items_df(bbox, datetime, ids, max_items)
+ df = self.load_items_df(bbox, geometry, datetime, ids, max_items)
# Collect the filtered rows and convert them to PyStacItem objects
items = []
@@ -237,6 +265,9 @@ class CollectionClient:
self,
*ids: Union[str, list],
bbox: Optional[list] = None,
+ geometry: Optional[
+ Union[str, BaseGeometry, List[Union[str, BaseGeometry]]]
+ ] = None,
datetime: Optional[Union[str, python_datetime.datetime, list]] = None,
max_items: Optional[int] = None,
) -> DataFrame:
@@ -251,6 +282,10 @@ class CollectionClient:
- bbox (Optional[list]): A list of bounding boxes for filtering the
items.
Each bounding box is represented as a list of four float values:
[min_lon, min_lat, max_lon, max_lat].
Example: [[-180.0, -90.0, 180.0, 90.0]] # This bounding box covers
the entire world.
+ - geometry (Optional[Union[str, BaseGeometry, List[Union[str,
BaseGeometry]]]]): Shapely geometry object(s) or WKT string(s) for spatial
filtering.
+ Can be a single geometry, WKT string, or a list of geometries/WKT
strings.
+ If both bbox and geometry are provided, geometry takes precedence.
+ Example: Polygon(...) or "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))" or
[Polygon(...), Polygon(...)]
- datetime (Optional[Union[str, python_datetime.datetime, list]]): A
single datetime, RFC 3339-compliant timestamp,
or a list of date-time ranges for filtering the items.
Example: "2020-01-01T00:00:00Z" or python_datetime.datetime(2020, 1,
1) or [["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]]
@@ -264,7 +299,7 @@ class CollectionClient:
is raised with a message indicating the failure.
"""
try:
- df = self.load_items_df(bbox, datetime, ids, max_items)
+ df = self.load_items_df(bbox, geometry, datetime, ids, max_items)
return df
except Exception as e:
@@ -276,6 +311,9 @@ class CollectionClient:
*ids: Union[str, list],
output_path: str,
bbox: Optional[list] = None,
+ geometry: Optional[
+ Union[str, BaseGeometry, List[Union[str, BaseGeometry]]]
+ ] = None,
datetime: Optional[list] = None,
) -> None:
"""
@@ -299,7 +337,9 @@ class CollectionClient:
DataFrame to Parquet format, a RuntimeError is raised with a message
indicating the failure.
"""
try:
- df = self.get_dataframe(*ids, bbox=bbox, datetime=datetime)
+ df = self.get_dataframe(
+ *ids, bbox=bbox, geometry=geometry, datetime=datetime
+ )
df_geoparquet = self._convert_assets_schema(df)
df_geoparquet.write.format("geoparquet").save(output_path)
logging.info(f"DataFrame successfully saved to {output_path}")
@@ -342,9 +382,15 @@ class CollectionClient:
return df
- def load_items_df(self, bbox, datetime, ids, max_items):
+ def load_items_df(self, bbox, geometry, datetime, ids, max_items):
# Load the collection data from the specified collection URL
- if not ids and not bbox and not datetime and max_items is not None:
+ if (
+ not ids
+ and not bbox
+ and not geometry
+ and not datetime
+ and max_items is not None
+ ):
df = (
self.spark.read.format("stac")
.option("itemsLimitMax", max_items)
@@ -362,6 +408,10 @@ class CollectionClient:
# Ensure bbox is a list of lists
if bbox and isinstance(bbox[0], float):
bbox = [bbox]
+ # Handle geometry parameter
+ if geometry:
+ if not isinstance(geometry, list):
+ geometry = [geometry]
# Handle datetime parameter
if datetime:
if isinstance(datetime, (str, python_datetime.datetime)):
@@ -371,7 +421,7 @@ class CollectionClient:
):
datetime = [list(datetime)]
# Apply spatial and temporal filters
- df = self._apply_spatial_temporal_filters(df, bbox, datetime)
+ df = self._apply_spatial_temporal_filters(df, bbox, geometry,
datetime)
# Limit the number of items if max_items is specified
if max_items is not None:
df = df.limit(max_items)
diff --git a/python/tests/stac/test_collection_client.py
b/python/tests/stac/test_collection_client.py
index f50a811daa..568389b601 100644
--- a/python/tests/stac/test_collection_client.py
+++ b/python/tests/stac/test_collection_client.py
@@ -188,6 +188,131 @@ class TestStacReader(TestBase):
df_loaded =
collection.spark.read.format("geoparquet").load(output_path)
assert df_loaded.count() == 20, "Loaded GeoParquet file is empty"
+ def test_get_items_with_wkt_geometry(self) -> None:
+ """Test that WKT geometry strings are properly handled for spatial
filtering."""
+ client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+ collection = client.get_collection("aster-l1t")
+
+ # Test with WKT polygon geometry
+ wkt_polygon = "POLYGON((90 -73, 105 -73, 105 -69, 90 -69, 90 -73))"
+ items_with_wkt = list(collection.get_items(geometry=wkt_polygon))
+
+ # Both should return similar number of items (may not be exactly same
due to geometry differences)
+ assert items_with_wkt is not None
+ assert len(items_with_wkt) > 0
+
+ def test_get_dataframe_with_shapely_geometry(self) -> None:
+ """Test that Shapely geometry objects are properly handled for spatial
filtering."""
+ from shapely.geometry import Polygon
+
+ client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+ collection = client.get_collection("aster-l1t")
+
+ # Test with Shapely polygon geometry
+ shapely_polygon = Polygon(
+ [(90, -73), (105, -73), (105, -69), (90, -69), (90, -73)]
+ )
+ df_with_shapely = collection.get_dataframe(geometry=shapely_polygon)
+
+ # Both should return similar number of items
+ assert df_with_shapely is not None
+ assert df_with_shapely.count() > 0
+
+ def test_get_items_with_geometry_list(self) -> None:
+ """Test that lists of geometry objects are properly handled."""
+ from shapely.geometry import Polygon
+
+ client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+ collection = client.get_collection("aster-l1t")
+
+ # Test with list of geometries (both WKT and Shapely)
+ wkt_polygon = "POLYGON((90 -73, 105 -73, 105 -69, 90 -69, 90 -73))"
+ shapely_polygon = Polygon(
+ [(-100, -72), (-90, -72), (-90, -62), (-100, -62), (-100, -72)]
+ )
+ geometry_list = [wkt_polygon, shapely_polygon]
+
+ items_with_geom_list =
list(collection.get_items(geometry=geometry_list))
+
+ # Should return items from both geometries
+ assert items_with_geom_list is not None
+ assert len(items_with_geom_list) > 0
+
+ def test_geometry_takes_precedence_over_bbox(self) -> None:
+ """Test that geometry parameter takes precedence over bbox when both
are provided."""
+ from shapely.geometry import Polygon
+
+ client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+ collection = client.get_collection("aster-l1t")
+
+ # Define different spatial extents
+ bbox = [-180.0, -90.0, 180.0, 90.0] # World bbox
+ small_polygon = Polygon(
+ [(90, -73), (105, -73), (105, -69), (90, -69), (90, -73)]
+ ) # Small area
+
+ # When both are provided, geometry should take precedence
+ items_with_both = list(collection.get_items(bbox=bbox,
geometry=small_polygon))
+ items_with_geom_only =
list(collection.get_items(geometry=small_polygon))
+
+ # Results should be identical since geometry takes precedence
+ assert items_with_both is not None
+ assert items_with_geom_only is not None
+ assert len(items_with_both) == len(items_with_geom_only)
+ assert len(items_with_both) > 0
+
+ def test_get_dataframe_with_geometry_and_datetime(self) -> None:
+ """Test that geometry and datetime filters work together."""
+ from shapely.geometry import Polygon
+
+ client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+ collection = client.get_collection("aster-l1t")
+
+ # Define spatial and temporal filters
+ polygon = Polygon([(90, -73), (105, -73), (105, -69), (90, -69), (90,
-73)])
+ datetime_range = ["2006-12-01T00:00:00Z", "2006-12-27T03:00:00Z"]
+
+ df_with_both = collection.get_dataframe(
+ geometry=polygon, datetime=datetime_range
+ )
+ df_with_geom_only = collection.get_dataframe(geometry=polygon)
+
+ # Combined filter should return fewer or equal items than
geometry-only filter
+ assert df_with_both is not None
+ assert df_with_geom_only is not None
+ assert df_with_both.count() <= df_with_geom_only.count()
+
+ def test_save_to_geoparquet_with_geometry(self) -> None:
+ """Test saving to GeoParquet with geometry parameter."""
+ from shapely.geometry import Polygon
+ import tempfile
+ import os
+
+ client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+ collection = client.get_collection("aster-l1t")
+
+ # Create a temporary directory for the output path and clean it up
after the test
+ with tempfile.TemporaryDirectory() as tmpdirname:
+ output_path = f"{tmpdirname}/test_geometry_geoparquet_output"
+
+ # Define spatial and temporal extents
+ polygon = Polygon(
+ [(-180, -90), (180, -90), (180, 90), (-180, 90), (-180, -90)]
+ )
+ datetime_range = [["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"]]
+
+ # Call the method to save the DataFrame to GeoParquet
+ collection.save_to_geoparquet(
+ output_path=output_path, geometry=polygon,
datetime=datetime_range
+ )
+
+ # Check if the file was created
+ assert os.path.exists(output_path), "GeoParquet file was not
created"
+
+ # Optionally, you can load the file back and check its contents
+ df_loaded =
collection.spark.read.format("geoparquet").load(output_path)
+ assert df_loaded.count() > 0, "Loaded GeoParquet file is empty"
+
def test_get_items_with_tuple_datetime(self) -> None:
"""Test that tuples are properly handled as datetime input (same as
lists)."""
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])