This is an automated email from the ASF dual-hosted git repository.

jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git


The following commit(s) were added to refs/heads/master by this push:
     new 718fa5b57a [SEDONA-704] Add the grid extension to the stac reader 
(#1883)
718fa5b57a is described below

commit 718fa5b57a0a1c10eb06f2a0358ea3809ffec373
Author: Feng Zhang <[email protected]>
AuthorDate: Thu Mar 27 21:56:44 2025 -0700

    [SEDONA-704] Add the grid extension to the stac reader (#1883)
    
    * [SEDONA-704] Add the grid extension to the stac reader
    
    * In Scala 2.13, the standard library collections were restructured, and 
scala.collection.Seq is now different from Seq.
    
    * Update 
spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
    
    ---------
    
    Co-authored-by: Jia Yu <[email protected]>
---
 docs/tutorial/files/stac-sedona-spark.md           |   1 +
 .../sql/sedona_sql/io/stac/StacExtension.scala     |  57 ++++++++++
 .../spark/sql/sedona_sql/io/stac/StacUtils.scala   |  69 +++++++++++-
 .../resources/datasource_stac/extended-item.json   |   3 +-
 .../sedona_sql/io/stac/StacDataSourceTest.scala    | 120 ++++++++++++++-------
 5 files changed, 206 insertions(+), 44 deletions(-)

diff --git a/docs/tutorial/files/stac-sedona-spark.md 
b/docs/tutorial/files/stac-sedona-spark.md
index aff0e29f0f..864b6ef9d6 100644
--- a/docs/tutorial/files/stac-sedona-spark.md
+++ b/docs/tutorial/files/stac-sedona-spark.md
@@ -77,6 +77,7 @@ root
  |    |-- element: string (containsNull = true)
  |-- constellation: string (nullable = true)
  |-- mission: string (nullable = true)
+ |-- grid:code: string (nullable = true)
  |-- gsd: double (nullable = true)
  |-- collection: string (nullable = true)
  |-- links: array (nullable = true)
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala
new file mode 100644
index 0000000000..c1f78268ac
--- /dev/null
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacExtension.scala
@@ -0,0 +1,57 @@
+/*
+ * Licensed to the Apache Software Foundation (ASF) under one
+ * or more contributor license agreements.  See the NOTICE file
+ * distributed with this work for additional information
+ * regarding copyright ownership.  The ASF licenses this file
+ * to you under the Apache License, Version 2.0 (the
+ * "License"); you may not use this file except in compliance
+ * with the License.  You may obtain a copy of the License at
+ *
+ *   http://www.apache.org/licenses/LICENSE-2.0
+ *
+ * Unless required by applicable law or agreed to in writing,
+ * software distributed under the License is distributed on an
+ * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ * KIND, either express or implied.  See the License for the
+ * specific language governing permissions and limitations
+ * under the License.
+ */
+package org.apache.spark.sql.sedona_sql.io.stac
+
+import org.apache.spark.sql.types.{StringType, StructField, StructType}
+
+/**
+ * Defines a STAC extension with its schema and property mappings
+ *
+ * @param name
+ *   The name of the STAC extension (e.g., "eo", "proj")
+ * @param schema
+ *   The schema for the extension properties
+ * @param propertyMappings
+ *   Mapping from original property path to top-level property name
+ */
+case class StacExtension(name: String, schema: StructType)
+
+object StacExtension {
+
+  /**
+   * Returns an array of STAC extension definitions, each containing:
+   *   - Extension name
+   *   - Extension schema
+   *   - Mapping from original property paths to promoted top-level property 
names
+   *
+   * @return
+   *   Array of StacExtension definitions
+   */
+  def getStacExtensionDefinitions(): Array[StacExtension] = {
+    Array(
+      // Grid extension - 
https://stac-extensions.github.io/grid/v1.1.0/schema.json
+      StacExtension(
+        name = "grid",
+        // Schema for the grid extension, add all required fields here
+        schema = StructType(Seq(StructField("grid:code", StringType, nullable 
= true))))
+
+      // Add other extensions here...
+    )
+  }
+}
diff --git 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
index 95c86c24c0..cfb0d33abc 100644
--- 
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
+++ 
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
@@ -24,7 +24,7 @@ import org.apache.spark.sql.SparkSession
 import org.apache.spark.sql.catalyst.InternalRow
 import org.apache.spark.sql.execution.datasource.stac.TemporalFilter
 import 
org.apache.spark.sql.execution.datasources.parquet.GeoParquetSpatialFilter
-import org.apache.spark.sql.types.{StructField, StructType}
+import org.apache.spark.sql.types.{MetadataBuilder, StructField, StructType}
 import org.locationtech.jts.geom.Envelope
 
 import java.time.LocalDateTime
@@ -129,14 +129,79 @@ object StacUtils {
     val stacVersion = collection.get("stac_version").asText()
 
     // Return the corresponding schema based on the stac_version
-    stacVersion match {
+    val coreSchema = stacVersion match {
       case "1.0.0" => StacTable.SCHEMA_V1_0_0
       case version if version.matches("1\\.[1-9]\\d*\\.\\d*") => 
StacTable.SCHEMA_V1_1_0
       // Add more cases here for other versions if needed
       case _ => throw new IllegalArgumentException(s"Unsupported STAC version: 
$stacVersion")
     }
+    val extensions = StacExtension.getStacExtensionDefinitions()
+    val schemaWithExtensions = addExtensionFieldsToSchema(coreSchema, 
extensions)
+    schemaWithExtensions
   }
 
+  /**
+   * Adds STAC extension fields to the properties field in the schema
+   *
+   * @param schema
+   *   The base STAC schema to enhance
+   * @param extensions
+   *   Array of STAC extension definitions
+   * @return
+   *   Enhanced schema with extension fields added to the properties struct
+   */
+  def addExtensionFieldsToSchema(
+      schema: StructType,
+      extensions: Array[StacExtension]): StructType = {
+    // Find the properties field in the schema
+    val propertiesFieldOpt = schema.fields.find(_.name == "properties")
+
+    if (propertiesFieldOpt.isEmpty) {
+      // If there's no properties field, return the original schema
+      return schema
+    }
+
+    // Get the properties field and its struct type
+    val propertiesField = propertiesFieldOpt.get
+    val propertiesStruct = propertiesField.dataType.asInstanceOf[StructType]
+
+    // Create extension fields with metadata indicating their source
+    val extensionFields = extensions.flatMap { extension =>
+      extension.schema.fields.map { field =>
+        StructField(
+          field.name,
+          field.dataType,
+          field.nullable,
+          new MetadataBuilder()
+            .withMetadata(field.metadata)
+            .putString("stac_extension", extension.name)
+            .build())
+      }
+    }
+
+    // Create a new properties struct that includes the extension fields
+    val updatedPropertiesStruct = StructType(propertiesStruct.fields ++ 
extensionFields)
+
+    // Create a new properties field with the updated struct
+    val updatedPropertiesField = StructField(
+      propertiesField.name,
+      updatedPropertiesStruct,
+      propertiesField.nullable,
+      propertiesField.metadata)
+
+    // Replace the properties field in the schema
+    val updatedFields = schema.fields.map {
+      case field if field.name == "properties" => updatedPropertiesField
+      case field => field
+    }
+
+    // Return the schema with the updated properties field
+    StructType(updatedFields)
+  }
+
+  /**
+   * Promote the properties field to the top level of the row.
+   */
   def promotePropertiesToTop(row: InternalRow, schema: StructType): 
InternalRow = {
     val propertiesIndex = schema.fieldIndex("properties")
     val propertiesStruct = 
schema("properties").dataType.asInstanceOf[StructType]
diff --git a/spark/common/src/test/resources/datasource_stac/extended-item.json 
b/spark/common/src/test/resources/datasource_stac/extended-item.json
index b5f3a0a9df..14857da6b1 100644
--- a/spark/common/src/test/resources/datasource_stac/extended-item.json
+++ b/spark/common/src/test/resources/datasource_stac/extended-item.json
@@ -58,6 +58,7 @@
       "cool_sensor_v2"
     ],
     "gsd": 0.66,
+    "grid:code": "MSIN-2506",
     "eo:cloud_cover": 1.2,
     "eo:snow_cover": 0,
     "statistics": {
@@ -207,4 +208,4 @@
       "title": "Satellite Ephemeris Metadata"
     }
   }
-}
\ No newline at end of file
+}
diff --git 
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
 
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
index 5603b3e551..da9ff573e3 100644
--- 
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
+++ 
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
@@ -19,7 +19,7 @@
 package org.apache.spark.sql.sedona_sql.io.stac
 
 import org.apache.sedona.sql.TestBaseScala
-import org.apache.spark.sql.sedona_sql.UDT.{GeometryUDT, RasterUDT}
+import org.apache.spark.sql.sedona_sql.UDT.GeometryUDT
 import org.apache.spark.sql.types._
 
 class StacDataSourceTest extends TestBaseScala {
@@ -41,6 +41,14 @@ class StacDataSourceTest extends TestBaseScala {
     assert(rowCount > 0)
   }
 
+  it("basic df load from local file with extensions should work") {
+    val dfStac = sparkSession.read.format("stac").load(STAC_COLLECTION_LOCAL)
+    // Filter rows where grid:code equals "MSIN-2506"
+    val filteredDf = dfStac.filter(dfStac.col("grid:code") === "MSIN-2506")
+    val rowCount = filteredDf.count()
+    assert(rowCount > 0)
+  }
+
   it("basic df load from remote service endpoints should work") {
     STAC_COLLECTION_REMOTE.foreach { endpoint =>
       val dfStac = sparkSession.read.format("stac").load(endpoint)
@@ -156,54 +164,84 @@ class StacDataSourceTest extends TestBaseScala {
   }
 
   def assertSchema(actualSchema: StructType): Unit = {
-    val expectedSchema = StructType(
-      Seq(
-        StructField("stac_version", StringType, nullable = false),
-        StructField(
-          "stac_extensions",
-          ArrayType(StringType, containsNull = true),
-          nullable = true),
-        StructField("type", StringType, nullable = false),
-        StructField("id", StringType, nullable = false),
-        StructField("bbox", ArrayType(DoubleType, containsNull = true), 
nullable = true),
-        StructField("geometry", new GeometryUDT(), nullable = true),
-        StructField("title", StringType, nullable = true),
-        StructField("description", StringType, nullable = true),
-        StructField("datetime", TimestampType, nullable = true),
-        StructField("start_datetime", TimestampType, nullable = true),
-        StructField("end_datetime", TimestampType, nullable = true),
-        StructField("created", TimestampType, nullable = true),
-        StructField("updated", TimestampType, nullable = true),
-        StructField("platform", StringType, nullable = true),
-        StructField("instruments", ArrayType(StringType, containsNull = true), 
nullable = true),
-        StructField("constellation", StringType, nullable = true),
-        StructField("mission", StringType, nullable = true),
-        StructField("gsd", DoubleType, nullable = true),
-        StructField("collection", StringType, nullable = true),
-        StructField(
-          "links",
-          ArrayType(
-            StructType(Seq(
+    // Base STAC fields that should always be present
+    val baseFields = Seq(
+      StructField("stac_version", StringType, nullable = false),
+      StructField("stac_extensions", ArrayType(StringType, containsNull = 
true), nullable = true),
+      StructField("type", StringType, nullable = false),
+      StructField("id", StringType, nullable = false),
+      StructField("bbox", ArrayType(DoubleType, containsNull = true), nullable 
= true),
+      StructField("geometry", new GeometryUDT(), nullable = true),
+      StructField("title", StringType, nullable = true),
+      StructField("description", StringType, nullable = true),
+      StructField("datetime", TimestampType, nullable = true),
+      StructField("start_datetime", TimestampType, nullable = true),
+      StructField("end_datetime", TimestampType, nullable = true),
+      StructField("created", TimestampType, nullable = true),
+      StructField("updated", TimestampType, nullable = true),
+      StructField("platform", StringType, nullable = true),
+      StructField("instruments", ArrayType(StringType, containsNull = true), 
nullable = true),
+      StructField("constellation", StringType, nullable = true),
+      StructField("mission", StringType, nullable = true),
+      StructField("gsd", DoubleType, nullable = true),
+      StructField("collection", StringType, nullable = true),
+      StructField(
+        "links",
+        ArrayType(
+          StructType(
+            Seq(
               StructField("rel", StringType, nullable = true),
               StructField("href", StringType, nullable = true),
               StructField("type", StringType, nullable = true),
               StructField("title", StringType, nullable = true))),
-            containsNull = true),
-          nullable = true),
-        StructField(
-          "assets",
-          MapType(
-            StringType,
-            StructType(Seq(
+          containsNull = true),
+        nullable = true),
+      StructField(
+        "assets",
+        MapType(
+          StringType,
+          StructType(
+            Seq(
               StructField("href", StringType, nullable = true),
               StructField("type", StringType, nullable = true),
               StructField("title", StringType, nullable = true),
               StructField("roles", ArrayType(StringType, containsNull = true), 
nullable = true))),
-            valueContainsNull = true),
-          nullable = true)))
+          valueContainsNull = true),
+        nullable = true))
+
+    // Extension fields that may be present
+    val extensionFields = Seq(
+      // Grid extension fields
+      StructField("grid:code", StringType, nullable = true)
+      // Add other extension fields as needed
+    )
+
+    // Check that all base fields are present with correct types
+    baseFields.foreach { expectedField =>
+      val actualField = actualSchema.fields.find(_.name == expectedField.name)
+      assert(actualField.isDefined, s"Required field ${expectedField.name} not 
found in schema")
+      assert(
+        actualField.get.dataType == expectedField.dataType,
+        s"Field ${expectedField.name} has wrong type. Expected: 
${expectedField.dataType}, Actual: ${actualField.get.dataType}")
+      assert(
+        actualField.get.nullable == expectedField.nullable,
+        s"Field ${expectedField.name} has wrong nullability. Expected: 
${expectedField.nullable}, Actual: ${actualField.get.nullable}")
+    }
+
+    // Check extension fields if they are present
+    val actualFieldNames = actualSchema.fields.map(_.name).toSet
+    extensionFields.foreach { extensionField =>
+      if (actualFieldNames.contains(extensionField.name)) {
+        val actualField = actualSchema.fields.find(_.name == 
extensionField.name).get
+        assert(
+          actualField.dataType == extensionField.dataType,
+          s"Extension field ${extensionField.name} has wrong type. Expected: 
${extensionField.dataType}, Actual: ${actualField.dataType}")
+      }
+    }
 
-    assert(
-      actualSchema == expectedSchema,
-      s"Schema does not match. Expected: $expectedSchema, Actual: 
$actualSchema")
+    // Check that there are no unexpected fields
+    val expectedFieldNames = (baseFields ++ extensionFields).map(_.name).toSet
+    val unexpectedFields = actualFieldNames.diff(expectedFieldNames)
+    assert(unexpectedFields.isEmpty, s"Schema contains unexpected fields: 
$unexpectedFields")
   }
 }

Reply via email to