This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 718fa5b57a [SEDONA-704] Add the grid extension to the stac reader
(#1883)
718fa5b57a is described below
commit 718fa5b57a0a1c10eb06f2a0358ea3809ffec373
Author: Feng Zhang <[email protected]>
AuthorDate: Thu Mar 27 21:56:44 2025 -0700
[SEDONA-704] Add the grid extension to the stac reader (#1883)
* [SEDONA-704] Add the grid extension to the stac reader
* In Scala 2.13, the standard library collections were restructured, and
scala.collection.Seq is now different from Seq.
* Update
spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
---------
Co-authored-by: Jia Yu <[email protected]>
---
docs/tutorial/files/stac-sedona-spark.md | 1 +
.../sql/sedona_sql/io/stac/StacExtension.scala | 57 ++++++++++
.../spark/sql/sedona_sql/io/stac/StacUtils.scala | 69 +++++++++++-
.../resources/datasource_stac/extended-item.json | 3 +-
.../sedona_sql/io/stac/StacDataSourceTest.scala | 120 ++++++++++++++-------
5 files changed, 206 insertions(+), 44 deletions(-)
diff --git a/docs/tutorial/files/stac-sedona-spark.md
b/docs/tutorial/files/stac-sedona-spark.md
index aff0e29f0f..864b6ef9d6 100644
--- a/docs/tutorial/files/stac-sedona-spark.md
+++ b/docs/tutorial/files/stac-sedona-spark.md
@@ -77,6 +77,7 @@ root
| |-- element: string (containsNull = true)
|-- constellation: string (nullable = true)
|-- mission: string (nullable = true)
+ |-- grid:code: string (nullable = true)
|-- gsd: double (nullable = true)
|-- collection: string (nullable = true)
|-- links: array (nullable = true)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala
new file mode 100644
index 0000000000..c1f78268ac
--- /dev/null
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements. See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership. The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License. You may obtain a copy of the License at
+ *
+ * http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied. See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.io.stac
+
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+/**
+ * Defines a STAC extension with its schema and property mappings
+ *
+ * @param name
+ * The name of the STAC extension (e.g., "eo", "proj")
+ * @param schema
+ * The schema for the extension properties
+ * @param propertyMappings
+ * Mapping from original property path to top-level property name
+ */
+case class StacExtension(name: String, schema: StructType)
+
+object StacExtension {
+
+ /**
+ * Returns an array of STAC extension definitions, each containing:
+ * - Extension name
+ * - Extension schema
+ * - Mapping from original property paths to promoted top-level property
names
+ *
+ * @return
+ * Array of StacExtension definitions
+ */
+ def getStacExtensionDefinitions(): Array[StacExtension] = {
+ Array(
+ // Grid extension -
https://stac-extensions.github.io/grid/v1.1.0/schema.json
+ StacExtension(
+ name = "grid",
+ // Schema for the grid extension, add all required fields here
+ schema = StructType(Seq(StructField("grid:code", StringType, nullable
= true))))
+
+ // Add other extensions here...
+ )
+ }
+}
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
index 95c86c24c0..cfb0d33abc 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.SparkSession
import org.apache.spark.sql.catalyst.InternalRow
import org.apache.spark.sql.execution.datasource.stac.TemporalFilter
import
org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}
import org.locationtech.jts.geom.Envelope
import java.time.LocalDateTime
@@ -129,14 +129,79 @@ object StacUtils {
val stacVersion = collection.get("stac_version").asText()
// Return the corresponding schema based on the stac_version
- stacVersion match {
+ val coreSchema = stacVersion match {
case "1.0.0" => StacTable.SCHEMA_V1_0_0
case version if version.matches("1\\.[1-9]\\d*\\.\\d*") =>
StacTable.SCHEMA_V1_1_0
// Add more cases here for other versions if needed
case _ => throw new IllegalArgumentException(s"Unsupported STAC version:
$stacVersion")
}
+ val extensions = StacExtension.getStacExtensionDefinitions()
+ val schemaWithExtensions = addExtensionFieldsToSchema(coreSchema,
extensions)
+ schemaWithExtensions
}
+ /**
+ * Adds STAC extension fields to the properties field in the schema
+ *
+ * @param schema
+ * The base STAC schema to enhance
+ * @param extensions
+ * Array of STAC extension definitions
+ * @return
+ * Enhanced schema with extension fields added to the properties struct
+ */
+ def addExtensionFieldsToSchema(
+ schema: StructType,
+ extensions: Array[StacExtension]): StructType = {
+ // Find the properties field in the schema
+ val propertiesFieldOpt = schema.fields.find(_.name == "properties")
+
+ if (propertiesFieldOpt.isEmpty) {
+ // If there's no properties field, return the original schema
+ return schema
+ }
+
+ // Get the properties field and its struct type
+ val propertiesField = propertiesFieldOpt.get
+ val propertiesStruct = propertiesField.dataType.asInstanceOf[StructType]
+
+ // Create extension fields with metadata indicating their source
+ val extensionFields = extensions.flatMap { extension =>
+ extension.schema.fields.map { field =>
+ StructField(
+ field.name,
+ field.dataType,
+ field.nullable,
+ new MetadataBuilder()
+ .withMetadata(field.metadata)
+ .putString("stac_extension", extension.name)
+ .build())
+ }
+ }
+
+ // Create a new properties struct that includes the extension fields
+ val updatedPropertiesStruct = StructType(propertiesStruct.fields ++
extensionFields)
+
+ // Create a new properties field with the updated struct
+ val updatedPropertiesField = StructField(
+ propertiesField.name,
+ updatedPropertiesStruct,
+ propertiesField.nullable,
+ propertiesField.metadata)
+
+ // Replace the properties field in the schema
+ val updatedFields = schema.fields.map {
+ case field if field.name == "properties" => updatedPropertiesField
+ case field => field
+ }
+
+ // Return the schema with the updated properties field
+ StructType(updatedFields)
+ }
+
+ /**
+ * Promote the properties field to the top level of the row.
+ */
def promotePropertiesToTop(row: InternalRow, schema: StructType):
InternalRow = {
val propertiesIndex = schema.fieldIndex("properties")
val propertiesStruct =
schema("properties").dataType.asInstanceOf[StructType]
diff --git a/spark/common/src/test/resources/datasource_stac/extended-item.json
b/spark/common/src/test/resources/datasource_stac/extended-item.json
index b5f3a0a9df..14857da6b1 100644
--- a/spark/common/src/test/resources/datasource_stac/extended-item.json
+++ b/spark/common/src/test/resources/datasource_stac/extended-item.json
@@ -58,6 +58,7 @@
"cool_sensor_v2"
],
"gsd": 0.66,
+ "grid:code": "MSIN-2506",
"eo:cloud_cover": 1.2,
"eo:snow_cover": 0,
"statistics": {
@@ -207,4 +208,4 @@
"title": "Satellite Ephemeris Metadata"
}
}
-}
\ No newline at end of file
+}
diff --git
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
index 5603b3e551..da9ff573e3 100644
---
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
+++
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
@@ -19,7 +19,7 @@
package org.apache.spark.sql.sedona_sql.io.stac
import org.apache.sedona.sql.TestBaseScala
-import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
import org.apache.spark.sql.types._
class StacDataSourceTest extends TestBaseScala {
@@ -41,6 +41,14 @@ class StacDataSourceTest extends TestBaseScala {
assert(rowCount > 0)
}
+ it("basic df load from local file with extensions should work") {
+ val dfStac = sparkSession.read.format("stac").load(STAC_COLLECTION_LOCAL)
+ // Filter rows where grid:code equals "MSIN-2506"
+ val filteredDf = dfStac.filter(dfStac.col("grid:code") === "MSIN-2506")
+ val rowCount = filteredDf.count()
+ assert(rowCount > 0)
+ }
+
it("basic df load from remote service endpoints should work") {
STAC_COLLECTION_REMOTE.foreach { endpoint =>
val dfStac = sparkSession.read.format("stac").load(endpoint)
@@ -156,54 +164,84 @@ class StacDataSourceTest extends TestBaseScala {
}
def assertSchema(actualSchema: StructType): Unit = {
- val expectedSchema = StructType(
- Seq(
- StructField("stac_version", StringType, nullable = false),
- StructField(
- "stac_extensions",
- ArrayType(StringType, containsNull = true),
- nullable = true),
- StructField("type", StringType, nullable = false),
- StructField("id", StringType, nullable = false),
- StructField("bbox", ArrayType(DoubleType, containsNull = true),
nullable = true),
- StructField("geometry", new GeometryUDT(), nullable = true),
- StructField("title", StringType, nullable = true),
- StructField("description", StringType, nullable = true),
- StructField("datetime", TimestampType, nullable = true),
- StructField("start_datetime", TimestampType, nullable = true),
- StructField("end_datetime", TimestampType, nullable = true),
- StructField("created", TimestampType, nullable = true),
- StructField("updated", TimestampType, nullable = true),
- StructField("platform", StringType, nullable = true),
- StructField("instruments", ArrayType(StringType, containsNull = true),
nullable = true),
- StructField("constellation", StringType, nullable = true),
- StructField("mission", StringType, nullable = true),
- StructField("gsd", DoubleType, nullable = true),
- StructField("collection", StringType, nullable = true),
- StructField(
- "links",
- ArrayType(
- StructType(Seq(
+ // Base STAC fields that should always be present
+ val baseFields = Seq(
+ StructField("stac_version", StringType, nullable = false),
+ StructField("stac_extensions", ArrayType(StringType, containsNull =
true), nullable = true),
+ StructField("type", StringType, nullable = false),
+ StructField("id", StringType, nullable = false),
+ StructField("bbox", ArrayType(DoubleType, containsNull = true), nullable
= true),
+ StructField("geometry", new GeometryUDT(), nullable = true),
+ StructField("title", StringType, nullable = true),
+ StructField("description", StringType, nullable = true),
+ StructField("datetime", TimestampType, nullable = true),
+ StructField("start_datetime", TimestampType, nullable = true),
+ StructField("end_datetime", TimestampType, nullable = true),
+ StructField("created", TimestampType, nullable = true),
+ StructField("updated", TimestampType, nullable = true),
+ StructField("platform", StringType, nullable = true),
+ StructField("instruments", ArrayType(StringType, containsNull = true),
nullable = true),
+ StructField("constellation", StringType, nullable = true),
+ StructField("mission", StringType, nullable = true),
+ StructField("gsd", DoubleType, nullable = true),
+ StructField("collection", StringType, nullable = true),
+ StructField(
+ "links",
+ ArrayType(
+ StructType(
+ Seq(
StructField("rel", StringType, nullable = true),
StructField("href", StringType, nullable = true),
StructField("type", StringType, nullable = true),
StructField("title", StringType, nullable = true))),
- containsNull = true),
- nullable = true),
- StructField(
- "assets",
- MapType(
- StringType,
- StructType(Seq(
+ containsNull = true),
+ nullable = true),
+ StructField(
+ "assets",
+ MapType(
+ StringType,
+ StructType(
+ Seq(
StructField("href", StringType, nullable = true),
StructField("type", StringType, nullable = true),
StructField("title", StringType, nullable = true),
StructField("roles", ArrayType(StringType, containsNull = true),
nullable = true))),
- valueContainsNull = true),
- nullable = true)))
+ valueContainsNull = true),
+ nullable = true))
+
+ // Extension fields that may be present
+ val extensionFields = Seq(
+ // Grid extension fields
+ StructField("grid:code", StringType, nullable = true)
+ // Add other extension fields as needed
+ )
+
+ // Check that all base fields are present with correct types
+ baseFields.foreach { expectedField =>
+ val actualField = actualSchema.fields.find(_.name == expectedField.name)
+ assert(actualField.isDefined, s"Required field ${expectedField.name} not
found in schema")
+ assert(
+ actualField.get.dataType == expectedField.dataType,
+ s"Field ${expectedField.name} has wrong type. Expected:
${expectedField.dataType}, Actual: ${actualField.get.dataType}")
+ assert(
+ actualField.get.nullable == expectedField.nullable,
+ s"Field ${expectedField.name} has wrong nullability. Expected:
${expectedField.nullable}, Actual: ${actualField.get.nullable}")
+ }
+
+ // Check extension fields if they are present
+ val actualFieldNames = actualSchema.fields.map(_.name).toSet
+ extensionFields.foreach { extensionField =>
+ if (actualFieldNames.contains(extensionField.name)) {
+ val actualField = actualSchema.fields.find(_.name ==
extensionField.name).get
+ assert(
+ actualField.dataType == extensionField.dataType,
+ s"Extension field ${extensionField.name} has wrong type. Expected:
${extensionField.dataType}, Actual: ${actualField.dataType}")
+ }
+ }
- assert(
- actualSchema == expectedSchema,
- s"Schema does not match. Expected: $expectedSchema, Actual:
$actualSchema")
+ // Check that there are no unexpected fields
+ val expectedFieldNames = (baseFields ++ extensionFields).map(_.name).toSet
+ val unexpectedFields = actualFieldNames.diff(expectedFieldNames)
+ assert(unexpectedFields.isEmpty, s"Schema contains unexpected fields:
$unexpectedFields")
}
}