This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch prepare-1.7.2 in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 8bee92d784e457116c3342d8256bf9d2d04b9a98 Author: Feng Zhang <[email protected]> AuthorDate: Mon Mar 17 14:30:18 2025 -0700 [SEDONA-704] Optimize STAC reader and fix few issues (#1861) * [SEDONA-704] Optimize STAC reader and fix few issues * fix formating issue * fix compiling error * another fix --- docs/tutorial/files/stac-sedona-spark.md | 4 +- python/sedona/stac/collection_client.py | 93 ++++++-------- python/tests/stac/test_client.py | 5 +- python/tests/stac/test_collection_client.py | 6 +- .../spark/sql/sedona_sql/io/stac/StacBatch.scala | 36 +++++- .../sql/sedona_sql/io/stac/StacDataSource.scala | 17 ++- .../sedona_sql/io/stac/StacPartitionReader.scala | 8 +- .../spark/sql/sedona_sql/io/stac/StacScan.scala | 26 +++- .../sql/sedona_sql/io/stac/StacScanBuilder.scala | 10 +- .../spark/sql/sedona_sql/io/stac/StacTable.scala | 9 +- .../spark/sql/sedona_sql/io/stac/StacUtils.scala | 134 ++++++++++++++------- .../SpatialTemporalFilterPushDownForStacScan.scala | 17 ++- .../sql/sedona_sql/io/stac/StacBatchTest.scala | 81 +++++++++++-- .../sedona_sql/io/stac/StacDataSourceTest.scala | 26 ++-- .../io/stac/StacPartitionReaderTest.scala | 5 + .../sql/sedona_sql/io/stac/StacUtilsTest.scala | 128 +++++++++++++++++++- 16 files changed, 450 insertions(+), 155 deletions(-) diff --git a/docs/tutorial/files/stac-sedona-spark.md b/docs/tutorial/files/stac-sedona-spark.md index 062e6c5f55..aff0e29f0f 100644 --- a/docs/tutorial/files/stac-sedona-spark.md +++ b/docs/tutorial/files/stac-sedona-spark.md @@ -116,7 +116,7 @@ The STAC data source supports predicate pushdown for spatial and temporal filter ### Spatial Filter Pushdown -Spatial filter pushdown allows the data source to apply spatial predicates (e.g., st_contains, st_intersects) directly at the data source level, reducing the amount of data transferred and processed. +Spatial filter pushdown allows the data source to apply spatial predicates (e.g., st_intersects) directly at the data source level, reducing the amount of data transferred and processed. ### Temporal Filter Pushdown @@ -147,7 +147,7 @@ In this example, the data source will push down the temporal filter to the under ```sql SELECT id, geometry FROM STAC_TABLE - WHERE st_contains(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11, 17 10))'), geometry) + WHERE st_intersects(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11, 17 10))'), geometry) ``` In this example, the data source will push down the spatial filter to the underlying data source. diff --git a/python/sedona/stac/collection_client.py b/python/sedona/stac/collection_client.py index b1cae6df39..7670c6291d 100644 --- a/python/sedona/stac/collection_client.py +++ b/python/sedona/stac/collection_client.py @@ -119,7 +119,7 @@ class CollectionClient: - DataFrame: The filtered Spark DataFrame. The function constructs SQL conditions for spatial and temporal filters and applies them to the DataFrame. - If bbox is provided, it constructs spatial conditions using st_contains and ST_GeomFromText. + If bbox is provided, it constructs spatial conditions using st_intersects and ST_GeomFromText. If datetime is provided, it constructs temporal conditions using the datetime column. The conditions are combined using OR logic. """ @@ -131,7 +131,7 @@ class CollectionClient: f"{bbox[2]} {bbox[3]}, {bbox[0]} {bbox[3]}, {bbox[0]} {bbox[1]}))" ) bbox_conditions.append( - f"st_contains(ST_GeomFromText('{polygon_wkt}'), geometry)" + f"st_intersects(ST_GeomFromText('{polygon_wkt}'), geometry)" ) bbox_sql_condition = " OR ".join(bbox_conditions) df = df.filter(bbox_sql_condition) @@ -217,34 +217,7 @@ class CollectionClient: is raised with a message indicating the failure. """ try: - # Load the collection data from the specified collection URL - df = self.spark.read.format("stac").load(self.collection_url) - - # Apply ID filters if provided - if ids: - if isinstance(ids, tuple): - ids = list(ids) - if isinstance(ids, str): - ids = [ids] - df = df.filter(df.id.isin(ids)) - - # Ensure bbox is a list of lists - if bbox and isinstance(bbox[0], float): - bbox = [bbox] - - # Handle datetime parameter - if datetime: - if isinstance(datetime, (str, python_datetime.datetime)): - datetime = [self._expand_date(str(datetime))] - elif isinstance(datetime, list) and isinstance(datetime[0], str): - datetime = [datetime] - - # Apply spatial and temporal filters - df = self._apply_spatial_temporal_filters(df, bbox, datetime) - - # Limit the number of items if max_items is specified - if max_items is not None: - df = df.limit(max_items) + df = self.load_items_df(bbox, datetime, ids, max_items) # Collect the filtered rows and convert them to PyStacItem objects items = [] @@ -291,32 +264,7 @@ class CollectionClient: is raised with a message indicating the failure. """ try: - df = self.spark.read.format("stac").load(self.collection_url) - - # Apply ID filters if provided - if ids: - if isinstance(ids, tuple): - ids = list(ids) - if isinstance(ids, str): - ids = [ids] - df = df.filter(df.id.isin(ids)) - - # Ensure bbox is a list of lists - if bbox and isinstance(bbox[0], float): - bbox = [bbox] - - # Handle datetime parameter - if datetime: - if isinstance(datetime, (str, python_datetime.datetime)): - datetime = [[str(datetime), str(datetime)]] - elif isinstance(datetime, list) and isinstance(datetime[0], str): - datetime = [datetime] - - df = self._apply_spatial_temporal_filters(df, bbox, datetime) - - # Limit the number of items if max_items is specified - if max_items is not None: - df = df.limit(max_items) + df = self.load_items_df(bbox, datetime, ids, max_items) return df except Exception as e: @@ -394,5 +342,38 @@ class CollectionClient: return df + def load_items_df(self, bbox, datetime, ids, max_items): + # Load the collection data from the specified collection URL + if not ids and not bbox and not datetime and max_items is not None: + df = ( + self.spark.read.format("stac") + .option("itemsLimitMax", max_items) + .load(self.collection_url) + ) + else: + df = self.spark.read.format("stac").load(self.collection_url) + # Apply ID filters if provided + if ids: + if isinstance(ids, tuple): + ids = list(ids) + if isinstance(ids, str): + ids = [ids] + df = df.filter(df.id.isin(ids)) + # Ensure bbox is a list of lists + if bbox and isinstance(bbox[0], float): + bbox = [bbox] + # Handle datetime parameter + if datetime: + if isinstance(datetime, (str, python_datetime.datetime)): + datetime = [self._expand_date(str(datetime))] + elif isinstance(datetime, list) and isinstance(datetime[0], str): + datetime = [datetime] + # Apply spatial and temporal filters + df = self._apply_spatial_temporal_filters(df, bbox, datetime) + # Limit the number of items if max_items is specified + if max_items is not None: + df = df.limit(max_items) + return df + def __str__(self): return f"<CollectionClient id={self.collection_id}>" diff --git a/python/tests/stac/test_client.py b/python/tests/stac/test_client.py index 4f9919e1c5..b8b2beed8a 100644 --- a/python/tests/stac/test_client.py +++ b/python/tests/stac/test_client.py @@ -36,7 +36,7 @@ class TestStacClient(TestBase): return_dataframe=False, ) assert items is not None - assert len(list(items)) == 2 + assert len(list(items)) > 0 def test_search_with_ids(self) -> None: client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) @@ -82,7 +82,7 @@ class TestStacClient(TestBase): return_dataframe=False, ) assert items is not None - assert len(list(items)) == 4 + assert len(list(items)) > 0 def test_search_with_bbox_and_non_overlapping_intervals(self) -> None: client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) @@ -144,6 +144,7 @@ class TestStacClient(TestBase): datetime=["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"], ) assert df is not None + assert df.count() == 20 assert isinstance(df, DataFrame) def test_search_with_catalog_url(self) -> None: diff --git a/python/tests/stac/test_collection_client.py b/python/tests/stac/test_collection_client.py index 24226f86ca..aa9a15f4eb 100644 --- a/python/tests/stac/test_collection_client.py +++ b/python/tests/stac/test_collection_client.py @@ -71,7 +71,7 @@ class TestStacReader(TestBase): bbox = [[-100.0, -72.0, 105.0, -69.0]] items = list(collection.get_items(bbox=bbox)) assert items is not None - assert len(items) == 2 + assert len(items) > 0 def test_get_items_with_temporal_extent(self) -> None: client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) @@ -88,7 +88,7 @@ class TestStacReader(TestBase): datetime = [["2006-12-01T00:00:00Z", "2006-12-27T03:00:00Z"]] items = list(collection.get_items(bbox=bbox, datetime=datetime)) assert items is not None - assert len(items) == 4 + assert len(items) > 0 def test_get_items_with_multiple_bboxes_and_interval(self) -> None: client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) @@ -111,7 +111,7 @@ class TestStacReader(TestBase): datetime = [["2006-12-01T00:00:00Z", "2006-12-27T03:00:00Z"]] items = list(collection.get_items(bbox=bbox, datetime=datetime)) assert items is not None - assert len(items) == 4 + assert len(items) > 0 def test_get_items_with_ids(self) -> None: client = Client.open(STAC_URLS["PLANETARY-COMPUTER"]) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala index 5994c463fd..4ef53756d2 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala @@ -19,16 +19,19 @@ package org.apache.spark.sql.sedona_sql.io.stac import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.connector.read.{Batch, InputPartition, PartitionReaderFactory} import org.apache.spark.sql.execution.datasource.stac.TemporalFilter import org.apache.spark.sql.execution.datasources.parquet.{GeoParquetSpatialFilter, GeometryFieldMetaData} import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.getNumPartitions import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration import java.time.LocalDateTime import java.time.format.DateTimeFormatterBuilder import java.time.temporal.ChronoField import scala.jdk.CollectionConverters._ +import scala.util.Random import scala.util.control.Breaks.breakable /** @@ -40,12 +43,14 @@ import scala.util.control.Breaks.breakable * which are necessary for batch data processing. */ case class StacBatch( + broadcastConf: Broadcast[SerializableConfiguration], stacCollectionUrl: String, stacCollectionJson: String, schema: StructType, opts: Map[String, String], spatialFilter: Option[GeoParquetSpatialFilter], - temporalFilter: Option[TemporalFilter]) + temporalFilter: Option[TemporalFilter], + limitFilter: Option[Int]) extends Batch { private val defaultItemsLimitPerRequest: Int = { @@ -82,10 +87,16 @@ case class StacBatch( // Initialize the itemLinks array val itemLinks = scala.collection.mutable.ArrayBuffer[String]() - // Start the recursive collection of item links - val itemsLimitMax = opts.getOrElse("itemsLimitMax", "-1").toInt + // Get the maximum number of items to process + val itemsLimitMax = limitFilter match { + case Some(limit) if limit >= 0 => limit + case _ => opts.getOrElse("itemsLimitMax", "-1").toInt + } val checkItemsLimitMax = itemsLimitMax > 0 + + // Start the recursive collection of item links setItemMaxLeft(itemsLimitMax) + collectItemLinks(stacCollectionBasePath, stacCollectionJson, itemLinks, checkItemsLimitMax) // Handle when the number of items is less than 1 @@ -109,8 +120,9 @@ case class StacBatch( // Determine how many items to put in each partition val partitionSize = Math.ceil(itemLinks.length.toDouble / numPartitions).toInt - // Group the item links into partitions - itemLinks + // Group the item links into partitions, but randomize first for better load balancing + Random + .shuffle(itemLinks) .grouped(partitionSize) .zipWithIndex .map { case (items, index) => @@ -229,7 +241,7 @@ case class StacBatch( } else if (rel == "items" && href.startsWith("http")) { // iterate through the items and check if the limit is reached (if needed) if (iterateItemsWithLimit( - itemUrl + "?limit=" + defaultItemsLimitPerRequest, + getItemLink(itemUrl, defaultItemsLimitPerRequest, spatialFilter, temporalFilter), needCountNextItems)) return } } @@ -256,6 +268,17 @@ case class StacBatch( } } + /** Adds an item link to the list of item links. */ + def getItemLink( + itemUrl: String, + defaultItemsLimitPerRequest: Int, + spatialFilter: Option[GeoParquetSpatialFilter], + temporalFilter: Option[TemporalFilter]): String = { + val baseUrl = itemUrl + "?limit=" + defaultItemsLimitPerRequest + val urlWithFilters = StacUtils.addFiltersToUrl(baseUrl, spatialFilter, temporalFilter) + urlWithFilters + } + /** * Filters a collection based on the provided spatial and temporal filters. * @@ -361,6 +384,7 @@ case class StacBatch( override def createReaderFactory(): PartitionReaderFactory = { (partition: InputPartition) => { new StacPartitionReader( + broadcastConf, partition.asInstanceOf[StacPartition], schema, opts, diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala index dc2dc3e6dd..154adf8eca 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala @@ -18,15 +18,16 @@ */ package org.apache.spark.sql.sedona_sql.io.stac -import StacUtils.{inferStacSchema, updatePropertiesPromotedSchema} import org.apache.spark.sql.SparkSession import org.apache.spark.sql.connector.catalog.{Table, TableProvider} import org.apache.spark.sql.connector.expressions.Transform import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.sedona_sql.io.geojson.GeoJSONUtils +import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.{inferStacSchema, updatePropertiesPromotedSchema} import org.apache.spark.sql.sources.DataSourceRegister import org.apache.spark.sql.types.StructType import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration import java.util import java.util.concurrent.ConcurrentHashMap @@ -100,6 +101,7 @@ class StacDataSource() extends TableProvider with DataSourceRegister { partitioning: Array[Transform], properties: util.Map[String, String]): Table = { val opts = new CaseInsensitiveStringMap(properties) + val sparkSession = SparkSession.active val optsMap: Map[String, String] = opts.asCaseSensitiveMap().asScala.toMap ++ Map( "sessionLocalTimeZone" -> SparkSession.active.sessionState.conf.sessionLocalTimeZone, @@ -107,17 +109,20 @@ class StacDataSource() extends TableProvider with DataSourceRegister { "defaultParallelism" -> SparkSession.active.sparkContext.defaultParallelism.toString, "maxPartitionItemFiles" -> SparkSession.active.conf .get("spark.sedona.stac.load.maxPartitionItemFiles", "0"), - "numPartitions" -> SparkSession.active.conf + "numPartitions" -> sparkSession.conf .get("spark.sedona.stac.load.numPartitions", "-1"), "itemsLimitMax" -> opts .asCaseSensitiveMap() .asScala .toMap - .get("itemsLimitMax") - .filter(_.toInt > 0) - .getOrElse(SparkSession.active.conf.get("spark.sedona.stac.load.itemsLimitMax", "-1"))) + .getOrElse( + "itemsLimitMax", + sparkSession.conf.get("spark.sedona.stac.load.itemsLimitMax", "-1"))) val stacCollectionJsonString = StacUtils.loadStacCollectionToJson(optsMap) + val hadoopConf = sparkSession.sessionState.newHadoopConfWithOptions(opts.asScala.toMap) + val broadcastedConf = + sparkSession.sparkContext.broadcast(new SerializableConfiguration(hadoopConf)) - new StacTable(stacCollectionJson = stacCollectionJsonString, opts = optsMap) + new StacTable(stacCollectionJson = stacCollectionJsonString, opts = optsMap, broadcastedConf) } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala index a545eb232f..b72ac526ca 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala @@ -20,6 +20,7 @@ package org.apache.spark.sql.sedona_sql.io.stac import com.fasterxml.jackson.databind.ObjectMapper import org.apache.hadoop.conf.Configuration +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.catalyst.json.JSONOptionsInRead import org.apache.spark.sql.connector.read.PartitionReader @@ -28,14 +29,16 @@ import org.apache.spark.sql.execution.datasources.PartitionedFile import org.apache.spark.sql.execution.datasources.json.JsonDataSource import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter import org.apache.spark.sql.sedona_sql.io.geojson.{GeoJSONUtils, SparkCompatUtil} -import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.{buildOutDbRasterFields, promotePropertiesToTop} +import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.promotePropertiesToTop import org.apache.spark.sql.types.{StringType, StructType} +import org.apache.spark.util.SerializableConfiguration import java.io.{File, PrintWriter} import java.lang.reflect.Constructor import scala.io.Source class StacPartitionReader( + broadcast: Broadcast[SerializableConfiguration], partition: StacPartition, schema: StructType, opts: Map[String, String], @@ -120,8 +123,7 @@ class StacPartitionReader( rows.map(row => { val geometryConvertedRow = GeoJSONUtils.convertGeoJsonToGeometry(row, alteredSchema) - val rasterAddedRow = buildOutDbRasterFields(geometryConvertedRow, alteredSchema) - val propertiesPromotedRow = promotePropertiesToTop(rasterAddedRow, alteredSchema) + val propertiesPromotedRow = promotePropertiesToTop(geometryConvertedRow, alteredSchema) propertiesPromotedRow }) } else { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScan.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScan.scala index 2edf082912..9ccc89fba8 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScan.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScan.scala @@ -18,14 +18,19 @@ */ package org.apache.spark.sql.sedona_sql.io.stac +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.connector.read.{Batch, Scan} import org.apache.spark.sql.execution.datasource.stac.TemporalFilter import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter import org.apache.spark.sql.internal.connector.SupportsMetadata import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.{getFullCollectionUrl, inferStacSchema} import org.apache.spark.sql.types.StructType +import org.apache.spark.util.SerializableConfiguration -class StacScan(stacCollectionJson: String, opts: Map[String, String]) +class StacScan( + stacCollectionJson: String, + opts: Map[String, String], + broadcastConf: Broadcast[SerializableConfiguration]) extends Scan with SupportsMetadata { @@ -35,6 +40,8 @@ class StacScan(stacCollectionJson: String, opts: Map[String, String]) // The temporal filter to be pushed down to the data source var temporalFilter: Option[TemporalFilter] = None + var limit: Option[Int] = None + /** * Returns the schema of the data to be read. * @@ -62,12 +69,14 @@ class StacScan(stacCollectionJson: String, opts: Map[String, String]) override def toBatch: Batch = { val stacCollectionUrl = getFullCollectionUrl(opts) StacBatch( + broadcastConf, stacCollectionUrl, stacCollectionJson, readSchema(), opts, spatialFilter, - temporalFilter) + temporalFilter, + limit) } /** @@ -90,6 +99,16 @@ class StacScan(stacCollectionJson: String, opts: Map[String, String]) temporalFilter = Some(combineTemporalFilter) } + /** + * Sets the limit on the number of items to be read. + * + * @param n + * The limit on the number of items to be read. + */ + def setLimit(n: Int) = { + limit = Some(n) + } + /** * Returns metadata about the data to be read. * @@ -101,7 +120,8 @@ class StacScan(stacCollectionJson: String, opts: Map[String, String]) override def getMetaData(): Map[String, String] = { Map( "PushedSpatialFilters" -> spatialFilter.map(_.toString).getOrElse("None"), - "PushedTemporalFilters" -> temporalFilter.map(_.toString).getOrElse("None")) + "PushedTemporalFilters" -> temporalFilter.map(_.toString).getOrElse("None"), + "Pushedimit" -> limit.map(_.toString).getOrElse("None")) } /** diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScanBuilder.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScanBuilder.scala index ebaab87dda..0fcbde9b4d 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScanBuilder.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScanBuilder.scala @@ -18,7 +18,9 @@ */ package org.apache.spark.sql.sedona_sql.io.stac +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} +import org.apache.spark.util.SerializableConfiguration /** * The `StacScanBuilder` class represents the builder for creating a `Scan` instance in the @@ -28,7 +30,11 @@ import org.apache.spark.sql.connector.read.{Scan, ScanBuilder} * bridge between Spark's data source API and the specific implementation of the STAC data read * operation. */ -class StacScanBuilder(stacCollectionJson: String, opts: Map[String, String]) extends ScanBuilder { +class StacScanBuilder( + stacCollectionJson: String, + opts: Map[String, String], + broadcastConf: Broadcast[SerializableConfiguration]) + extends ScanBuilder { /** * Builds and returns a `Scan` instance. The `Scan` defines the schema and batch reading methods @@ -37,5 +43,5 @@ class StacScanBuilder(stacCollectionJson: String, opts: Map[String, String]) ext * @return * A `Scan` instance that defines how to read STAC data. */ - override def build(): Scan = new StacScan(stacCollectionJson, opts) + override def build(): Scan = new StacScan(stacCollectionJson, opts, broadcastConf) } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala index bd536f6de6..ca5f32663c 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala @@ -18,6 +18,7 @@ */ package org.apache.spark.sql.sedona_sql.io.stac +import org.apache.spark.broadcast.Broadcast import org.apache.spark.sql.connector.catalog.{SupportsRead, Table, TableCapability} import org.apache.spark.sql.connector.read.ScanBuilder import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT @@ -25,6 +26,7 @@ import org.apache.spark.sql.sedona_sql.io.geojson.GeoJSONUtils import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.{inferStacSchema, updatePropertiesPromotedSchema} import org.apache.spark.sql.types._ import org.apache.spark.sql.util.CaseInsensitiveStringMap +import org.apache.spark.util.SerializableConfiguration import java.util.concurrent.ConcurrentHashMap @@ -38,7 +40,10 @@ import java.util.concurrent.ConcurrentHashMap * @constructor * Creates a new instance of the `StacTable` class. */ -class StacTable(stacCollectionJson: String, opts: Map[String, String]) +class StacTable( + stacCollectionJson: String, + opts: Map[String, String], + broadcastConf: Broadcast[SerializableConfiguration]) extends Table with SupportsRead { @@ -84,7 +89,7 @@ class StacTable(stacCollectionJson: String, opts: Map[String, String]) * A new instance of ScanBuilder. */ override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder = - new StacScanBuilder(stacCollectionJson, opts) + new StacScanBuilder(stacCollectionJson, opts, broadcastConf) } object StacTable { diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala index 508d6986b5..95c86c24c0 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala @@ -22,9 +22,13 @@ import com.fasterxml.jackson.databind.ObjectMapper import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow -import org.apache.spark.sql.catalyst.util.ArrayBasedMapData -import org.apache.spark.sql.types.{MapType, StringType, StructField, StructType} +import org.apache.spark.sql.execution.datasource.stac.TemporalFilter +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter +import org.apache.spark.sql.types.{StructField, StructType} +import org.locationtech.jts.geom.Envelope +import java.time.LocalDateTime +import java.time.format.DateTimeFormatter import scala.io.Source object StacUtils { @@ -133,9 +137,6 @@ object StacUtils { } } - /** - * Promote the properties field to the top level of the row. - */ def promotePropertiesToTop(row: InternalRow, schema: StructType): InternalRow = { val propertiesIndex = schema.fieldIndex("properties") val propertiesStruct = schema("properties").dataType.asInstanceOf[StructType] @@ -167,47 +168,6 @@ object StacUtils { StructType(newFields) } - /** - * Builds the output row with the raster field in the assets map. - * - * @param row - * The input row. - * @param schema - * The schema of the input row. - * @return - * The output row with the raster field in the assets map. - */ - def buildOutDbRasterFields(row: InternalRow, schema: StructType): InternalRow = { - val newValues = new Array[Any](schema.fields.length) - - schema.fields.zipWithIndex.foreach { - case (StructField("assets", MapType(StringType, valueType: StructType, _), _, _), index) => - val assetsMap = row.getMap(index) - if (assetsMap != null) { - val updatedAssets = assetsMap - .keyArray() - .array - .zip(assetsMap.valueArray().array) - .map { case (key, value) => - val assetRow = value.asInstanceOf[InternalRow] - if (assetRow != null) { - key -> assetRow - } else { - key -> null - } - } - .toMap - newValues(index) = ArrayBasedMapData(updatedAssets) - } else { - newValues(index) = null - } - case (_, index) => - newValues(index) = row.get(index, schema.fields(index).dataType) - } - - InternalRow.fromSeq(newValues) - } - /** * Returns the number of partitions to use for reading the data. * @@ -241,4 +201,86 @@ object StacUtils { Math.max(1, Math.ceil(itemCount.toDouble / maxSplitFiles).toInt) } } + + /** Returns the temporal filter string based on the temporal filter. */ + def getFilterBBox(filter: GeoParquetSpatialFilter): String = { + def calculateUnionBBox(filter: GeoParquetSpatialFilter): Envelope = { + filter match { + case GeoParquetSpatialFilter.AndFilter(left, right) => + val leftEnvelope = calculateUnionBBox(left) + val rightEnvelope = calculateUnionBBox(right) + leftEnvelope.expandToInclude(rightEnvelope) + leftEnvelope + case GeoParquetSpatialFilter.OrFilter(left, right) => + val leftEnvelope = calculateUnionBBox(left) + val rightEnvelope = calculateUnionBBox(right) + leftEnvelope.expandToInclude(rightEnvelope) + leftEnvelope + case leaf: GeoParquetSpatialFilter.LeafFilter => + leaf.queryWindow.getEnvelopeInternal + } + } + + val unionEnvelope = calculateUnionBBox(filter) + s"bbox=${unionEnvelope.getMinX}%2C${unionEnvelope.getMinY}%2C${unionEnvelope.getMaxX}%2C${unionEnvelope.getMaxY}" + } + + /** Returns the temporal filter string based on the temporal filter. */ + def getFilterTemporal(filter: TemporalFilter): String = { + val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'") + + def formatDateTime(dateTime: LocalDateTime): String = { + if (dateTime == null) ".." else dateTime.format(formatter) + } + + def calculateUnionTemporal(filter: TemporalFilter): (LocalDateTime, LocalDateTime) = { + filter match { + case TemporalFilter.AndFilter(left, right) => + val (leftStart, leftEnd) = calculateUnionTemporal(left) + val (rightStart, rightEnd) = calculateUnionTemporal(right) + val start = + if (leftStart == null || (rightStart != null && rightStart.isBefore(leftStart))) + rightStart + else leftStart + val end = + if (leftEnd == null || (rightEnd != null && rightEnd.isAfter(leftEnd))) rightEnd + else leftEnd + (start, end) + case TemporalFilter.OrFilter(left, right) => + val (leftStart, leftEnd) = calculateUnionTemporal(left) + val (rightStart, rightEnd) = calculateUnionTemporal(right) + val start = + if (leftStart == null || (rightStart != null && rightStart.isBefore(leftStart))) + rightStart + else leftStart + val end = + if (leftEnd == null || (rightEnd != null && rightEnd.isAfter(leftEnd))) rightEnd + else leftEnd + (start, end) + case TemporalFilter.LessThanFilter(_, value) => + (null, value) + case TemporalFilter.GreaterThanFilter(_, value) => + (value, null) + case TemporalFilter.EqualFilter(_, value) => + (value, value) + } + } + + val (start, end) = calculateUnionTemporal(filter) + if (end == null) s"datetime=${formatDateTime(start)}/.." + else s"datetime=${formatDateTime(start)}/${formatDateTime(end)}" + } + + /** Adds the spatial and temporal filters to the base URL. */ + def addFiltersToUrl( + baseUrl: String, + spatialFilter: Option[GeoParquetSpatialFilter], + temporalFilter: Option[TemporalFilter]): String = { + val spatialFilterStr = spatialFilter.map(StacUtils.getFilterBBox).getOrElse("") + val temporalFilterStr = temporalFilter.map(StacUtils.getFilterTemporal).getOrElse("") + + val filters = Seq(spatialFilterStr, temporalFilterStr).filter(_.nonEmpty).mkString("&") + val urlWithFilters = if (filters.nonEmpty) s"&$filters" else "" + s"$baseUrl$urlWithFilters" + } } diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialTemporalFilterPushDownForStacScan.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialTemporalFilterPushDownForStacScan.scala index 566d368d69..c32af552b3 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialTemporalFilterPushDownForStacScan.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialTemporalFilterPushDownForStacScan.scala @@ -20,7 +20,7 @@ package org.apache.spark.sql.sedona_sql.optimization import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression, GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Or, SubqueryExpression} -import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan} +import org.apache.spark.sql.catalyst.plans.logical.{Filter, GlobalLimit, LocalLimit, LogicalPlan} import org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath import org.apache.spark.sql.execution.datasource.stac.TemporalFilter import org.apache.spark.sql.execution.datasource.stac.TemporalFilter.{AndFilter => TemporalAndFilter} @@ -79,10 +79,25 @@ class SpatialTemporalFilterPushDownForStacScan(sparkSession: SparkSession) filter.copy() } filter.copy() + case lr: DataSourceV2ScanRelation if isStacScanRelation(lr) => + val scan = lr.scan.asInstanceOf[StacScan] + val limit = extractLimit(plan) + limit match { + case Some(n) => scan.setLimit(n) + case None => + } + lr } } } + def extractLimit(plan: LogicalPlan): Option[Int] = { + plan.collectFirst { + case GlobalLimit(Literal(limit: Int, _), _) => limit + case LocalLimit(Literal(limit: Int, _), _) => limit + } + } + private def isStacScanRelation(lr: DataSourceV2ScanRelation): Boolean = lr.scan.isInstanceOf[StacScan] diff --git a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatchTest.scala b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatchTest.scala index 0765b3950f..612ce2bb52 100644 --- a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatchTest.scala +++ b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatchTest.scala @@ -22,6 +22,8 @@ import org.apache.sedona.sql.TestBaseScala import org.apache.spark.sql.connector.read.InputPartition import org.apache.spark.sql.types.StructType +import java.time.format.DateTimeFormatter +import java.time.{LocalDate, ZoneOffset} import scala.io.Source import scala.collection.mutable @@ -40,6 +42,41 @@ class StacBatchTest extends TestBaseScala { } } + it("collectItemLinks should collect correct item links") { + val collectionUrl = + "https://earth-search.aws.element84.com/v1/collections/sentinel-2-pre-c1-l2a" + val stacCollectionJson = StacUtils.loadStacCollectionToJson(collectionUrl) + val opts = mutable + .Map( + "itemsLimitMax" -> "1000", + "itemsLimitPerRequest" -> "200", + "itemsLoadProcessReportThreshold" -> "1000000") + .toMap + + val stacBatch = + StacBatch( + null, + collectionUrl, + stacCollectionJson, + StructType(Seq()), + opts, + None, + None, + None) + stacBatch.setItemMaxLeft(1000) + val itemLinks = mutable.ArrayBuffer[String]() + val needCountNextItems = true + + val startTime = System.nanoTime() + stacBatch.collectItemLinks(collectionUrl, stacCollectionJson, itemLinks, needCountNextItems) + val endTime = System.nanoTime() + val duration = (endTime - startTime) / 1e6 // Convert to milliseconds + + assert(itemLinks.nonEmpty) + assert(itemLinks.length == 5) + assert(duration > 0) + } + it("planInputPartitions should create correct number of partitions") { val stacCollectionJson = """ @@ -48,18 +85,26 @@ class StacBatchTest extends TestBaseScala { | "id": "sample-collection", | "description": "A sample STAC collection", | "links": [ - | {"rel": "item", "href": "https://path/to/item1.json"}, - | {"rel": "item", "href": "https://path/to/item2.json"}, - | {"rel": "item", "href": "https://path/to/item3.json"} + | {"rel": "item", "href": "https://storage.googleapis.com/cfo-public/vegetation/California-Vegetation-CanopyBaseHeight-2016-Summer-00010m.json"}, + | {"rel": "item", "href": "https://storage.googleapis.com/cfo-public/vegetation/California-Vegetation-CanopyBaseHeight-2016-Summer-00010m.json"}, + | {"rel": "item", "href": "https://storage.googleapis.com/cfo-public/vegetation/California-Vegetation-CanopyBaseHeight-2016-Summer-00010m.json"} | ] |} """.stripMargin - val opts = mutable.Map("numPartitions" -> "2").toMap - val collectionUrl = "https://path/to/collection.json" + val opts = mutable.Map("numPartitions" -> "2", "itemsLimitMax" -> "20").toMap + val collectionUrl = "https://storage.googleapis.com/cfo-public/vegetation/collection.json" val stacBatch = - StacBatch(collectionUrl, stacCollectionJson, StructType(Seq()), opts, None, None) + StacBatch( + null, + collectionUrl, + stacCollectionJson, + StructType(Seq()), + opts, + None, + None, + None) val partitions: Array[InputPartition] = stacBatch.planInputPartitions() assert(partitions.length == 2) @@ -75,11 +120,19 @@ class StacBatchTest extends TestBaseScala { |} """.stripMargin - val opts = mutable.Map("numPartitions" -> "2").toMap + val opts = mutable.Map("numPartitions" -> "2", "itemsLimitMax" -> "20").toMap val collectionUrl = "https://path/to/collection.json" val stacBatch = - StacBatch(collectionUrl, stacCollectionJson, StructType(Seq()), opts, None, None) + StacBatch( + null, + collectionUrl, + stacCollectionJson, + StructType(Seq()), + opts, + None, + None, + None) val partitions: Array[InputPartition] = stacBatch.planInputPartitions() assert(partitions.isEmpty) @@ -88,11 +141,19 @@ class StacBatchTest extends TestBaseScala { it("planInputPartitions should create correct number of partitions with real collection.json") { val rootJsonFile = "datasource_stac/collection.json" val stacCollectionJson = loadJsonFromResource(rootJsonFile) - val opts = mutable.Map("numPartitions" -> "3").toMap + val opts = mutable.Map("numPartitions" -> "3", "itemsLimitMax" -> "20").toMap val collectionUrl = getAbsolutePathOfResource(rootJsonFile) val stacBatch = - StacBatch(collectionUrl, stacCollectionJson, StructType(Seq()), opts, None, None) + StacBatch( + null, + collectionUrl, + stacCollectionJson, + StructType(Seq()), + opts, + None, + None, + None) val partitions: Array[InputPartition] = stacBatch.planInputPartitions() assert(partitions.length == 3) diff --git a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala index 2cab18b694..5603b3e551 100644 --- a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala +++ b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala @@ -19,8 +19,8 @@ package org.apache.spark.sql.sedona_sql.io.stac import org.apache.sedona.sql.TestBaseScala -import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT -import org.apache.spark.sql.types.{ArrayType, DoubleType, MapType, StringType, StructField, StructType, TimestampType} +import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT} +import org.apache.spark.sql.types._ class StacDataSourceTest extends TestBaseScala { @@ -88,7 +88,7 @@ class StacDataSourceTest extends TestBaseScala { val dfSelect = sparkSession.sql( "SELECT id, geometry " + "FROM STACTBL " + - "WHERE st_contains(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11, 17 10))'), geometry)") + "WHERE st_intersects(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11, 17 10))'), geometry)") val physicalPlan = dfSelect.queryExecution.executedPlan.toString() assert(physicalPlan.contains( @@ -102,10 +102,11 @@ class StacDataSourceTest extends TestBaseScala { val dfStac = sparkSession.read.format("stac").load(STAC_COLLECTION_LOCAL) dfStac.createOrReplaceTempView("STACTBL") - val dfSelect = sparkSession.sql("SELECT id, datetime as dt, geometry, bbox " + - "FROM STACTBL " + - "WHERE datetime BETWEEN '2020-01-01T00:00:00Z' AND '2020-12-13T00:00:00Z' " + - "AND st_contains(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11, 17 10))'), geometry)") + val dfSelect = sparkSession.sql( + "SELECT id, datetime as dt, geometry, bbox " + + "FROM STACTBL " + + "WHERE datetime BETWEEN '2020-01-01T00:00:00Z' AND '2020-12-13T00:00:00Z' " + + "AND st_intersects(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11, 17 10))'), geometry)") val physicalPlan = dfSelect.queryExecution.executedPlan.toString() assert(physicalPlan.contains( @@ -137,11 +138,12 @@ class StacDataSourceTest extends TestBaseScala { val dfStac = sparkSession.read.format("stac").load(STAC_COLLECTION_LOCAL) dfStac.createOrReplaceTempView("STACTBL") - val dfSelect = sparkSession.sql("SELECT id, datetime as dt, geometry, bbox " + - "FROM STACTBL " + - "WHERE id = 'some-id' " + - "AND datetime BETWEEN '2020-01-01T00:00:00Z' AND '2020-12-13T00:00:00Z' " + - "AND st_contains(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11, 17 10))'), geometry)") + val dfSelect = sparkSession.sql( + "SELECT id, datetime as dt, geometry, bbox " + + "FROM STACTBL " + + "WHERE id = 'some-id' " + + "AND datetime BETWEEN '2020-01-01T00:00:00Z' AND '2020-12-13T00:00:00Z' " + + "AND st_intersects(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11, 17 10))'), geometry)") val physicalPlan = dfSelect.queryExecution.executedPlan.toString() assert(physicalPlan.contains( diff --git a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReaderTest.scala b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReaderTest.scala index fc6f8dcfdc..428327e7cc 100644 --- a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReaderTest.scala +++ b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReaderTest.scala @@ -18,8 +18,10 @@ */ package org.apache.spark.sql.sedona_sql.io.stac +import org.apache.hadoop.conf.Configuration import org.apache.sedona.sql.TestBaseScala import org.apache.spark.sql.catalyst.InternalRow +import org.apache.spark.util.SerializableConfiguration import scala.jdk.CollectionConverters._ @@ -40,6 +42,7 @@ class StacPartitionReaderTest extends TestBaseScala { val partition = StacPartition(0, jsonFiles, Map.empty[String, String].asJava) val reader = new StacPartitionReader( + sparkSession.sparkContext.broadcast(new SerializableConfiguration(new Configuration())), partition, StacTable.SCHEMA_V1_1_0, Map.empty[String, String], @@ -61,6 +64,7 @@ class StacPartitionReaderTest extends TestBaseScala { val partition = StacPartition(0, jsonFiles, Map.empty[String, String].asJava) val reader = new StacPartitionReader( + sparkSession.sparkContext.broadcast(new SerializableConfiguration(new Configuration())), partition, StacTable.SCHEMA_V1_1_0, Map.empty[String, String], @@ -82,6 +86,7 @@ class StacPartitionReaderTest extends TestBaseScala { val partition = StacPartition(0, jsonFiles, Map.empty[String, String].asJava) val reader = new StacPartitionReader( + sparkSession.sparkContext.broadcast(new SerializableConfiguration(new Configuration())), partition, StacTable.SCHEMA_V1_1_0, Map.empty[String, String], diff --git a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala index 75542c760d..d2940abf41 100644 --- a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala +++ b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala @@ -21,16 +21,25 @@ package org.apache.spark.sql.sedona_sql.io.stac import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper} import com.fasterxml.jackson.module.scala.DefaultScalaModule import org.apache.hadoop.fs.{FileSystem, Path} +import org.apache.sedona.core.spatialOperator.SpatialPredicate import org.apache.spark.sql.SparkSession -import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.getNumPartitions +import org.apache.spark.sql.execution.datasource.stac.TemporalFilter +import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter +import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.{getFilterBBox, getFilterTemporal, getNumPartitions} +import org.locationtech.jts.geom.{Envelope, GeometryFactory, Polygon} +import org.locationtech.jts.io.WKTReader import org.scalatest.funsuite.AnyFunSuite import java.io.{File, PrintWriter} +import java.time.LocalDateTime import scala.io.Source import scala.jdk.CollectionConverters._ class StacUtilsTest extends AnyFunSuite { + val geometryFactory = new GeometryFactory() + val wktReader = new WKTReader(geometryFactory) + test("getStacCollectionBasePath should return base URL for HTTP URL") { val opts = Map("path" -> "https://service_url/collections/collection.json") val result = StacUtils.getStacCollectionBasePath(opts) @@ -591,4 +600,121 @@ class StacUtilsTest extends AnyFunSuite { writer.close() } } + + test("getFilterBBox with LeafFilter") { + val queryWindow: Polygon = + wktReader.read("POLYGON((10 10, 20 10, 20 20, 10 20, 10 10))").asInstanceOf[Polygon] + val leafFilter = + GeoParquetSpatialFilter.LeafFilter("geometry", SpatialPredicate.INTERSECTS, queryWindow) + val bbox = getFilterBBox(leafFilter) + assert(bbox == "bbox=10.0%2C10.0%2C20.0%2C20.0") + } + + test("getFilterBBox with AndFilter") { + val queryWindow1: Polygon = + wktReader.read("POLYGON((10 10, 20 10, 20 20, 10 20, 10 10))").asInstanceOf[Polygon] + val queryWindow2: Polygon = + wktReader.read("POLYGON((30 30, 40 30, 40 40, 30 40, 30 30))").asInstanceOf[Polygon] + val leafFilter1 = + GeoParquetSpatialFilter.LeafFilter("geometry", SpatialPredicate.INTERSECTS, queryWindow1) + val leafFilter2 = + GeoParquetSpatialFilter.LeafFilter("geometry", SpatialPredicate.INTERSECTS, queryWindow2) + val andFilter = GeoParquetSpatialFilter.AndFilter(leafFilter1, leafFilter2) + val bbox = getFilterBBox(andFilter) + assert(bbox == "bbox=10.0%2C10.0%2C40.0%2C40.0") + } + + test("getFilterBBox with OrFilter") { + val queryWindow1: Polygon = + wktReader.read("POLYGON((10 10, 20 10, 20 20, 10 20, 10 10))").asInstanceOf[Polygon] + val queryWindow2: Polygon = + wktReader.read("POLYGON((30 30, 40 30, 40 40, 30 40, 30 30))").asInstanceOf[Polygon] + val leafFilter1 = + GeoParquetSpatialFilter.LeafFilter("geometry", SpatialPredicate.INTERSECTS, queryWindow1) + val leafFilter2 = + GeoParquetSpatialFilter.LeafFilter("geometry", SpatialPredicate.INTERSECTS, queryWindow2) + val orFilter = GeoParquetSpatialFilter.OrFilter(leafFilter1, leafFilter2) + val bbox = getFilterBBox(orFilter) + assert(bbox == "bbox=10.0%2C10.0%2C40.0%2C40.0") + } + + test("getFilterTemporal with LessThanFilter") { + val dateTime = LocalDateTime.parse("2025-03-07T00:00:00") + val filter = TemporalFilter.LessThanFilter("timestamp", dateTime) + val result = getFilterTemporal(filter) + assert(result == "datetime=../2025-03-07T00:00:00.000Z") + } + + test("getFilterTemporal with GreaterThanFilter") { + val dateTime = LocalDateTime.parse("2025-03-06T00:00:00") + val filter = TemporalFilter.GreaterThanFilter("timestamp", dateTime) + val result = getFilterTemporal(filter) + assert(result == "datetime=2025-03-06T00:00:00.000Z/..") + } + + test("getFilterTemporal with EqualFilter") { + val dateTime = LocalDateTime.parse("2025-03-06T00:00:00") + val filter = TemporalFilter.EqualFilter("timestamp", dateTime) + val result = getFilterTemporal(filter) + assert(result == "datetime=2025-03-06T00:00:00.000Z/2025-03-06T00:00:00.000Z") + } + + test("getFilterTemporal with AndFilter") { + val dateTime1 = LocalDateTime.parse("2025-03-06T00:00:00") + val dateTime2 = LocalDateTime.parse("2025-03-07T00:00:00") + val filter1 = TemporalFilter.GreaterThanFilter("timestamp", dateTime1) + val filter2 = TemporalFilter.LessThanFilter("timestamp", dateTime2) + val andFilter = TemporalFilter.AndFilter(filter1, filter2) + val result = getFilterTemporal(andFilter) + assert(result == "datetime=2025-03-06T00:00:00.000Z/2025-03-07T00:00:00.000Z") + } + + test("getFilterTemporal with OrFilter") { + val dateTime1 = LocalDateTime.parse("2025-03-06T00:00:00") + val dateTime2 = LocalDateTime.parse("2025-03-07T00:00:00") + val filter1 = TemporalFilter.GreaterThanFilter("timestamp", dateTime1) + val filter2 = TemporalFilter.LessThanFilter("timestamp", dateTime2) + val orFilter = TemporalFilter.OrFilter(filter1, filter2) + val result = getFilterTemporal(orFilter) + assert(result == "datetime=2025-03-06T00:00:00.000Z/2025-03-07T00:00:00.000Z") + } + + test("addFiltersToUrl with no filters") { + val baseUrl = "http://example.com/stac" + val result = StacUtils.addFiltersToUrl(baseUrl, None, None) + assert(result == "http://example.com/stac") + } + + test("addFiltersToUrl with spatial filter") { + val baseUrl = "http://example.com/stac" + val envelope = new Envelope(1.0, 2.0, 3.0, 4.0) + val queryWindow = geometryFactory.toGeometry(envelope) + val spatialFilter = Some( + GeoParquetSpatialFilter.LeafFilter("geometry", SpatialPredicate.INTERSECTS, queryWindow)) + val result = StacUtils.addFiltersToUrl(baseUrl, spatialFilter, None) + val expectedUrl = s"$baseUrl&bbox=1.0%2C3.0%2C2.0%2C4.0" + assert(result == expectedUrl) + } + + test("addFiltersToUrl with temporal filter") { + val baseUrl = "http://example.com/stac" + val temporalFilter = Some( + TemporalFilter.GreaterThanFilter("timestamp", LocalDateTime.parse("2025-03-06T00:00:00"))) + val result = StacUtils.addFiltersToUrl(baseUrl, None, temporalFilter) + val expectedUrl = s"$baseUrl&datetime=2025-03-06T00:00:00.000Z/.." + assert(result == expectedUrl) + } + + test("addFiltersToUrl with both spatial and temporal filters") { + val baseUrl = "http://example.com/stac" + val envelope = new Envelope(1.0, 2.0, 3.0, 4.0) + val queryWindow = geometryFactory.toGeometry(envelope) + val spatialFilter = Some( + GeoParquetSpatialFilter.LeafFilter("geometry", SpatialPredicate.INTERSECTS, queryWindow)) + val temporalFilter = Some( + TemporalFilter.GreaterThanFilter("timestamp", LocalDateTime.parse("2025-03-06T00:00:00"))) + val result = StacUtils.addFiltersToUrl(baseUrl, spatialFilter, temporalFilter) + val expectedUrl = s"$baseUrl&bbox=1.0%2C3.0%2C2.0%2C4.0&datetime=2025-03-06T00:00:00.000Z/.." + assert(result == expectedUrl) + } }
