This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 57146118c0 [GH-2127] Add Shapely and WKT geometry support to STAC 
reader (#2128)
57146118c0 is described below

commit 57146118c0c49912fae3ebbcd9d27c98c8b2acb2
Author: Feng Zhang <[email protected]>
AuthorDate: Sun Jul 20 23:25:55 2025 -0700

    [GH-2127] Add Shapely and WKT geometry support to STAC reader (#2128)
    
    * [GH-2127] Add Shapely and WKT geometry support to STAC reader
    
     Enhanced the STAC reader to support Shapely geometry objects and WKT 
strings as spatial filters, providing more flexibility than
       the existing bbox-only approach. This improvement goes beyond 
pystac-client's capabilities while maintaining full backward
      compatibility.
    
      New Features
    
      - New geometry parameter added to all STAC client methods (search, 
get_items, get_dataframe, save_to_geoparquet)
      - Multiple input format support:
        - Shapely geometry objects (Polygon, etc.)
        - WKT strings (e.g., "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))")
        - Lists of mixed geometries and WKT strings
      - Smart precedence: When both bbox and geometry are provided, geometry 
takes precedence
      - Seamless integration with existing datetime and other filters
    
      Usage Examples
    
      from shapely.geometry import Polygon
    
      # Using WKT string
      items = client.search(geometry="POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))")
    
      # Using Shapely polygon
      polygon = Polygon([(0, 0), (1, 0), (1, 1), (0, 1), (0, 0)])
      items = client.search(geometry=polygon)
    
      # Using list of geometries
      geometries = [polygon, "POLYGON((2 2, 3 2, 3 3, 2 3, 2 2))"]
      items = client.search(geometry=geometries)
    
      # Combined with other filters
      items = client.search(
          geometry=polygon,
          datetime=["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]
      )
    
    * Update python/sedona/spark/stac/collection_client.py
    
    Co-authored-by: Copilot <[email protected]>
    
    * revert incorrect refactoring
    
    ---------
    
    Co-authored-by: Copilot <[email protected]>
---
 python/sedona/spark/stac/client.py            |  23 ++++-
 python/sedona/spark/stac/collection_client.py |  76 +++++++++++++---
 python/tests/stac/test_collection_client.py   | 125 ++++++++++++++++++++++++++
 3 files changed, 208 insertions(+), 16 deletions(-)

diff --git a/python/sedona/spark/stac/client.py 
b/python/sedona/spark/stac/client.py
index f4af0eeb1b..ac954ea17f 100644
--- a/python/sedona/spark/stac/client.py
+++ b/python/sedona/spark/stac/client.py
@@ -14,12 +14,13 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-from typing import Union, Optional, Iterator
+from typing import Union, Optional, Iterator, List
 
 from sedona.spark.stac.collection_client import CollectionClient
 
 import datetime as python_datetime
 from pystac import Item as PyStacItem
+from shapely.geometry.base import BaseGeometry
 
 from pyspark.sql import DataFrame
 
@@ -77,6 +78,9 @@ class Client:
         *ids: Union[str, list],
         collection_id: Optional[str] = None,
         bbox: Optional[list] = None,
+        geometry: Optional[
+            Union[str, BaseGeometry, List[Union[str, BaseGeometry]]]
+        ] = None,
         datetime: Optional[Union[str, python_datetime.datetime, list]] = None,
         max_items: Optional[int] = None,
         return_dataframe: bool = True,
@@ -95,6 +99,11 @@ class Client:
           Each bounding box is represented as a list of four float values: 
[min_lon, min_lat, max_lon, max_lat].
           Example: [[-180.0, -90.0, 180.0, 90.0]]  # This bounding box covers 
the entire world.
 
+        - geometry (Optional[Union[str, BaseGeometry, List[Union[str, 
BaseGeometry]]]]): Shapely geometry object(s) or WKT string(s) for spatial 
filtering.
+          Can be a single geometry, WKT string, or a list of geometries/WKT 
strings.
+          If both bbox and geometry are provided, geometry takes precedence.
+          Example: Polygon(...) or "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))" or 
[Polygon(...), Polygon(...)]
+
         - datetime (Optional[Union[str, python_datetime.datetime, list]]): A 
single datetime, RFC 3339-compliant timestamp,
           or a list of date-time ranges for filtering the items. The datetime 
can be specified in various formats:
           - "YYYY" expands to ["YYYY-01-01T00:00:00Z", "YYYY-12-31T23:59:59Z"]
@@ -119,9 +128,17 @@ class Client:
             client = self.get_collection_from_catalog()
         if return_dataframe:
             return client.get_dataframe(
-                *ids, bbox=bbox, datetime=datetime, max_items=max_items
+                *ids,
+                bbox=bbox,
+                geometry=geometry,
+                datetime=datetime,
+                max_items=max_items,
             )
         else:
             return client.get_items(
-                *ids, bbox=bbox, datetime=datetime, max_items=max_items
+                *ids,
+                bbox=bbox,
+                geometry=geometry,
+                datetime=datetime,
+                max_items=max_items,
             )
diff --git a/python/sedona/spark/stac/collection_client.py 
b/python/sedona/spark/stac/collection_client.py
index 0d04c06575..bc020cfdf8 100644
--- a/python/sedona/spark/stac/collection_client.py
+++ b/python/sedona/spark/stac/collection_client.py
@@ -16,13 +16,14 @@
 # under the License.
 
 import logging
-from typing import Iterator, Union
+from typing import Iterator, Union, List
 from typing import Optional
 
 import datetime as python_datetime
 from pyspark.sql import DataFrame, SparkSession
 from pyspark.sql.types import dt
 from pystac import Item as PyStacItem
+from shapely.geometry.base import BaseGeometry
 
 
 def get_collection_url(url: str, collection_id: Optional[str] = None) -> str:
@@ -101,7 +102,7 @@ class CollectionClient:
 
     @staticmethod
     def _apply_spatial_temporal_filters(
-        df: DataFrame, bbox=None, datetime=None
+        df: DataFrame, bbox=None, geometry=None, datetime=None
     ) -> DataFrame:
         """
         This function applies spatial and temporal filters to a Spark 
DataFrame.
@@ -111,6 +112,8 @@ class CollectionClient:
         - bbox (Optional[list]): A list of bounding boxes for filtering the 
items.
           Each bounding box is represented as a list of four float values: 
[min_lon, min_lat, max_lon, max_lat].
           Example: [[-180.0, -90.0, 180.0, 90.0]]  # This bounding box covers 
the entire world.
+        - geometry (Optional[list]): A list of geometry objects (Shapely or 
WKT) for spatial filtering.
+          If both bbox and geometry are provided, geometry takes precedence.
         - datetime (Optional[list]): A list of date-time ranges for filtering 
the items.
           Each date-time range is represented as a list of two strings in ISO 
8601 format: [start_datetime, end_datetime].
           Example: [["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]]  # This 
interval covers the entire year of 2020.
@@ -119,16 +122,35 @@ class CollectionClient:
         - DataFrame: The filtered Spark DataFrame.
 
         The function constructs SQL conditions for spatial and temporal 
filters and applies them to the DataFrame.
-        If bbox is provided, it constructs spatial conditions using 
st_intersects and ST_GeomFromText.
+        If geometry is provided, it takes precedence over bbox for spatial 
filtering.
+        If bbox is provided (and no geometry), it constructs spatial 
conditions using st_intersects and ST_GeomFromText.
         If datetime is provided, it constructs temporal conditions using the 
datetime column.
         The conditions are combined using OR logic.
         """
-        if bbox:
+        # Geometry takes precedence over bbox
+        if geometry:
+            geometry_conditions = []
+            for geom in geometry:
+                if isinstance(geom, str):
+                    # Assume it's WKT
+                    geom_wkt = geom
+                elif hasattr(geom, "wkt"):
+                    # Shapely geometry object
+                    geom_wkt = geom.wkt
+                else:
+                    # Try to convert to string (fallback)
+                    geom_wkt = str(geom)
+                geometry_conditions.append(
+                    f"st_intersects(ST_GeomFromText('{geom_wkt}'), geometry)"
+                )
+            geometry_sql_condition = " OR ".join(geometry_conditions)
+            df = df.filter(geometry_sql_condition)
+        elif bbox:
             bbox_conditions = []
-            for bbox in bbox:
+            for bbox_item in bbox:
                 polygon_wkt = (
-                    f"POLYGON(({bbox[0]} {bbox[1]}, {bbox[2]} {bbox[1]}, "
-                    f"{bbox[2]} {bbox[3]}, {bbox[0]} {bbox[3]}, {bbox[0]} 
{bbox[1]}))"
+                    f"POLYGON(({bbox_item[0]} {bbox_item[1]}, {bbox_item[2]} 
{bbox_item[1]}, "
+                    f"{bbox_item[2]} {bbox_item[3]}, {bbox_item[0]} 
{bbox_item[3]}, {bbox_item[0]} {bbox_item[1]}))"
                 )
                 bbox_conditions.append(
                     f"st_intersects(ST_GeomFromText('{polygon_wkt}'), 
geometry)"
@@ -194,6 +216,9 @@ class CollectionClient:
         self,
         *ids: Union[str, list],
         bbox: Optional[list] = None,
+        geometry: Optional[
+            Union[str, BaseGeometry, List[Union[str, BaseGeometry]]]
+        ] = None,
         datetime: Optional[Union[str, python_datetime.datetime, list]] = None,
         max_items: Optional[int] = None,
     ) -> Iterator[PyStacItem]:
@@ -204,6 +229,9 @@ class CollectionClient:
         optional filters to the data. The filters include:
         - IDs: A list of item IDs to filter the items. If not provided, no ID 
filtering is applied.
         - bbox (Optional[list]): A list of bounding boxes for filtering the 
items.
+        - geometry (Optional[Union[str, BaseGeometry, List[Union[str, 
BaseGeometry]]]]): Shapely geometry object(s) or WKT string(s) for spatial 
filtering.
+          Can be a single geometry, WKT string, or a list of geometries/WKT 
strings.
+          If both bbox and geometry are provided, geometry takes precedence.
         - datetime (Optional[Union[str, python_datetime.datetime, list]]): A 
single datetime, RFC 3339-compliant timestamp,
           or a list of date-time ranges for filtering the items.
         - max_items (Optional[int]): The maximum number of items to return 
from the search, even if there are more matching results.
@@ -217,7 +245,7 @@ class CollectionClient:
           is raised with a message indicating the failure.
         """
         try:
-            df = self.load_items_df(bbox, datetime, ids, max_items)
+            df = self.load_items_df(bbox, geometry, datetime, ids, max_items)
 
             # Collect the filtered rows and convert them to PyStacItem objects
             items = []
@@ -237,6 +265,9 @@ class CollectionClient:
         self,
         *ids: Union[str, list],
         bbox: Optional[list] = None,
+        geometry: Optional[
+            Union[str, BaseGeometry, List[Union[str, BaseGeometry]]]
+        ] = None,
         datetime: Optional[Union[str, python_datetime.datetime, list]] = None,
         max_items: Optional[int] = None,
     ) -> DataFrame:
@@ -251,6 +282,10 @@ class CollectionClient:
         - bbox (Optional[list]): A list of bounding boxes for filtering the 
items.
           Each bounding box is represented as a list of four float values: 
[min_lon, min_lat, max_lon, max_lat].
           Example: [[-180.0, -90.0, 180.0, 90.0]]  # This bounding box covers 
the entire world.
+        - geometry (Optional[Union[str, BaseGeometry, List[Union[str, 
BaseGeometry]]]]): Shapely geometry object(s) or WKT string(s) for spatial 
filtering.
+          Can be a single geometry, WKT string, or a list of geometries/WKT 
strings.
+          If both bbox and geometry are provided, geometry takes precedence.
+          Example: Polygon(...) or "POLYGON((0 0, 1 0, 1 1, 0 1, 0 0))" or 
[Polygon(...), Polygon(...)]
         - datetime (Optional[Union[str, python_datetime.datetime, list]]): A 
single datetime, RFC 3339-compliant timestamp,
           or a list of date-time ranges for filtering the items.
           Example: "2020-01-01T00:00:00Z" or python_datetime.datetime(2020, 1, 
1) or [["2020-01-01T00:00:00Z", "2021-01-01T00:00:00Z"]]
@@ -264,7 +299,7 @@ class CollectionClient:
           is raised with a message indicating the failure.
         """
         try:
-            df = self.load_items_df(bbox, datetime, ids, max_items)
+            df = self.load_items_df(bbox, geometry, datetime, ids, max_items)
 
             return df
         except Exception as e:
@@ -276,6 +311,9 @@ class CollectionClient:
         *ids: Union[str, list],
         output_path: str,
         bbox: Optional[list] = None,
+        geometry: Optional[
+            Union[str, BaseGeometry, List[Union[str, BaseGeometry]]]
+        ] = None,
         datetime: Optional[list] = None,
     ) -> None:
         """
@@ -299,7 +337,9 @@ class CollectionClient:
           DataFrame to Parquet format, a RuntimeError is raised with a message 
indicating the failure.
         """
         try:
-            df = self.get_dataframe(*ids, bbox=bbox, datetime=datetime)
+            df = self.get_dataframe(
+                *ids, bbox=bbox, geometry=geometry, datetime=datetime
+            )
             df_geoparquet = self._convert_assets_schema(df)
             df_geoparquet.write.format("geoparquet").save(output_path)
             logging.info(f"DataFrame successfully saved to {output_path}")
@@ -342,9 +382,15 @@ class CollectionClient:
 
         return df
 
-    def load_items_df(self, bbox, datetime, ids, max_items):
+    def load_items_df(self, bbox, geometry, datetime, ids, max_items):
         # Load the collection data from the specified collection URL
-        if not ids and not bbox and not datetime and max_items is not None:
+        if (
+            not ids
+            and not bbox
+            and not geometry
+            and not datetime
+            and max_items is not None
+        ):
             df = (
                 self.spark.read.format("stac")
                 .option("itemsLimitMax", max_items)
@@ -362,6 +408,10 @@ class CollectionClient:
             # Ensure bbox is a list of lists
             if bbox and isinstance(bbox[0], float):
                 bbox = [bbox]
+            # Handle geometry parameter
+            if geometry:
+                if not isinstance(geometry, list):
+                    geometry = [geometry]
             # Handle datetime parameter
             if datetime:
                 if isinstance(datetime, (str, python_datetime.datetime)):
@@ -371,7 +421,7 @@ class CollectionClient:
                 ):
                     datetime = [list(datetime)]
             # Apply spatial and temporal filters
-            df = self._apply_spatial_temporal_filters(df, bbox, datetime)
+            df = self._apply_spatial_temporal_filters(df, bbox, geometry, 
datetime)
         # Limit the number of items if max_items is specified
         if max_items is not None:
             df = df.limit(max_items)
diff --git a/python/tests/stac/test_collection_client.py 
b/python/tests/stac/test_collection_client.py
index f50a811daa..568389b601 100644
--- a/python/tests/stac/test_collection_client.py
+++ b/python/tests/stac/test_collection_client.py
@@ -188,6 +188,131 @@ class TestStacReader(TestBase):
             df_loaded = 
collection.spark.read.format("geoparquet").load(output_path)
             assert df_loaded.count() == 20, "Loaded GeoParquet file is empty"
 
+    def test_get_items_with_wkt_geometry(self) -> None:
+        """Test that WKT geometry strings are properly handled for spatial 
filtering."""
+        client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+        collection = client.get_collection("aster-l1t")
+
+        # Test with WKT polygon geometry
+        wkt_polygon = "POLYGON((90 -73, 105 -73, 105 -69, 90 -69, 90 -73))"
+        items_with_wkt = list(collection.get_items(geometry=wkt_polygon))
+
+        # Both should return similar number of items (may not be exactly same 
due to geometry differences)
+        assert items_with_wkt is not None
+        assert len(items_with_wkt) > 0
+
+    def test_get_dataframe_with_shapely_geometry(self) -> None:
+        """Test that Shapely geometry objects are properly handled for spatial 
filtering."""
+        from shapely.geometry import Polygon
+
+        client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+        collection = client.get_collection("aster-l1t")
+
+        # Test with Shapely polygon geometry
+        shapely_polygon = Polygon(
+            [(90, -73), (105, -73), (105, -69), (90, -69), (90, -73)]
+        )
+        df_with_shapely = collection.get_dataframe(geometry=shapely_polygon)
+
+        # Both should return similar number of items
+        assert df_with_shapely is not None
+        assert df_with_shapely.count() > 0
+
+    def test_get_items_with_geometry_list(self) -> None:
+        """Test that lists of geometry objects are properly handled."""
+        from shapely.geometry import Polygon
+
+        client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+        collection = client.get_collection("aster-l1t")
+
+        # Test with list of geometries (both WKT and Shapely)
+        wkt_polygon = "POLYGON((90 -73, 105 -73, 105 -69, 90 -69, 90 -73))"
+        shapely_polygon = Polygon(
+            [(-100, -72), (-90, -72), (-90, -62), (-100, -62), (-100, -72)]
+        )
+        geometry_list = [wkt_polygon, shapely_polygon]
+
+        items_with_geom_list = 
list(collection.get_items(geometry=geometry_list))
+
+        # Should return items from both geometries
+        assert items_with_geom_list is not None
+        assert len(items_with_geom_list) > 0
+
+    def test_geometry_takes_precedence_over_bbox(self) -> None:
+        """Test that geometry parameter takes precedence over bbox when both 
are provided."""
+        from shapely.geometry import Polygon
+
+        client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+        collection = client.get_collection("aster-l1t")
+
+        # Define different spatial extents
+        bbox = [-180.0, -90.0, 180.0, 90.0]  # World bbox
+        small_polygon = Polygon(
+            [(90, -73), (105, -73), (105, -69), (90, -69), (90, -73)]
+        )  # Small area
+
+        # When both are provided, geometry should take precedence
+        items_with_both = list(collection.get_items(bbox=bbox, 
geometry=small_polygon))
+        items_with_geom_only = 
list(collection.get_items(geometry=small_polygon))
+
+        # Results should be identical since geometry takes precedence
+        assert items_with_both is not None
+        assert items_with_geom_only is not None
+        assert len(items_with_both) == len(items_with_geom_only)
+        assert len(items_with_both) > 0
+
+    def test_get_dataframe_with_geometry_and_datetime(self) -> None:
+        """Test that geometry and datetime filters work together."""
+        from shapely.geometry import Polygon
+
+        client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+        collection = client.get_collection("aster-l1t")
+
+        # Define spatial and temporal filters
+        polygon = Polygon([(90, -73), (105, -73), (105, -69), (90, -69), (90, 
-73)])
+        datetime_range = ["2006-12-01T00:00:00Z", "2006-12-27T03:00:00Z"]
+
+        df_with_both = collection.get_dataframe(
+            geometry=polygon, datetime=datetime_range
+        )
+        df_with_geom_only = collection.get_dataframe(geometry=polygon)
+
+        # Combined filter should return fewer or equal items than 
geometry-only filter
+        assert df_with_both is not None
+        assert df_with_geom_only is not None
+        assert df_with_both.count() <= df_with_geom_only.count()
+
+    def test_save_to_geoparquet_with_geometry(self) -> None:
+        """Test saving to GeoParquet with geometry parameter."""
+        from shapely.geometry import Polygon
+        import tempfile
+        import os
+
+        client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
+        collection = client.get_collection("aster-l1t")
+
+        # Create a temporary directory for the output path and clean it up 
after the test
+        with tempfile.TemporaryDirectory() as tmpdirname:
+            output_path = f"{tmpdirname}/test_geometry_geoparquet_output"
+
+            # Define spatial and temporal extents
+            polygon = Polygon(
+                [(-180, -90), (180, -90), (180, 90), (-180, 90), (-180, -90)]
+            )
+            datetime_range = [["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"]]
+
+            # Call the method to save the DataFrame to GeoParquet
+            collection.save_to_geoparquet(
+                output_path=output_path, geometry=polygon, 
datetime=datetime_range
+            )
+
+            # Check if the file was created
+            assert os.path.exists(output_path), "GeoParquet file was not 
created"
+
+            # Optionally, you can load the file back and check its contents
+            df_loaded = 
collection.spark.read.format("geoparquet").load(output_path)
+            assert df_loaded.count() > 0, "Loaded GeoParquet file is empty"
+
     def test_get_items_with_tuple_datetime(self) -> None:
         """Test that tuples are properly handled as datetime input (same as 
lists)."""
         client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])

Reply via email to