This is an automated email from the ASF dual-hosted git repository. jiayu pushed a commit to branch prepare-1.7.2 in repository https://gitbox.apache.org/repos/asf/sedona.git
commit 173f28ce96ac7e9cdc687984d4604b40d92f6ee8 Author: Feng Zhang <[email protected]> AuthorDate: Thu Mar 27 21:56:44 2025 -0700 [SEDONA-704] Add the grid extension to the stac reader (#1883) * [SEDONA-704] Add the grid extension to the stac reader * In Scala 2.13, the standard library collections were restructured, and scala.collection.Seq is now different from Seq. * Update spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala --------- Co-authored-by: Jia Yu <[email protected]> --- docs/tutorial/files/stac-sedona-spark.md | 1 + .../sql/sedona_sql/io/stac/StacExtension.scala | 57 ++++++++++ .../spark/sql/sedona_sql/io/stac/StacUtils.scala | 69 +++++++++++- .../resources/datasource_stac/extended-item.json | 3 +- .../sedona_sql/io/stac/StacDataSourceTest.scala | 120 ++++++++++++++------- 5 files changed, 206 insertions(+), 44 deletions(-) diff --git a/docs/tutorial/files/stac-sedona-spark.md b/docs/tutorial/files/stac-sedona-spark.md index aff0e29f0f..864b6ef9d6 100644 --- a/docs/tutorial/files/stac-sedona-spark.md +++ b/docs/tutorial/files/stac-sedona-spark.md @@ -77,6 +77,7 @@ root | |-- element: string (containsNull = true) |-- constellation: string (nullable = true) |-- mission: string (nullable = true) + |-- grid:code: string (nullable = true) |-- gsd: double (nullable = true) |-- collection: string (nullable = true) |-- links: array (nullable = true) diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala new file mode 100644 index 0000000000..c1f78268ac --- /dev/null +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala @@ -0,0 +1,57 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one + * or more contributor license agreements. See the NOTICE file + * distributed with this work for additional information + * regarding copyright ownership. The ASF licenses this file + * to you under the Apache License, Version 2.0 (the + * "License"); you may not use this file except in compliance + * with the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, + * software distributed under the License is distributed on an + * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY + * KIND, either express or implied. See the License for the + * specific language governing permissions and limitations + * under the License. + */ +package org.apache.spark.sql.sedona_sql.io.stac + +import org.apache.spark.sql.types.{StringType, StructField, StructType} + +/** + * Defines a STAC extension with its schema and property mappings + * + * @param name + * The name of the STAC extension (e.g., "eo", "proj") + * @param schema + * The schema for the extension properties + * @param propertyMappings + * Mapping from original property path to top-level property name + */ +case class StacExtension(name: String, schema: StructType) + +object StacExtension { + + /** + * Returns an array of STAC extension definitions, each containing: + * - Extension name + * - Extension schema + * - Mapping from original property paths to promoted top-level property names + * + * @return + * Array of StacExtension definitions + */ + def getStacExtensionDefinitions(): Array[StacExtension] = { + Array( + // Grid extension - https://stac-extensions.github.io/grid/v1.1.0/schema.json + StacExtension( + name = "grid", + // Schema for the grid extension, add all required fields here + schema = StructType(Seq(StructField("grid:code", StringType, nullable = true)))) + + // Add other extensions here... + ) + } +} diff --git a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala index 95c86c24c0..cfb0d33abc 100644 --- a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala +++ b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala @@ -24,7 +24,7 @@ import org.apache.spark.sql.SparkSession import org.apache.spark.sql.catalyst.InternalRow import org.apache.spark.sql.execution.datasource.stac.TemporalFilter import org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter -import org.apache.spark.sql.types.{StructField, StructType} +import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType} import org.locationtech.jts.geom.Envelope import java.time.LocalDateTime @@ -129,14 +129,79 @@ object StacUtils { val stacVersion = collection.get("stac_version").asText() // Return the corresponding schema based on the stac_version - stacVersion match { + val coreSchema = stacVersion match { case "1.0.0" => StacTable.SCHEMA_V1_0_0 case version if version.matches("1\\.[1-9]\\d*\\.\\d*") => StacTable.SCHEMA_V1_1_0 // Add more cases here for other versions if needed case _ => throw new IllegalArgumentException(s"Unsupported STAC version: $stacVersion") } + val extensions = StacExtension.getStacExtensionDefinitions() + val schemaWithExtensions = addExtensionFieldsToSchema(coreSchema, extensions) + schemaWithExtensions } + /** + * Adds STAC extension fields to the properties field in the schema + * + * @param schema + * The base STAC schema to enhance + * @param extensions + * Array of STAC extension definitions + * @return + * Enhanced schema with extension fields added to the properties struct + */ + def addExtensionFieldsToSchema( + schema: StructType, + extensions: Array[StacExtension]): StructType = { + // Find the properties field in the schema + val propertiesFieldOpt = schema.fields.find(_.name == "properties") + + if (propertiesFieldOpt.isEmpty) { + // If there's no properties field, return the original schema + return schema + } + + // Get the properties field and its struct type + val propertiesField = propertiesFieldOpt.get + val propertiesStruct = propertiesField.dataType.asInstanceOf[StructType] + + // Create extension fields with metadata indicating their source + val extensionFields = extensions.flatMap { extension => + extension.schema.fields.map { field => + StructField( + field.name, + field.dataType, + field.nullable, + new MetadataBuilder() + .withMetadata(field.metadata) + .putString("stac_extension", extension.name) + .build()) + } + } + + // Create a new properties struct that includes the extension fields + val updatedPropertiesStruct = StructType(propertiesStruct.fields ++ extensionFields) + + // Create a new properties field with the updated struct + val updatedPropertiesField = StructField( + propertiesField.name, + updatedPropertiesStruct, + propertiesField.nullable, + propertiesField.metadata) + + // Replace the properties field in the schema + val updatedFields = schema.fields.map { + case field if field.name == "properties" => updatedPropertiesField + case field => field + } + + // Return the schema with the updated properties field + StructType(updatedFields) + } + + /** + * Promote the properties field to the top level of the row. + */ def promotePropertiesToTop(row: InternalRow, schema: StructType): InternalRow = { val propertiesIndex = schema.fieldIndex("properties") val propertiesStruct = schema("properties").dataType.asInstanceOf[StructType] diff --git a/spark/common/src/test/resources/datasource_stac/extended-item.json b/spark/common/src/test/resources/datasource_stac/extended-item.json index b5f3a0a9df..14857da6b1 100644 --- a/spark/common/src/test/resources/datasource_stac/extended-item.json +++ b/spark/common/src/test/resources/datasource_stac/extended-item.json @@ -58,6 +58,7 @@ "cool_sensor_v2" ], "gsd": 0.66, + "grid:code": "MSIN-2506", "eo:cloud_cover": 1.2, "eo:snow_cover": 0, "statistics": { @@ -207,4 +208,4 @@ "title": "Satellite Ephemeris Metadata" } } -} \ No newline at end of file +} diff --git a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala index 5603b3e551..da9ff573e3 100644 --- a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala +++ b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala @@ -19,7 +19,7 @@ package org.apache.spark.sql.sedona_sql.io.stac import org.apache.sedona.sql.TestBaseScala -import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT} +import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT import org.apache.spark.sql.types._ class StacDataSourceTest extends TestBaseScala { @@ -41,6 +41,14 @@ class StacDataSourceTest extends TestBaseScala { assert(rowCount > 0) } + it("basic df load from local file with extensions should work") { + val dfStac = sparkSession.read.format("stac").load(STAC_COLLECTION_LOCAL) + // Filter rows where grid:code equals "MSIN-2506" + val filteredDf = dfStac.filter(dfStac.col("grid:code") === "MSIN-2506") + val rowCount = filteredDf.count() + assert(rowCount > 0) + } + it("basic df load from remote service endpoints should work") { STAC_COLLECTION_REMOTE.foreach { endpoint => val dfStac = sparkSession.read.format("stac").load(endpoint) @@ -156,54 +164,84 @@ class StacDataSourceTest extends TestBaseScala { } def assertSchema(actualSchema: StructType): Unit = { - val expectedSchema = StructType( - Seq( - StructField("stac_version", StringType, nullable = false), - StructField( - "stac_extensions", - ArrayType(StringType, containsNull = true), - nullable = true), - StructField("type", StringType, nullable = false), - StructField("id", StringType, nullable = false), - StructField("bbox", ArrayType(DoubleType, containsNull = true), nullable = true), - StructField("geometry", new GeometryUDT(), nullable = true), - StructField("title", StringType, nullable = true), - StructField("description", StringType, nullable = true), - StructField("datetime", TimestampType, nullable = true), - StructField("start_datetime", TimestampType, nullable = true), - StructField("end_datetime", TimestampType, nullable = true), - StructField("created", TimestampType, nullable = true), - StructField("updated", TimestampType, nullable = true), - StructField("platform", StringType, nullable = true), - StructField("instruments", ArrayType(StringType, containsNull = true), nullable = true), - StructField("constellation", StringType, nullable = true), - StructField("mission", StringType, nullable = true), - StructField("gsd", DoubleType, nullable = true), - StructField("collection", StringType, nullable = true), - StructField( - "links", - ArrayType( - StructType(Seq( + // Base STAC fields that should always be present + val baseFields = Seq( + StructField("stac_version", StringType, nullable = false), + StructField("stac_extensions", ArrayType(StringType, containsNull = true), nullable = true), + StructField("type", StringType, nullable = false), + StructField("id", StringType, nullable = false), + StructField("bbox", ArrayType(DoubleType, containsNull = true), nullable = true), + StructField("geometry", new GeometryUDT(), nullable = true), + StructField("title", StringType, nullable = true), + StructField("description", StringType, nullable = true), + StructField("datetime", TimestampType, nullable = true), + StructField("start_datetime", TimestampType, nullable = true), + StructField("end_datetime", TimestampType, nullable = true), + StructField("created", TimestampType, nullable = true), + StructField("updated", TimestampType, nullable = true), + StructField("platform", StringType, nullable = true), + StructField("instruments", ArrayType(StringType, containsNull = true), nullable = true), + StructField("constellation", StringType, nullable = true), + StructField("mission", StringType, nullable = true), + StructField("gsd", DoubleType, nullable = true), + StructField("collection", StringType, nullable = true), + StructField( + "links", + ArrayType( + StructType( + Seq( StructField("rel", StringType, nullable = true), StructField("href", StringType, nullable = true), StructField("type", StringType, nullable = true), StructField("title", StringType, nullable = true))), - containsNull = true), - nullable = true), - StructField( - "assets", - MapType( - StringType, - StructType(Seq( + containsNull = true), + nullable = true), + StructField( + "assets", + MapType( + StringType, + StructType( + Seq( StructField("href", StringType, nullable = true), StructField("type", StringType, nullable = true), StructField("title", StringType, nullable = true), StructField("roles", ArrayType(StringType, containsNull = true), nullable = true))), - valueContainsNull = true), - nullable = true))) + valueContainsNull = true), + nullable = true)) + + // Extension fields that may be present + val extensionFields = Seq( + // Grid extension fields + StructField("grid:code", StringType, nullable = true) + // Add other extension fields as needed + ) + + // Check that all base fields are present with correct types + baseFields.foreach { expectedField => + val actualField = actualSchema.fields.find(_.name == expectedField.name) + assert(actualField.isDefined, s"Required field ${expectedField.name} not found in schema") + assert( + actualField.get.dataType == expectedField.dataType, + s"Field ${expectedField.name} has wrong type. Expected: ${expectedField.dataType}, Actual: ${actualField.get.dataType}") + assert( + actualField.get.nullable == expectedField.nullable, + s"Field ${expectedField.name} has wrong nullability. Expected: ${expectedField.nullable}, Actual: ${actualField.get.nullable}") + } + + // Check extension fields if they are present + val actualFieldNames = actualSchema.fields.map(_.name).toSet + extensionFields.foreach { extensionField => + if (actualFieldNames.contains(extensionField.name)) { + val actualField = actualSchema.fields.find(_.name == extensionField.name).get + assert( + actualField.dataType == extensionField.dataType, + s"Extension field ${extensionField.name} has wrong type. Expected: ${extensionField.dataType}, Actual: ${actualField.dataType}") + } + } - assert( - actualSchema == expectedSchema, - s"Schema does not match. Expected: $expectedSchema, Actual: $actualSchema") + // Check that there are no unexpected fields + val expectedFieldNames = (baseFields ++ extensionFields).map(_.name).toSet + val unexpectedFields = actualFieldNames.diff(expectedFieldNames) + assert(unexpectedFields.isEmpty, s"Schema contains unexpected fields: $unexpectedFields") } }
