This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new a5504fed5c [SEDONA-704] Optimize STAC reader and fix few issues (#1861)
a5504fed5c is described below
commit a5504fed5c2253ae34ad14729fe35425f6ce1fe6
Author: Feng Zhang <[email protected]>
AuthorDate: Mon Mar 17 14:30:18 2025 -0700
[SEDONA-704] Optimize STAC reader and fix few issues (#1861)
* [SEDONA-704] Optimize STAC reader and fix few issues
* fix formating issue
* fix compiling error
* another fix
---
docs/tutorial/files/stac-sedona-spark.md | 4 +-
python/sedona/stac/collection_client.py | 93 ++++++--------
python/tests/stac/test_client.py | 5 +-
python/tests/stac/test_collection_client.py | 6 +-
.../spark/sql/sedona_sql/io/stac/StacBatch.scala | 36 +++++-
.../sql/sedona_sql/io/stac/StacDataSource.scala | 17 ++-
.../sedona_sql/io/stac/StacPartitionReader.scala | 8 +-
.../spark/sql/sedona_sql/io/stac/StacScan.scala | 26 +++-
.../sql/sedona_sql/io/stac/StacScanBuilder.scala | 10 +-
.../spark/sql/sedona_sql/io/stac/StacTable.scala | 9 +-
.../spark/sql/sedona_sql/io/stac/StacUtils.scala | 134 ++++++++++++++-------
.../SpatialTemporalFilterPushDownForStacScan.scala | 17 ++-
.../sql/sedona_sql/io/stac/StacBatchTest.scala | 81 +++++++++++--
.../sedona_sql/io/stac/StacDataSourceTest.scala | 26 ++--
.../io/stac/StacPartitionReaderTest.scala | 5 +
.../sql/sedona_sql/io/stac/StacUtilsTest.scala | 128 +++++++++++++++++++-
16 files changed, 450 insertions(+), 155 deletions(-)
diff --git a/docs/tutorial/files/stac-sedona-spark.md
b/docs/tutorial/files/stac-sedona-spark.md
index 062e6c5f55..aff0e29f0f 100644
--- a/docs/tutorial/files/stac-sedona-spark.md
+++ b/docs/tutorial/files/stac-sedona-spark.md
@@ -116,7 +116,7 @@ The STAC data source supports predicate pushdown for
spatial and temporal filter
### Spatial Filter Pushdown
-Spatial filter pushdown allows the data source to apply spatial predicates
(e.g., st_contains, st_intersects) directly at the data source level, reducing
the amount of data transferred and processed.
+Spatial filter pushdown allows the data source to apply spatial predicates
(e.g., st_intersects) directly at the data source level, reducing the amount of
data transferred and processed.
### Temporal Filter Pushdown
@@ -147,7 +147,7 @@ In this example, the data source will push down the
temporal filter to the under
```sql
SELECT id, geometry
FROM STAC_TABLE
- WHERE st_contains(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11, 17
10))'), geometry)
+ WHERE st_intersects(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11, 17
10))'), geometry)
```
In this example, the data source will push down the spatial filter to the
underlying data source.
diff --git a/python/sedona/stac/collection_client.py
b/python/sedona/stac/collection_client.py
index b1cae6df39..7670c6291d 100644
--- a/python/sedona/stac/collection_client.py
+++ b/python/sedona/stac/collection_client.py
@@ -119,7 +119,7 @@ class CollectionClient:
- DataFrame: The filtered Spark DataFrame.
The function constructs SQL conditions for spatial and temporal
filters and applies them to the DataFrame.
- If bbox is provided, it constructs spatial conditions using
st_contains and ST_GeomFromText.
+ If bbox is provided, it constructs spatial conditions using
st_intersects and ST_GeomFromText.
If datetime is provided, it constructs temporal conditions using the
datetime column.
The conditions are combined using OR logic.
"""
@@ -131,7 +131,7 @@ class CollectionClient:
f"{bbox[2]} {bbox[3]}, {bbox[0]} {bbox[3]}, {bbox[0]}
{bbox[1]}))"
)
bbox_conditions.append(
- f"st_contains(ST_GeomFromText('{polygon_wkt}'), geometry)"
+ f"st_intersects(ST_GeomFromText('{polygon_wkt}'),
geometry)"
)
bbox_sql_condition = " OR ".join(bbox_conditions)
df = df.filter(bbox_sql_condition)
@@ -217,34 +217,7 @@ class CollectionClient:
is raised with a message indicating the failure.
"""
try:
- # Load the collection data from the specified collection URL
- df = self.spark.read.format("stac").load(self.collection_url)
-
- # Apply ID filters if provided
- if ids:
- if isinstance(ids, tuple):
- ids = list(ids)
- if isinstance(ids, str):
- ids = [ids]
- df = df.filter(df.id.isin(ids))
-
- # Ensure bbox is a list of lists
- if bbox and isinstance(bbox[0], float):
- bbox = [bbox]
-
- # Handle datetime parameter
- if datetime:
- if isinstance(datetime, (str, python_datetime.datetime)):
- datetime = [self._expand_date(str(datetime))]
- elif isinstance(datetime, list) and isinstance(datetime[0],
str):
- datetime = [datetime]
-
- # Apply spatial and temporal filters
- df = self._apply_spatial_temporal_filters(df, bbox, datetime)
-
- # Limit the number of items if max_items is specified
- if max_items is not None:
- df = df.limit(max_items)
+ df = self.load_items_df(bbox, datetime, ids, max_items)
# Collect the filtered rows and convert them to PyStacItem objects
items = []
@@ -291,32 +264,7 @@ class CollectionClient:
is raised with a message indicating the failure.
"""
try:
- df = self.spark.read.format("stac").load(self.collection_url)
-
- # Apply ID filters if provided
- if ids:
- if isinstance(ids, tuple):
- ids = list(ids)
- if isinstance(ids, str):
- ids = [ids]
- df = df.filter(df.id.isin(ids))
-
- # Ensure bbox is a list of lists
- if bbox and isinstance(bbox[0], float):
- bbox = [bbox]
-
- # Handle datetime parameter
- if datetime:
- if isinstance(datetime, (str, python_datetime.datetime)):
- datetime = [[str(datetime), str(datetime)]]
- elif isinstance(datetime, list) and isinstance(datetime[0],
str):
- datetime = [datetime]
-
- df = self._apply_spatial_temporal_filters(df, bbox, datetime)
-
- # Limit the number of items if max_items is specified
- if max_items is not None:
- df = df.limit(max_items)
+ df = self.load_items_df(bbox, datetime, ids, max_items)
return df
except Exception as e:
@@ -394,5 +342,38 @@ class CollectionClient:
return df
+ def load_items_df(self, bbox, datetime, ids, max_items):
+ # Load the collection data from the specified collection URL
+ if not ids and not bbox and not datetime and max_items is not None:
+ df = (
+ self.spark.read.format("stac")
+ .option("itemsLimitMax", max_items)
+ .load(self.collection_url)
+ )
+ else:
+ df = self.spark.read.format("stac").load(self.collection_url)
+ # Apply ID filters if provided
+ if ids:
+ if isinstance(ids, tuple):
+ ids = list(ids)
+ if isinstance(ids, str):
+ ids = [ids]
+ df = df.filter(df.id.isin(ids))
+ # Ensure bbox is a list of lists
+ if bbox and isinstance(bbox[0], float):
+ bbox = [bbox]
+ # Handle datetime parameter
+ if datetime:
+ if isinstance(datetime, (str, python_datetime.datetime)):
+ datetime = [self._expand_date(str(datetime))]
+ elif isinstance(datetime, list) and isinstance(datetime[0],
str):
+ datetime = [datetime]
+ # Apply spatial and temporal filters
+ df = self._apply_spatial_temporal_filters(df, bbox, datetime)
+ # Limit the number of items if max_items is specified
+ if max_items is not None:
+ df = df.limit(max_items)
+ return df
+
def __str__(self):
return f"<CollectionClient id={self.collection_id}>"
diff --git a/python/tests/stac/test_client.py b/python/tests/stac/test_client.py
index 4f9919e1c5..b8b2beed8a 100644
--- a/python/tests/stac/test_client.py
+++ b/python/tests/stac/test_client.py
@@ -36,7 +36,7 @@ class TestStacClient(TestBase):
return_dataframe=False,
)
assert items is not None
- assert len(list(items)) == 2
+ assert len(list(items)) > 0
def test_search_with_ids(self) -> None:
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
@@ -82,7 +82,7 @@ class TestStacClient(TestBase):
return_dataframe=False,
)
assert items is not None
- assert len(list(items)) == 4
+ assert len(list(items)) > 0
def test_search_with_bbox_and_non_overlapping_intervals(self) -> None:
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
@@ -144,6 +144,7 @@ class TestStacClient(TestBase):
datetime=["2006-01-01T00:00:00Z", "2007-01-01T00:00:00Z"],
)
assert df is not None
+ assert df.count() == 20
assert isinstance(df, DataFrame)
def test_search_with_catalog_url(self) -> None:
diff --git a/python/tests/stac/test_collection_client.py
b/python/tests/stac/test_collection_client.py
index 24226f86ca..aa9a15f4eb 100644
--- a/python/tests/stac/test_collection_client.py
+++ b/python/tests/stac/test_collection_client.py
@@ -71,7 +71,7 @@ class TestStacReader(TestBase):
bbox = [[-100.0, -72.0, 105.0, -69.0]]
items = list(collection.get_items(bbox=bbox))
assert items is not None
- assert len(items) == 2
+ assert len(items) > 0
def test_get_items_with_temporal_extent(self) -> None:
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
@@ -88,7 +88,7 @@ class TestStacReader(TestBase):
datetime = [["2006-12-01T00:00:00Z", "2006-12-27T03:00:00Z"]]
items = list(collection.get_items(bbox=bbox, datetime=datetime))
assert items is not None
- assert len(items) == 4
+ assert len(items) > 0
def test_get_items_with_multiple_bboxes_and_interval(self) -> None:
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
@@ -111,7 +111,7 @@ class TestStacReader(TestBase):
datetime = [["2006-12-01T00:00:00Z", "2006-12-27T03:00:00Z"]]
items = list(collection.get_items(bbox=bbox, datetime=datetime))
assert items is not None
- assert len(items) == 4
+ assert len(items) > 0
def test_get_items_with_ids(self) -> None:
client = Client.open(STAC_URLS["PLANETARY-COMPUTER"])
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
index 5994c463fd..4ef53756d2 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
@@ -19,16 +19,19 @@
package org.apache.spark.sql.sedona_sql.io.stac
import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.connector.read.{Batch, InputPartition,
PartitionReaderFactory}
import org.apache.spark.sql.execution.datasource.stac.TemporalFilter
import
org.apache.spark.sql.execution.datasources.parquet.{GeoParquetSpatialFilter,
GeometryFieldMetaData}
import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.getNumPartitions
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.SerializableConfiguration
import java.time.LocalDateTime
import java.time.format.DateTimeFormatterBuilder
import java.time.temporal.ChronoField
import scala.jdk.CollectionConverters._
+import scala.util.Random
import scala.util.control.Breaks.breakable
/**
@@ -40,12 +43,14 @@ import scala.util.control.Breaks.breakable
* which are necessary for batch data processing.
*/
case class StacBatch(
+ broadcastConf: Broadcast[SerializableConfiguration],
stacCollectionUrl: String,
stacCollectionJson: String,
schema: StructType,
opts: Map[String, String],
spatialFilter: Option[GeoParquetSpatialFilter],
- temporalFilter: Option[TemporalFilter])
+ temporalFilter: Option[TemporalFilter],
+ limitFilter: Option[Int])
extends Batch {
private val defaultItemsLimitPerRequest: Int = {
@@ -82,10 +87,16 @@ case class StacBatch(
// Initialize the itemLinks array
val itemLinks = scala.collection.mutable.ArrayBuffer[String]()
- // Start the recursive collection of item links
- val itemsLimitMax = opts.getOrElse("itemsLimitMax", "-1").toInt
+ // Get the maximum number of items to process
+ val itemsLimitMax = limitFilter match {
+ case Some(limit) if limit >= 0 => limit
+ case _ => opts.getOrElse("itemsLimitMax", "-1").toInt
+ }
val checkItemsLimitMax = itemsLimitMax > 0
+
+ // Start the recursive collection of item links
setItemMaxLeft(itemsLimitMax)
+
collectItemLinks(stacCollectionBasePath, stacCollectionJson, itemLinks,
checkItemsLimitMax)
// Handle when the number of items is less than 1
@@ -109,8 +120,9 @@ case class StacBatch(
// Determine how many items to put in each partition
val partitionSize = Math.ceil(itemLinks.length.toDouble /
numPartitions).toInt
- // Group the item links into partitions
- itemLinks
+ // Group the item links into partitions, but randomize first for better
load balancing
+ Random
+ .shuffle(itemLinks)
.grouped(partitionSize)
.zipWithIndex
.map { case (items, index) =>
@@ -229,7 +241,7 @@ case class StacBatch(
} else if (rel == "items" && href.startsWith("http")) {
// iterate through the items and check if the limit is reached (if
needed)
if (iterateItemsWithLimit(
- itemUrl + "?limit=" + defaultItemsLimitPerRequest,
+ getItemLink(itemUrl, defaultItemsLimitPerRequest,
spatialFilter, temporalFilter),
needCountNextItems)) return
}
}
@@ -256,6 +268,17 @@ case class StacBatch(
}
}
+ /** Adds an item link to the list of item links. */
+ def getItemLink(
+ itemUrl: String,
+ defaultItemsLimitPerRequest: Int,
+ spatialFilter: Option[GeoParquetSpatialFilter],
+ temporalFilter: Option[TemporalFilter]): String = {
+ val baseUrl = itemUrl + "?limit=" + defaultItemsLimitPerRequest
+ val urlWithFilters = StacUtils.addFiltersToUrl(baseUrl, spatialFilter,
temporalFilter)
+ urlWithFilters
+ }
+
/**
* Filters a collection based on the provided spatial and temporal filters.
*
@@ -361,6 +384,7 @@ case class StacBatch(
override def createReaderFactory(): PartitionReaderFactory = { (partition:
InputPartition) =>
{
new StacPartitionReader(
+ broadcastConf,
partition.asInstanceOf[StacPartition],
schema,
opts,
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala
index dc2dc3e6dd..154adf8eca 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSource.scala
@@ -18,15 +18,16 @@
*/
package org.apache.spark.sql.sedona_sql.io.stac
-import StacUtils.{inferStacSchema, updatePropertiesPromotedSchema}
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.connector.catalog.{Table, TableProvider}
import org.apache.spark.sql.connector.expressions.Transform
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.sedona_sql.io.geojson.GeoJSONUtils
+import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.{inferStacSchema,
updatePropertiesPromotedSchema}
import org.apache.spark.sql.sources.DataSourceRegister
import org.apache.spark.sql.types.StructType
import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.util.SerializableConfiguration
import java.util
import java.util.concurrent.ConcurrentHashMap
@@ -100,6 +101,7 @@ class StacDataSource() extends TableProvider with
DataSourceRegister {
partitioning: Array[Transform],
properties: util.Map[String, String]): Table = {
val opts = new CaseInsensitiveStringMap(properties)
+ val sparkSession = SparkSession.active
val optsMap: Map[String, String] = opts.asCaseSensitiveMap().asScala.toMap
++ Map(
"sessionLocalTimeZone" ->
SparkSession.active.sessionState.conf.sessionLocalTimeZone,
@@ -107,17 +109,20 @@ class StacDataSource() extends TableProvider with
DataSourceRegister {
"defaultParallelism" ->
SparkSession.active.sparkContext.defaultParallelism.toString,
"maxPartitionItemFiles" -> SparkSession.active.conf
.get("spark.sedona.stac.load.maxPartitionItemFiles", "0"),
- "numPartitions" -> SparkSession.active.conf
+ "numPartitions" -> sparkSession.conf
.get("spark.sedona.stac.load.numPartitions", "-1"),
"itemsLimitMax" -> opts
.asCaseSensitiveMap()
.asScala
.toMap
- .get("itemsLimitMax")
- .filter(_.toInt > 0)
-
.getOrElse(SparkSession.active.conf.get("spark.sedona.stac.load.itemsLimitMax",
"-1")))
+ .getOrElse(
+ "itemsLimitMax",
+ sparkSession.conf.get("spark.sedona.stac.load.itemsLimitMax", "-1")))
val stacCollectionJsonString = StacUtils.loadStacCollectionToJson(optsMap)
+ val hadoopConf =
sparkSession.sessionState.newHadoopConfWithOptions(opts.asScala.toMap)
+ val broadcastedConf =
+ sparkSession.sparkContext.broadcast(new
SerializableConfiguration(hadoopConf))
- new StacTable(stacCollectionJson = stacCollectionJsonString, opts =
optsMap)
+ new StacTable(stacCollectionJson = stacCollectionJsonString, opts =
optsMap, broadcastedConf)
}
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala
index a545eb232f..b72ac526ca 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala
@@ -20,6 +20,7 @@ package org.apache.spark.sql.sedona_sql.io.stac
import com.fasterxml.jackson.databind.ObjectMapper
import org.apache.hadoop.conf.Configuration
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.catalyst.json.JSONOptionsInRead
import org.apache.spark.sql.connector.read.PartitionReader
@@ -28,14 +29,16 @@ import
org.apache.spark.sql.execution.datasources.PartitionedFile
import org.apache.spark.sql.execution.datasources.json.JsonDataSource
import
org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
import org.apache.spark.sql.sedona_sql.io.geojson.{GeoJSONUtils,
SparkCompatUtil}
-import
org.apache.spark.sql.sedona_sql.io.stac.StacUtils.{buildOutDbRasterFields,
promotePropertiesToTop}
+import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.promotePropertiesToTop
import org.apache.spark.sql.types.{StringType, StructType}
+import org.apache.spark.util.SerializableConfiguration
import java.io.{File, PrintWriter}
import java.lang.reflect.Constructor
import scala.io.Source
class StacPartitionReader(
+ broadcast: Broadcast[SerializableConfiguration],
partition: StacPartition,
schema: StructType,
opts: Map[String, String],
@@ -120,8 +123,7 @@ class StacPartitionReader(
rows.map(row => {
val geometryConvertedRow =
GeoJSONUtils.convertGeoJsonToGeometry(row, alteredSchema)
- val rasterAddedRow = buildOutDbRasterFields(geometryConvertedRow,
alteredSchema)
- val propertiesPromotedRow = promotePropertiesToTop(rasterAddedRow,
alteredSchema)
+ val propertiesPromotedRow =
promotePropertiesToTop(geometryConvertedRow, alteredSchema)
propertiesPromotedRow
})
} else {
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScan.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScan.scala
index 2edf082912..9ccc89fba8 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScan.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScan.scala
@@ -18,14 +18,19 @@
*/
package org.apache.spark.sql.sedona_sql.io.stac
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.connector.read.{Batch, Scan}
import org.apache.spark.sql.execution.datasource.stac.TemporalFilter
import
org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
import org.apache.spark.sql.internal.connector.SupportsMetadata
import
org.apache.spark.sql.sedona_sql.io.stac.StacUtils.{getFullCollectionUrl,
inferStacSchema}
import org.apache.spark.sql.types.StructType
+import org.apache.spark.util.SerializableConfiguration
-class StacScan(stacCollectionJson: String, opts: Map[String, String])
+class StacScan(
+ stacCollectionJson: String,
+ opts: Map[String, String],
+ broadcastConf: Broadcast[SerializableConfiguration])
extends Scan
with SupportsMetadata {
@@ -35,6 +40,8 @@ class StacScan(stacCollectionJson: String, opts: Map[String,
String])
// The temporal filter to be pushed down to the data source
var temporalFilter: Option[TemporalFilter] = None
+ var limit: Option[Int] = None
+
/**
* Returns the schema of the data to be read.
*
@@ -62,12 +69,14 @@ class StacScan(stacCollectionJson: String, opts:
Map[String, String])
override def toBatch: Batch = {
val stacCollectionUrl = getFullCollectionUrl(opts)
StacBatch(
+ broadcastConf,
stacCollectionUrl,
stacCollectionJson,
readSchema(),
opts,
spatialFilter,
- temporalFilter)
+ temporalFilter,
+ limit)
}
/**
@@ -90,6 +99,16 @@ class StacScan(stacCollectionJson: String, opts: Map[String,
String])
temporalFilter = Some(combineTemporalFilter)
}
+ /**
+ * Sets the limit on the number of items to be read.
+ *
+ * @param n
+ * The limit on the number of items to be read.
+ */
+ def setLimit(n: Int) = {
+ limit = Some(n)
+ }
+
/**
* Returns metadata about the data to be read.
*
@@ -101,7 +120,8 @@ class StacScan(stacCollectionJson: String, opts:
Map[String, String])
override def getMetaData(): Map[String, String] = {
Map(
"PushedSpatialFilters" ->
spatialFilter.map(_.toString).getOrElse("None"),
- "PushedTemporalFilters" ->
temporalFilter.map(_.toString).getOrElse("None"))
+ "PushedTemporalFilters" ->
temporalFilter.map(_.toString).getOrElse("None"),
+ "Pushedimit" -> limit.map(_.toString).getOrElse("None"))
}
/**
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScanBuilder.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScanBuilder.scala
index ebaab87dda..0fcbde9b4d 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScanBuilder.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacScanBuilder.scala
@@ -18,7 +18,9 @@
*/
package org.apache.spark.sql.sedona_sql.io.stac
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.connector.read.{Scan, ScanBuilder}
+import org.apache.spark.util.SerializableConfiguration
/**
* The `StacScanBuilder` class represents the builder for creating a `Scan`
instance in the
@@ -28,7 +30,11 @@ import org.apache.spark.sql.connector.read.{Scan,
ScanBuilder}
* bridge between Spark's data source API and the specific implementation of
the STAC data read
* operation.
*/
-class StacScanBuilder(stacCollectionJson: String, opts: Map[String, String])
extends ScanBuilder {
+class StacScanBuilder(
+ stacCollectionJson: String,
+ opts: Map[String, String],
+ broadcastConf: Broadcast[SerializableConfiguration])
+ extends ScanBuilder {
/**
* Builds and returns a `Scan` instance. The `Scan` defines the schema and
batch reading methods
@@ -37,5 +43,5 @@ class StacScanBuilder(stacCollectionJson: String, opts:
Map[String, String]) ext
* @return
* A `Scan` instance that defines how to read STAC data.
*/
- override def build(): Scan = new StacScan(stacCollectionJson, opts)
+ override def build(): Scan = new StacScan(stacCollectionJson, opts,
broadcastConf)
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala
index bd536f6de6..ca5f32663c 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacTable.scala
@@ -18,6 +18,7 @@
*/
package org.apache.spark.sql.sedona_sql.io.stac
+import org.apache.spark.broadcast.Broadcast
import org.apache.spark.sql.connector.catalog.{SupportsRead, Table,
TableCapability}
import org.apache.spark.sql.connector.read.ScanBuilder
import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
@@ -25,6 +26,7 @@ import org.apache.spark.sql.sedona_sql.io.geojson.GeoJSONUtils
import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.{inferStacSchema,
updatePropertiesPromotedSchema}
import org.apache.spark.sql.types._
import org.apache.spark.sql.util.CaseInsensitiveStringMap
+import org.apache.spark.util.SerializableConfiguration
import java.util.concurrent.ConcurrentHashMap
@@ -38,7 +40,10 @@ import java.util.concurrent.ConcurrentHashMap
* @constructor
* Creates a new instance of the `StacTable` class.
*/
-class StacTable(stacCollectionJson: String, opts: Map[String, String])
+class StacTable(
+ stacCollectionJson: String,
+ opts: Map[String, String],
+ broadcastConf: Broadcast[SerializableConfiguration])
extends Table
with SupportsRead {
@@ -84,7 +89,7 @@ class StacTable(stacCollectionJson: String, opts: Map[String,
String])
* A new instance of ScanBuilder.
*/
override def newScanBuilder(options: CaseInsensitiveStringMap): ScanBuilder =
- new StacScanBuilder(stacCollectionJson, opts)
+ new StacScanBuilder(stacCollectionJson, opts, broadcastConf)
}
object StacTable {
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
index 508d6986b5..95c86c24c0 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
@@ -22,9 +22,13 @@ import com.fasterxml.jackson.databind.ObjectMapper
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
-import org.apache.spark.sql.catalyst.util.ArrayBasedMapData
-import org.apache.spark.sql.types.{MapType, StringType, StructField,
StructType}
+import org.apache.spark.sql.execution.datasource.stac.TemporalFilter
+import
org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
+import org.apache.spark.sql.types.{StructField, StructType}
+import org.locationtech.jts.geom.Envelope
+import java.time.LocalDateTime
+import java.time.format.DateTimeFormatter
import scala.io.Source
object StacUtils {
@@ -133,9 +137,6 @@ object StacUtils {
}
}
- /**
- * Promote the properties field to the top level of the row.
- */
def promotePropertiesToTop(row: InternalRow, schema: StructType):
InternalRow = {
val propertiesIndex = schema.fieldIndex("properties")
val propertiesStruct =
schema("properties").dataType.asInstanceOf[StructType]
@@ -167,47 +168,6 @@ object StacUtils {
StructType(newFields)
}
- /**
- * Builds the output row with the raster field in the assets map.
- *
- * @param row
- * The input row.
- * @param schema
- * The schema of the input row.
- * @return
- * The output row with the raster field in the assets map.
- */
- def buildOutDbRasterFields(row: InternalRow, schema: StructType):
InternalRow = {
- val newValues = new Array[Any](schema.fields.length)
-
- schema.fields.zipWithIndex.foreach {
- case (StructField("assets", MapType(StringType, valueType: StructType,
_), _, _), index) =>
- val assetsMap = row.getMap(index)
- if (assetsMap != null) {
- val updatedAssets = assetsMap
- .keyArray()
- .array
- .zip(assetsMap.valueArray().array)
- .map { case (key, value) =>
- val assetRow = value.asInstanceOf[InternalRow]
- if (assetRow != null) {
- key -> assetRow
- } else {
- key -> null
- }
- }
- .toMap
- newValues(index) = ArrayBasedMapData(updatedAssets)
- } else {
- newValues(index) = null
- }
- case (_, index) =>
- newValues(index) = row.get(index, schema.fields(index).dataType)
- }
-
- InternalRow.fromSeq(newValues)
- }
-
/**
* Returns the number of partitions to use for reading the data.
*
@@ -241,4 +201,86 @@ object StacUtils {
Math.max(1, Math.ceil(itemCount.toDouble / maxSplitFiles).toInt)
}
}
+
+ /** Returns the temporal filter string based on the temporal filter. */
+ def getFilterBBox(filter: GeoParquetSpatialFilter): String = {
+ def calculateUnionBBox(filter: GeoParquetSpatialFilter): Envelope = {
+ filter match {
+ case GeoParquetSpatialFilter.AndFilter(left, right) =>
+ val leftEnvelope = calculateUnionBBox(left)
+ val rightEnvelope = calculateUnionBBox(right)
+ leftEnvelope.expandToInclude(rightEnvelope)
+ leftEnvelope
+ case GeoParquetSpatialFilter.OrFilter(left, right) =>
+ val leftEnvelope = calculateUnionBBox(left)
+ val rightEnvelope = calculateUnionBBox(right)
+ leftEnvelope.expandToInclude(rightEnvelope)
+ leftEnvelope
+ case leaf: GeoParquetSpatialFilter.LeafFilter =>
+ leaf.queryWindow.getEnvelopeInternal
+ }
+ }
+
+ val unionEnvelope = calculateUnionBBox(filter)
+
s"bbox=${unionEnvelope.getMinX}%2C${unionEnvelope.getMinY}%2C${unionEnvelope.getMaxX}%2C${unionEnvelope.getMaxY}"
+ }
+
+ /** Returns the temporal filter string based on the temporal filter. */
+ def getFilterTemporal(filter: TemporalFilter): String = {
+ val formatter = DateTimeFormatter.ofPattern("yyyy-MM-dd'T'HH:mm:ss.SSS'Z'")
+
+ def formatDateTime(dateTime: LocalDateTime): String = {
+ if (dateTime == null) ".." else dateTime.format(formatter)
+ }
+
+ def calculateUnionTemporal(filter: TemporalFilter): (LocalDateTime,
LocalDateTime) = {
+ filter match {
+ case TemporalFilter.AndFilter(left, right) =>
+ val (leftStart, leftEnd) = calculateUnionTemporal(left)
+ val (rightStart, rightEnd) = calculateUnionTemporal(right)
+ val start =
+ if (leftStart == null || (rightStart != null &&
rightStart.isBefore(leftStart)))
+ rightStart
+ else leftStart
+ val end =
+ if (leftEnd == null || (rightEnd != null &&
rightEnd.isAfter(leftEnd))) rightEnd
+ else leftEnd
+ (start, end)
+ case TemporalFilter.OrFilter(left, right) =>
+ val (leftStart, leftEnd) = calculateUnionTemporal(left)
+ val (rightStart, rightEnd) = calculateUnionTemporal(right)
+ val start =
+ if (leftStart == null || (rightStart != null &&
rightStart.isBefore(leftStart)))
+ rightStart
+ else leftStart
+ val end =
+ if (leftEnd == null || (rightEnd != null &&
rightEnd.isAfter(leftEnd))) rightEnd
+ else leftEnd
+ (start, end)
+ case TemporalFilter.LessThanFilter(_, value) =>
+ (null, value)
+ case TemporalFilter.GreaterThanFilter(_, value) =>
+ (value, null)
+ case TemporalFilter.EqualFilter(_, value) =>
+ (value, value)
+ }
+ }
+
+ val (start, end) = calculateUnionTemporal(filter)
+ if (end == null) s"datetime=${formatDateTime(start)}/.."
+ else s"datetime=${formatDateTime(start)}/${formatDateTime(end)}"
+ }
+
+ /** Adds the spatial and temporal filters to the base URL. */
+ def addFiltersToUrl(
+ baseUrl: String,
+ spatialFilter: Option[GeoParquetSpatialFilter],
+ temporalFilter: Option[TemporalFilter]): String = {
+ val spatialFilterStr =
spatialFilter.map(StacUtils.getFilterBBox).getOrElse("")
+ val temporalFilterStr =
temporalFilter.map(StacUtils.getFilterTemporal).getOrElse("")
+
+ val filters = Seq(spatialFilterStr,
temporalFilterStr).filter(_.nonEmpty).mkString("&")
+ val urlWithFilters = if (filters.nonEmpty) s"&$filters" else ""
+ s"$baseUrl$urlWithFilters"
+ }
}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialTemporalFilterPushDownForStacScan.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialTemporalFilterPushDownForStacScan.scala
index 566d368d69..c32af552b3 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialTemporalFilterPushDownForStacScan.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/optimization/SpatialTemporalFilterPushDownForStacScan.scala
@@ -20,7 +20,7 @@ package org.apache.spark.sql.sedona_sql.optimization
import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.expressions.{And, EqualTo, Expression,
GreaterThan, GreaterThanOrEqual, LessThan, LessThanOrEqual, Literal, Or,
SubqueryExpression}
-import org.apache.spark.sql.catalyst.plans.logical.{Filter, LogicalPlan}
+import org.apache.spark.sql.catalyst.plans.logical.{Filter, GlobalLimit,
LocalLimit, LogicalPlan}
import
org.apache.spark.sql.connector.catalog.CatalogV2Implicits.parseColumnPath
import org.apache.spark.sql.execution.datasource.stac.TemporalFilter
import
org.apache.spark.sql.execution.datasource.stac.TemporalFilter.{AndFilter =>
TemporalAndFilter}
@@ -79,10 +79,25 @@ class
SpatialTemporalFilterPushDownForStacScan(sparkSession: SparkSession)
filter.copy()
}
filter.copy()
+ case lr: DataSourceV2ScanRelation if isStacScanRelation(lr) =>
+ val scan = lr.scan.asInstanceOf[StacScan]
+ val limit = extractLimit(plan)
+ limit match {
+ case Some(n) => scan.setLimit(n)
+ case None =>
+ }
+ lr
}
}
}
+ def extractLimit(plan: LogicalPlan): Option[Int] = {
+ plan.collectFirst {
+ case GlobalLimit(Literal(limit: Int, _), _) => limit
+ case LocalLimit(Literal(limit: Int, _), _) => limit
+ }
+ }
+
private def isStacScanRelation(lr: DataSourceV2ScanRelation): Boolean =
lr.scan.isInstanceOf[StacScan]
diff --git
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatchTest.scala
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatchTest.scala
index 0765b3950f..612ce2bb52 100644
---
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatchTest.scala
+++
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatchTest.scala
@@ -22,6 +22,8 @@ import org.apache.sedona.sql.TestBaseScala
import org.apache.spark.sql.connector.read.InputPartition
import org.apache.spark.sql.types.StructType
+import java.time.format.DateTimeFormatter
+import java.time.{LocalDate, ZoneOffset}
import scala.io.Source
import scala.collection.mutable
@@ -40,6 +42,41 @@ class StacBatchTest extends TestBaseScala {
}
}
+ it("collectItemLinks should collect correct item links") {
+ val collectionUrl =
+
"https://earth-search.aws.element84.com/v1/collections/sentinel-2-pre-c1-l2a"
+ val stacCollectionJson = StacUtils.loadStacCollectionToJson(collectionUrl)
+ val opts = mutable
+ .Map(
+ "itemsLimitMax" -> "1000",
+ "itemsLimitPerRequest" -> "200",
+ "itemsLoadProcessReportThreshold" -> "1000000")
+ .toMap
+
+ val stacBatch =
+ StacBatch(
+ null,
+ collectionUrl,
+ stacCollectionJson,
+ StructType(Seq()),
+ opts,
+ None,
+ None,
+ None)
+ stacBatch.setItemMaxLeft(1000)
+ val itemLinks = mutable.ArrayBuffer[String]()
+ val needCountNextItems = true
+
+ val startTime = System.nanoTime()
+ stacBatch.collectItemLinks(collectionUrl, stacCollectionJson, itemLinks,
needCountNextItems)
+ val endTime = System.nanoTime()
+ val duration = (endTime - startTime) / 1e6 // Convert to milliseconds
+
+ assert(itemLinks.nonEmpty)
+ assert(itemLinks.length == 5)
+ assert(duration > 0)
+ }
+
it("planInputPartitions should create correct number of partitions") {
val stacCollectionJson =
"""
@@ -48,18 +85,26 @@ class StacBatchTest extends TestBaseScala {
| "id": "sample-collection",
| "description": "A sample STAC collection",
| "links": [
- | {"rel": "item", "href": "https://path/to/item1.json"},
- | {"rel": "item", "href": "https://path/to/item2.json"},
- | {"rel": "item", "href": "https://path/to/item3.json"}
+ | {"rel": "item", "href":
"https://storage.googleapis.com/cfo-public/vegetation/California-Vegetation-CanopyBaseHeight-2016-Summer-00010m.json"},
+ | {"rel": "item", "href":
"https://storage.googleapis.com/cfo-public/vegetation/California-Vegetation-CanopyBaseHeight-2016-Summer-00010m.json"},
+ | {"rel": "item", "href":
"https://storage.googleapis.com/cfo-public/vegetation/California-Vegetation-CanopyBaseHeight-2016-Summer-00010m.json"}
| ]
|}
""".stripMargin
- val opts = mutable.Map("numPartitions" -> "2").toMap
- val collectionUrl = "https://path/to/collection.json"
+ val opts = mutable.Map("numPartitions" -> "2", "itemsLimitMax" ->
"20").toMap
+ val collectionUrl =
"https://storage.googleapis.com/cfo-public/vegetation/collection.json"
val stacBatch =
- StacBatch(collectionUrl, stacCollectionJson, StructType(Seq()), opts,
None, None)
+ StacBatch(
+ null,
+ collectionUrl,
+ stacCollectionJson,
+ StructType(Seq()),
+ opts,
+ None,
+ None,
+ None)
val partitions: Array[InputPartition] = stacBatch.planInputPartitions()
assert(partitions.length == 2)
@@ -75,11 +120,19 @@ class StacBatchTest extends TestBaseScala {
|}
""".stripMargin
- val opts = mutable.Map("numPartitions" -> "2").toMap
+ val opts = mutable.Map("numPartitions" -> "2", "itemsLimitMax" ->
"20").toMap
val collectionUrl = "https://path/to/collection.json"
val stacBatch =
- StacBatch(collectionUrl, stacCollectionJson, StructType(Seq()), opts,
None, None)
+ StacBatch(
+ null,
+ collectionUrl,
+ stacCollectionJson,
+ StructType(Seq()),
+ opts,
+ None,
+ None,
+ None)
val partitions: Array[InputPartition] = stacBatch.planInputPartitions()
assert(partitions.isEmpty)
@@ -88,11 +141,19 @@ class StacBatchTest extends TestBaseScala {
it("planInputPartitions should create correct number of partitions with real
collection.json") {
val rootJsonFile = "datasource_stac/collection.json"
val stacCollectionJson = loadJsonFromResource(rootJsonFile)
- val opts = mutable.Map("numPartitions" -> "3").toMap
+ val opts = mutable.Map("numPartitions" -> "3", "itemsLimitMax" ->
"20").toMap
val collectionUrl = getAbsolutePathOfResource(rootJsonFile)
val stacBatch =
- StacBatch(collectionUrl, stacCollectionJson, StructType(Seq()), opts,
None, None)
+ StacBatch(
+ null,
+ collectionUrl,
+ stacCollectionJson,
+ StructType(Seq()),
+ opts,
+ None,
+ None,
+ None)
val partitions: Array[InputPartition] = stacBatch.planInputPartitions()
assert(partitions.length == 3)
diff --git
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
index 2cab18b694..5603b3e551 100644
---
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
+++
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
@@ -19,8 +19,8 @@
package org.apache.spark.sql.sedona_sql.io.stac
import org.apache.sedona.sql.TestBaseScala
-import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
-import org.apache.spark.sql.types.{ArrayType, DoubleType, MapType, StringType,
StructField, StructType, TimestampType}
+import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
+import org.apache.spark.sql.types._
class StacDataSourceTest extends TestBaseScala {
@@ -88,7 +88,7 @@ class StacDataSourceTest extends TestBaseScala {
val dfSelect = sparkSession.sql(
"SELECT id, geometry " +
"FROM STACTBL " +
- "WHERE st_contains(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17
11, 17 10))'), geometry)")
+ "WHERE st_intersects(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17
11, 17 10))'), geometry)")
val physicalPlan = dfSelect.queryExecution.executedPlan.toString()
assert(physicalPlan.contains(
@@ -102,10 +102,11 @@ class StacDataSourceTest extends TestBaseScala {
val dfStac = sparkSession.read.format("stac").load(STAC_COLLECTION_LOCAL)
dfStac.createOrReplaceTempView("STACTBL")
- val dfSelect = sparkSession.sql("SELECT id, datetime as dt, geometry, bbox
" +
- "FROM STACTBL " +
- "WHERE datetime BETWEEN '2020-01-01T00:00:00Z' AND
'2020-12-13T00:00:00Z' " +
- "AND st_contains(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11,
17 10))'), geometry)")
+ val dfSelect = sparkSession.sql(
+ "SELECT id, datetime as dt, geometry, bbox " +
+ "FROM STACTBL " +
+ "WHERE datetime BETWEEN '2020-01-01T00:00:00Z' AND
'2020-12-13T00:00:00Z' " +
+ "AND st_intersects(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17
11, 17 10))'), geometry)")
val physicalPlan = dfSelect.queryExecution.executedPlan.toString()
assert(physicalPlan.contains(
@@ -137,11 +138,12 @@ class StacDataSourceTest extends TestBaseScala {
val dfStac = sparkSession.read.format("stac").load(STAC_COLLECTION_LOCAL)
dfStac.createOrReplaceTempView("STACTBL")
- val dfSelect = sparkSession.sql("SELECT id, datetime as dt, geometry, bbox
" +
- "FROM STACTBL " +
- "WHERE id = 'some-id' " +
- "AND datetime BETWEEN '2020-01-01T00:00:00Z' AND '2020-12-13T00:00:00Z'
" +
- "AND st_contains(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17 11,
17 10))'), geometry)")
+ val dfSelect = sparkSession.sql(
+ "SELECT id, datetime as dt, geometry, bbox " +
+ "FROM STACTBL " +
+ "WHERE id = 'some-id' " +
+ "AND datetime BETWEEN '2020-01-01T00:00:00Z' AND
'2020-12-13T00:00:00Z' " +
+ "AND st_intersects(ST_GeomFromText('POLYGON((17 10, 18 10, 18 11, 17
11, 17 10))'), geometry)")
val physicalPlan = dfSelect.queryExecution.executedPlan.toString()
assert(physicalPlan.contains(
diff --git
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReaderTest.scala
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReaderTest.scala
index fc6f8dcfdc..428327e7cc 100644
---
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReaderTest.scala
+++
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReaderTest.scala
@@ -18,8 +18,10 @@
*/
package org.apache.spark.sql.sedona_sql.io.stac
+import org.apache.hadoop.conf.Configuration
import org.apache.sedona.sql.TestBaseScala
import org.apache.spark.sql.catalyst.InternalRow
+import org.apache.spark.util.SerializableConfiguration
import scala.jdk.CollectionConverters._
@@ -40,6 +42,7 @@ class StacPartitionReaderTest extends TestBaseScala {
val partition = StacPartition(0, jsonFiles, Map.empty[String,
String].asJava)
val reader =
new StacPartitionReader(
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(new
Configuration())),
partition,
StacTable.SCHEMA_V1_1_0,
Map.empty[String, String],
@@ -61,6 +64,7 @@ class StacPartitionReaderTest extends TestBaseScala {
val partition = StacPartition(0, jsonFiles, Map.empty[String,
String].asJava)
val reader =
new StacPartitionReader(
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(new
Configuration())),
partition,
StacTable.SCHEMA_V1_1_0,
Map.empty[String, String],
@@ -82,6 +86,7 @@ class StacPartitionReaderTest extends TestBaseScala {
val partition = StacPartition(0, jsonFiles, Map.empty[String,
String].asJava)
val reader =
new StacPartitionReader(
+ sparkSession.sparkContext.broadcast(new SerializableConfiguration(new
Configuration())),
partition,
StacTable.SCHEMA_V1_1_0,
Map.empty[String, String],
diff --git
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala
index 75542c760d..d2940abf41 100644
---
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala
+++
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala
@@ -21,16 +21,25 @@ package org.apache.spark.sql.sedona_sql.io.stac
import com.fasterxml.jackson.databind.{JsonNode, ObjectMapper}
import com.fasterxml.jackson.module.scala.DefaultScalaModule
import org.apache.hadoop.fs.{FileSystem, Path}
+import org.apache.sedona.core.spatialOperator.SpatialPredicate
import org.apache.spark.sql.SparkSession
-import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.getNumPartitions
+import org.apache.spark.sql.execution.datasource.stac.TemporalFilter
+import
org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
+import org.apache.spark.sql.sedona_sql.io.stac.StacUtils.{getFilterBBox,
getFilterTemporal, getNumPartitions}
+import org.locationtech.jts.geom.{Envelope, GeometryFactory, Polygon}
+import org.locationtech.jts.io.WKTReader
import org.scalatest.funsuite.AnyFunSuite
import java.io.{File, PrintWriter}
+import java.time.LocalDateTime
import scala.io.Source
import scala.jdk.CollectionConverters._
class StacUtilsTest extends AnyFunSuite {
+ val geometryFactory = new GeometryFactory()
+ val wktReader = new WKTReader(geometryFactory)
+
test("getStacCollectionBasePath should return base URL for HTTP URL") {
val opts = Map("path" -> "https://service_url/collections/collection.json")
val result = StacUtils.getStacCollectionBasePath(opts)
@@ -591,4 +600,121 @@ class StacUtilsTest extends AnyFunSuite {
writer.close()
}
}
+
+ test("getFilterBBox with LeafFilter") {
+ val queryWindow: Polygon =
+ wktReader.read("POLYGON((10 10, 20 10, 20 20, 10 20, 10
10))").asInstanceOf[Polygon]
+ val leafFilter =
+ GeoParquetSpatialFilter.LeafFilter("geometry",
SpatialPredicate.INTERSECTS, queryWindow)
+ val bbox = getFilterBBox(leafFilter)
+ assert(bbox == "bbox=10.0%2C10.0%2C20.0%2C20.0")
+ }
+
+ test("getFilterBBox with AndFilter") {
+ val queryWindow1: Polygon =
+ wktReader.read("POLYGON((10 10, 20 10, 20 20, 10 20, 10
10))").asInstanceOf[Polygon]
+ val queryWindow2: Polygon =
+ wktReader.read("POLYGON((30 30, 40 30, 40 40, 30 40, 30
30))").asInstanceOf[Polygon]
+ val leafFilter1 =
+ GeoParquetSpatialFilter.LeafFilter("geometry",
SpatialPredicate.INTERSECTS, queryWindow1)
+ val leafFilter2 =
+ GeoParquetSpatialFilter.LeafFilter("geometry",
SpatialPredicate.INTERSECTS, queryWindow2)
+ val andFilter = GeoParquetSpatialFilter.AndFilter(leafFilter1, leafFilter2)
+ val bbox = getFilterBBox(andFilter)
+ assert(bbox == "bbox=10.0%2C10.0%2C40.0%2C40.0")
+ }
+
+ test("getFilterBBox with OrFilter") {
+ val queryWindow1: Polygon =
+ wktReader.read("POLYGON((10 10, 20 10, 20 20, 10 20, 10
10))").asInstanceOf[Polygon]
+ val queryWindow2: Polygon =
+ wktReader.read("POLYGON((30 30, 40 30, 40 40, 30 40, 30
30))").asInstanceOf[Polygon]
+ val leafFilter1 =
+ GeoParquetSpatialFilter.LeafFilter("geometry",
SpatialPredicate.INTERSECTS, queryWindow1)
+ val leafFilter2 =
+ GeoParquetSpatialFilter.LeafFilter("geometry",
SpatialPredicate.INTERSECTS, queryWindow2)
+ val orFilter = GeoParquetSpatialFilter.OrFilter(leafFilter1, leafFilter2)
+ val bbox = getFilterBBox(orFilter)
+ assert(bbox == "bbox=10.0%2C10.0%2C40.0%2C40.0")
+ }
+
+ test("getFilterTemporal with LessThanFilter") {
+ val dateTime = LocalDateTime.parse("2025-03-07T00:00:00")
+ val filter = TemporalFilter.LessThanFilter("timestamp", dateTime)
+ val result = getFilterTemporal(filter)
+ assert(result == "datetime=../2025-03-07T00:00:00.000Z")
+ }
+
+ test("getFilterTemporal with GreaterThanFilter") {
+ val dateTime = LocalDateTime.parse("2025-03-06T00:00:00")
+ val filter = TemporalFilter.GreaterThanFilter("timestamp", dateTime)
+ val result = getFilterTemporal(filter)
+ assert(result == "datetime=2025-03-06T00:00:00.000Z/..")
+ }
+
+ test("getFilterTemporal with EqualFilter") {
+ val dateTime = LocalDateTime.parse("2025-03-06T00:00:00")
+ val filter = TemporalFilter.EqualFilter("timestamp", dateTime)
+ val result = getFilterTemporal(filter)
+ assert(result ==
"datetime=2025-03-06T00:00:00.000Z/2025-03-06T00:00:00.000Z")
+ }
+
+ test("getFilterTemporal with AndFilter") {
+ val dateTime1 = LocalDateTime.parse("2025-03-06T00:00:00")
+ val dateTime2 = LocalDateTime.parse("2025-03-07T00:00:00")
+ val filter1 = TemporalFilter.GreaterThanFilter("timestamp", dateTime1)
+ val filter2 = TemporalFilter.LessThanFilter("timestamp", dateTime2)
+ val andFilter = TemporalFilter.AndFilter(filter1, filter2)
+ val result = getFilterTemporal(andFilter)
+ assert(result ==
"datetime=2025-03-06T00:00:00.000Z/2025-03-07T00:00:00.000Z")
+ }
+
+ test("getFilterTemporal with OrFilter") {
+ val dateTime1 = LocalDateTime.parse("2025-03-06T00:00:00")
+ val dateTime2 = LocalDateTime.parse("2025-03-07T00:00:00")
+ val filter1 = TemporalFilter.GreaterThanFilter("timestamp", dateTime1)
+ val filter2 = TemporalFilter.LessThanFilter("timestamp", dateTime2)
+ val orFilter = TemporalFilter.OrFilter(filter1, filter2)
+ val result = getFilterTemporal(orFilter)
+ assert(result ==
"datetime=2025-03-06T00:00:00.000Z/2025-03-07T00:00:00.000Z")
+ }
+
+ test("addFiltersToUrl with no filters") {
+ val baseUrl = "http://example.com/stac"
+ val result = StacUtils.addFiltersToUrl(baseUrl, None, None)
+ assert(result == "http://example.com/stac")
+ }
+
+ test("addFiltersToUrl with spatial filter") {
+ val baseUrl = "http://example.com/stac"
+ val envelope = new Envelope(1.0, 2.0, 3.0, 4.0)
+ val queryWindow = geometryFactory.toGeometry(envelope)
+ val spatialFilter = Some(
+ GeoParquetSpatialFilter.LeafFilter("geometry",
SpatialPredicate.INTERSECTS, queryWindow))
+ val result = StacUtils.addFiltersToUrl(baseUrl, spatialFilter, None)
+ val expectedUrl = s"$baseUrl&bbox=1.0%2C3.0%2C2.0%2C4.0"
+ assert(result == expectedUrl)
+ }
+
+ test("addFiltersToUrl with temporal filter") {
+ val baseUrl = "http://example.com/stac"
+ val temporalFilter = Some(
+ TemporalFilter.GreaterThanFilter("timestamp",
LocalDateTime.parse("2025-03-06T00:00:00")))
+ val result = StacUtils.addFiltersToUrl(baseUrl, None, temporalFilter)
+ val expectedUrl = s"$baseUrl&datetime=2025-03-06T00:00:00.000Z/.."
+ assert(result == expectedUrl)
+ }
+
+ test("addFiltersToUrl with both spatial and temporal filters") {
+ val baseUrl = "http://example.com/stac"
+ val envelope = new Envelope(1.0, 2.0, 3.0, 4.0)
+ val queryWindow = geometryFactory.toGeometry(envelope)
+ val spatialFilter = Some(
+ GeoParquetSpatialFilter.LeafFilter("geometry",
SpatialPredicate.INTERSECTS, queryWindow))
+ val temporalFilter = Some(
+ TemporalFilter.GreaterThanFilter("timestamp",
LocalDateTime.parse("2025-03-06T00:00:00")))
+ val result = StacUtils.addFiltersToUrl(baseUrl, spatialFilter,
temporalFilter)
+ val expectedUrl =
s"$baseUrl&bbox=1.0%2C3.0%2C2.0%2C4.0&datetime=2025-03-06T00:00:00.000Z/.."
+ assert(result == expectedUrl)
+ }
}