This is an automated email from the ASF dual-hosted git repository.
jiayu pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/sedona.git
The following commit(s) were added to refs/heads/master by this push:
new 47ed183357 [GH-2522] Implemented STAC authentication for both python
and scala APIs (#2523)
47ed183357 is described below
commit 47ed183357adc13148e6f66be334f91cffa922fd
Author: Feng Zhang <[email protected]>
AuthorDate: Tue Nov 25 12:30:00 2025 -0800
[GH-2522] Implemented STAC authentication for both python and scala APIs
(#2523)
---
docs/tutorial/files/stac-sedona-spark.md | 150 +++++++++++++++-
python/sedona/spark/stac/client.py | 87 +++++++++-
python/sedona/spark/stac/collection_client.py | 39 ++++-
python/tests/stac/test_auth.py | 179 +++++++++++++++++++
python/tests/stac/test_auth_integration.py | 193 +++++++++++++++++++++
.../spark/sql/sedona_sql/io/stac/StacBatch.scala | 7 +-
.../sedona_sql/io/stac/StacPartitionReader.scala | 35 +++-
.../spark/sql/sedona_sql/io/stac/StacUtils.scala | 85 ++++++++-
.../sedona_sql/io/stac/StacDataSourceTest.scala | 135 ++++++++++++++
.../sql/sedona_sql/io/stac/StacUtilsTest.scala | 65 +++++++
10 files changed, 954 insertions(+), 21 deletions(-)
diff --git a/docs/tutorial/files/stac-sedona-spark.md
b/docs/tutorial/files/stac-sedona-spark.md
index 78b8616f83..d7154db17e 100644
--- a/docs/tutorial/files/stac-sedona-spark.md
+++ b/docs/tutorial/files/stac-sedona-spark.md
@@ -169,6 +169,8 @@ Below are reader options that can be set to control the
behavior of the STAC rea
- **itemsLimitPerRequest**: This option specifies the maximum number of items
to be requested in a single API call. It helps in controlling the size of each
request. The default value is set to 10.
+- **headers**: This option specifies HTTP headers to include in STAC API
requests. It should be a JSON-encoded string containing a dictionary of header
key-value pairs. This is useful for authentication and custom headers. Example:
`{"Authorization": "Basic <base64_credentials>"}`
+
These configurations can be combined into a single `Map[String, String]` and
passed to the STAC reader as shown below:
```scala
@@ -269,14 +271,129 @@ client.get_collection("aster-l1t").save_to_geoparquet(
These examples demonstrate how to use the Client class to search for items in
a STAC collection with various filters and return the results as either an
iterator of PyStacItem objects or a Spark DataFrame.
+### Authentication
+
+Many STAC services require authentication to access their data. The STAC
client supports multiple authentication methods including HTTP Basic
Authentication, Bearer Token Authentication, and custom headers.
+
+#### Basic Authentication
+
+Basic authentication is commonly used with API keys or username/password
combinations. Many services (like Planet Labs) use API keys as the username
with an empty password.
+
+```python
+from sedona.spark.stac import Client
+
+# Example 1: Using an API key (common pattern)
+client = Client.open("https://api.example.com/stac/v1")
+client.with_basic_auth("your_api_key_here", "")
+
+# Search for items with authentication
+df = client.search(collection_id="example-collection", max_items=10)
+df.show()
+
+# Example 2: Using username and password
+client = Client.open("https://api.example.com/stac/v1")
+client.with_basic_auth("username", "password")
+
+df = client.search(collection_id="example-collection", max_items=10)
+df.show()
+
+# Example 3: Method chaining
+df = (
+ Client.open("https://api.example.com/stac/v1")
+ .with_basic_auth("your_api_key", "")
+ .search(collection_id="example-collection", max_items=10)
+)
+df.show()
+```
+
+#### Bearer Token Authentication
+
+Bearer token authentication is used with OAuth2 tokens and JWT tokens. Note
that some services may only support specific authentication methods.
+
+```python
+from sedona.spark.stac import Client
+
+# Using a bearer token
+client = Client.open("https://api.example.com/stac/v1")
+client.with_bearer_token("your_access_token_here")
+
+df = client.search(collection_id="example-collection", max_items=10)
+df.show()
+
+# Method chaining
+df = (
+ Client.open("https://api.example.com/stac/v1")
+ .with_bearer_token("your_token")
+ .search(collection_id="example-collection", max_items=10)
+)
+df.show()
+```
+
+#### Custom Headers
+
+You can also pass custom headers directly when creating the client, which is
useful for services with non-standard authentication requirements.
+
+```python
+from sedona.spark.stac import Client
+
+# Using custom headers
+headers = {"Authorization": "Bearer your_token_here", "X-Custom-Header":
"custom_value"}
+client = Client.open("https://api.example.com/stac/v1", headers=headers)
+
+df = client.search(collection_id="example-collection", max_items=10)
+df.show()
+```
+
+#### Authentication with Scala DataSource
+
+When using the STAC data source directly in Scala or through Spark SQL, you
can pass authentication headers as a JSON-encoded option:
+
+```python
+import json
+from pyspark.sql import SparkSession
+
+# Prepare authentication headers
+headers = {"Authorization": "Basic <base64_encoded_credentials>"}
+headers_json = json.dumps(headers)
+
+# Load STAC data with authentication
+df = (
+ spark.read.format("stac")
+ .option("headers", headers_json)
+ .load("https://api.example.com/stac/v1/collections/example-collection")
+)
+
+df.show()
+```
+
+```scala
+// Scala example
+val headersJson = """{"Authorization":"Basic <base64_encoded_credentials>"}"""
+
+val df = sparkSession.read
+ .format("stac")
+ .option("headers", headersJson)
+ .load("https://api.example.com/stac/v1/collections/example-collection")
+
+df.show()
+```
+
+#### Important Notes
+
+- **Authentication methods are mutually exclusive**: Setting a new
authentication method will overwrite any previously set Authorization header,
but other custom headers remain unchanged.
+- **Headers are propagated**: Headers set on the Client are automatically
passed to all collection and item requests.
+- **Service-specific requirements**: Different STAC services may require
different authentication methods. For example, Planet Labs requires Basic
Authentication rather than Bearer tokens for collection access.
+- **Backward compatibility**: All authentication parameters are optional.
Existing code that accesses public STAC services without authentication will
continue to work unchanged.
+
### Methods
-**`open(url: str) -> Client`**
+**`open(url: str, headers: Optional[dict] = None) -> Client`**
Opens a connection to the specified STAC API URL.
Parameters:
* `url` (*str*): The URL of the STAC API to connect to. Example:
`"https://planetarycomputer.microsoft.com/api/stac/v1"`
+* `headers` (*Optional[dict]*): Optional dictionary of HTTP headers for
authentication or custom headers. Example: `{"Authorization": "Bearer
token123"}`
Returns:
@@ -284,6 +401,37 @@ Returns:
---
+**`with_basic_auth(username: str, password: str) -> Client`**
+Adds HTTP Basic Authentication to the client.
+
+This method encodes the username and password using Base64 and adds the
appropriate Authorization header for HTTP Basic Authentication.
+
+Parameters:
+
+* `username` (*str*): The username for authentication. For API keys, this is
typically the API key itself. Example: `"your_api_key"`
+* `password` (*str*): The password for authentication. For API keys, this is
often left empty. Example: `""`
+
+Returns:
+
+* `Client`: Returns self for method chaining.
+
+---
+
+**`with_bearer_token(token: str) -> Client`**
+Adds Bearer Token Authentication to the client.
+
+This method adds the appropriate Authorization header for Bearer Token
authentication, commonly used with OAuth2 and API tokens.
+
+Parameters:
+
+* `token` (*str*): The bearer token for authentication. Example:
`"your_access_token_here"`
+
+Returns:
+
+* `Client`: Returns self for method chaining.
+
+---
+
**`get_collection(collection_id: str) -> CollectionClient`**
Retrieves a collection client for the specified collection ID.
diff --git a/python/sedona/spark/stac/client.py
b/python/sedona/spark/stac/client.py
index 303994bfa4..d8b1a1a623 100644
--- a/python/sedona/spark/stac/client.py
+++ b/python/sedona/spark/stac/client.py
@@ -25,20 +25,95 @@ from pyspark.sql import DataFrame
class Client:
- def __init__(self, url: str):
+ def __init__(self, url: str, headers: Optional[dict] = None):
+ """
+ Initializes a STAC client with optional authentication headers.
+
+ :param url: The URL of the STAC API to connect to.
+ :param headers: Optional dictionary of HTTP headers to include in
requests.
+ Can be used for authentication or custom headers.
+ """
self.url = url
+ self.headers = headers if headers is not None else {}
@classmethod
- def open(cls, url: str):
+ def open(cls, url: str, headers: Optional[dict] = None):
"""
Opens a connection to the specified STAC API URL.
- This class method creates an instance of the Client class with the
given URL.
+ This class method creates an instance of the Client class with the
given URL
+ and optional authentication headers.
:param url: The URL of the STAC API to connect to. Example:
"https://planetarycomputer.microsoft.com/api/stac/v1"
+ :param headers: Optional dictionary of HTTP headers for authentication.
+ Example: {"Authorization": "Bearer token123"}
:return: An instance of the Client class connected to the specified
URL.
+
+ Example usage:
+ # Without authentication
+ client =
Client.open("https://planetarycomputer.microsoft.com/api/stac/v1")
+
+ # With custom headers
+ client = Client.open(
+ "https://example.com/stac/v1",
+ headers={"Authorization": "Bearer token123"}
+ )
+
+ # Using convenience methods
+ client = Client.open("https://example.com/stac/v1")
+ client.with_basic_auth("username", "password")
+ """
+ return cls(url, headers)
+
+ def with_basic_auth(self, username: str, password: str):
+ """
+ Adds HTTP Basic Authentication to the client.
+
+ This method encodes the username and password using Base64 and adds
+ the appropriate Authorization header for HTTP Basic Authentication.
+
+ :param username: The username for authentication. For API keys, this
is typically the API key itself.
+ :param password: The password for authentication. For API keys, this
is often left empty.
+ :return: Self for method chaining.
+
+ Example usage:
+ # Standard basic auth
+ client = Client.open("https://example.com/stac/v1")
+ client.with_basic_auth("user", "pass")
+
+ # API key as username (common pattern)
+ client.with_basic_auth("api_key_xyz", "")
+
+ # Method chaining
+ df = Client.open(url).with_basic_auth(api_key,
"").search(collection_id="test")
+ """
+ import base64
+
+ userpass = f"{username}:{password}"
+ b64_userpass = base64.b64encode(userpass.encode()).decode()
+ self.headers["Authorization"] = f"Basic {b64_userpass}"
+ return self
+
+ def with_bearer_token(self, token: str):
+ """
+ Adds Bearer Token Authentication to the client.
+
+ This method adds the appropriate Authorization header for Bearer Token
+ authentication, commonly used with OAuth2 and API tokens.
+
+ :param token: The bearer token for authentication.
+ :return: Self for method chaining.
+
+ Example usage:
+ # Bearer token auth
+ client = Client.open("https://example.com/stac/v1")
+ client.with_bearer_token("your_access_token_here")
+
+ # Method chaining
+ df =
Client.open(url).with_bearer_token(token).search(collection_id="test")
"""
- return cls(url)
+ self.headers["Authorization"] = f"Bearer {token}"
+ return self
def get_collection(self, collection_id: str):
"""
@@ -50,7 +125,7 @@ class Client:
:param collection_id: The ID of the collection to retrieve. Example:
"aster-l1t"
:return: An instance of the CollectionClient class for the specified
collection.
"""
- return CollectionClient(self.url, collection_id)
+ return CollectionClient(self.url, collection_id, headers=self.headers)
def get_collection_from_catalog(self):
"""
@@ -62,7 +137,7 @@ class Client:
dict: The root catalog of the STAC API.
"""
# Implement logic to fetch and return the root catalog
- return CollectionClient(self.url, None)
+ return CollectionClient(self.url, None, headers=self.headers)
def search(
self,
diff --git a/python/sedona/spark/stac/collection_client.py
b/python/sedona/spark/stac/collection_client.py
index 972dd0c140..3650f0431d 100644
--- a/python/sedona/spark/stac/collection_client.py
+++ b/python/sedona/spark/stac/collection_client.py
@@ -53,10 +53,23 @@ def get_collection_url(url: str, collection_id:
Optional[str] = None) -> str:
class CollectionClient:
- def __init__(self, url: str, collection_id: Optional[str] = None):
+ def __init__(
+ self,
+ url: str,
+ collection_id: Optional[str] = None,
+ headers: Optional[dict] = None,
+ ):
+ """
+ Initializes a collection client for a STAC collection.
+
+ :param url: The base URL of the STAC API.
+ :param collection_id: The ID of the collection to access. If None,
accesses the catalog root.
+ :param headers: Optional dictionary of HTTP headers for authentication.
+ """
self.url = url
self.collection_id = collection_id
self.collection_url = get_collection_url(url, collection_id)
+ self.headers = headers if headers is not None else {}
self.spark = SparkSession.getActiveSession()
@staticmethod
@@ -483,6 +496,22 @@ class CollectionClient:
return df
def load_items_df(self, bbox, geometry, datetime, ids, max_items):
+ """
+ Loads items from the STAC collection as a Spark DataFrame.
+
+ This method handles the conversion of headers to Spark options and
+ applies various filters to the data.
+ """
+ import json
+
+ # Prepare Spark DataFrameReader with headers if present
+ reader = self.spark.read.format("stac")
+
+ # Encode headers as JSON string for passing to Spark
+ if self.headers:
+ headers_json = json.dumps(self.headers)
+ reader = reader.option("headers", headers_json)
+
# Load the collection data from the specified collection URL
if (
not ids
@@ -491,13 +520,9 @@ class CollectionClient:
and not datetime
and max_items is not None
):
- df = (
- self.spark.read.format("stac")
- .option("itemsLimitMax", max_items)
- .load(self.collection_url)
- )
+ df = reader.option("itemsLimitMax",
max_items).load(self.collection_url)
else:
- df = self.spark.read.format("stac").load(self.collection_url)
+ df = reader.load(self.collection_url)
# Apply ID filters if provided
if ids:
if isinstance(ids, tuple):
diff --git a/python/tests/stac/test_auth.py b/python/tests/stac/test_auth.py
new file mode 100644
index 0000000000..eec663f2ee
--- /dev/null
+++ b/python/tests/stac/test_auth.py
@@ -0,0 +1,179 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import base64
+import json
+from unittest.mock import patch, MagicMock
+from sedona.spark.stac.client import Client
+from sedona.spark.stac.collection_client import CollectionClient
+
+from tests.test_base import TestBase
+
+
+class TestStacAuthentication(TestBase):
+ """Tests for STAC authentication functionality."""
+
+ def test_client_with_headers(self):
+ """Test that Client can be initialized with custom headers."""
+ headers = {"Authorization": "Bearer test_token"}
+ client = Client.open("https://example.com/stac/v1", headers=headers)
+
+ assert client.headers == headers
+ assert client.url == "https://example.com/stac/v1"
+
+ def test_client_without_headers(self):
+ """Test that Client works without headers (backward compatibility)."""
+ client = Client.open("https://example.com/stac/v1")
+
+ assert client.headers == {}
+ assert client.url == "https://example.com/stac/v1"
+
+ def test_with_basic_auth(self):
+ """Test basic authentication header encoding."""
+ client = Client.open("https://example.com/stac/v1")
+ client.with_basic_auth("testuser", "testpass")
+
+ # Verify the header was set correctly
+ assert "Authorization" in client.headers
+ auth_header = client.headers["Authorization"]
+ assert auth_header.startswith("Basic ")
+
+ # Verify the encoding is correct
+ encoded_part = auth_header.replace("Basic ", "")
+ decoded = base64.b64decode(encoded_part).decode()
+ assert decoded == "testuser:testpass"
+
+ def test_with_basic_auth_api_key(self):
+ """Test basic auth with API key pattern (common in STAC APIs)."""
+ client = Client.open("https://example.com/stac/v1")
+ client.with_basic_auth("api_key_12345", "")
+
+ # Verify the header was set correctly
+ assert "Authorization" in client.headers
+ auth_header = client.headers["Authorization"]
+ encoded_part = auth_header.replace("Basic ", "")
+ decoded = base64.b64decode(encoded_part).decode()
+ assert decoded == "api_key_12345:"
+
+ def test_with_bearer_token(self):
+ """Test bearer token authentication."""
+ client = Client.open("https://example.com/stac/v1")
+ client.with_bearer_token("test_token_abc123")
+
+ # Verify the header was set correctly
+ assert "Authorization" in client.headers
+ assert client.headers["Authorization"] == "Bearer test_token_abc123"
+
+ def test_method_chaining(self):
+ """Test that authentication methods support chaining."""
+ client = Client.open("https://example.com/stac/v1").with_bearer_token(
+ "token123"
+ )
+
+ assert client.headers["Authorization"] == "Bearer token123"
+
+ def test_headers_passed_to_collection_client(self):
+ """Test that headers are passed to CollectionClient."""
+ headers = {"Authorization": "Bearer test_token"}
+ client = Client.open("https://example.com/stac/v1", headers=headers)
+
+ collection_client = client.get_collection("test-collection")
+
+ assert isinstance(collection_client, CollectionClient)
+ assert collection_client.headers == headers
+
+ def test_headers_passed_to_catalog_client(self):
+ """Test that headers are passed to catalog client."""
+ headers = {"Authorization": "Bearer test_token"}
+ client = Client.open("https://example.com/stac/v1", headers=headers)
+
+ catalog_client = client.get_collection_from_catalog()
+
+ assert isinstance(catalog_client, CollectionClient)
+ assert catalog_client.headers == headers
+
+
@patch("sedona.spark.stac.collection_client.CollectionClient.load_items_df")
+ def test_headers_encoded_as_json_option(self, mock_load_items):
+ """Test that headers are JSON-encoded when passed to Spark."""
+ # Create a mock DataFrame
+ mock_df = MagicMock()
+ mock_load_items.return_value = mock_df
+
+ headers = {"Authorization": "Bearer test_token", "X-Custom": "value"}
+ client = Client.open("https://example.com/stac/v1", headers=headers)
+
+ # Trigger a search that calls load_items_df
+ collection_client = client.get_collection("test-collection")
+
+ # Verify headers are stored correctly
+ assert collection_client.headers == headers
+
+ def test_custom_headers(self):
+ """Test that custom headers (beyond auth) can be set."""
+ headers = {
+ "Authorization": "Bearer token",
+ "X-API-Key": "key123",
+ "User-Agent": "CustomClient/1.0",
+ }
+ client = Client.open("https://example.com/stac/v1", headers=headers)
+
+ assert client.headers == headers
+
+ def test_overwrite_auth_header(self):
+ """Test that auth methods can overwrite existing auth headers."""
+ client = Client.open("https://example.com/stac/v1")
+ client.with_bearer_token("first_token")
+ assert client.headers["Authorization"] == "Bearer first_token"
+
+ # Overwrite with basic auth
+ client.with_basic_auth("user", "pass")
+ assert client.headers["Authorization"].startswith("Basic ")
+
+ def test_collection_client_initialization_with_headers(self):
+ """Test CollectionClient can be initialized with headers directly."""
+ headers = {"Authorization": "Bearer test_token"}
+ collection_client = CollectionClient(
+ "https://example.com/stac/v1", "test-collection", headers=headers
+ )
+
+ assert collection_client.headers == headers
+ assert collection_client.collection_id == "test-collection"
+
+ def test_collection_client_without_headers(self):
+ """Test CollectionClient backward compatibility without headers."""
+ collection_client = CollectionClient(
+ "https://example.com/stac/v1", "test-collection"
+ )
+
+ assert collection_client.headers == {}
+
+ def test_empty_headers_dict(self):
+ """Test that empty headers dict works correctly."""
+ client = Client.open("https://example.com/stac/v1", headers={})
+
+ assert client.headers == {}
+
+ def test_headers_with_special_characters(self):
+ """Test that headers with special characters are handled correctly."""
+ # Base64 encoding should handle special characters
+ client = Client.open("https://example.com/stac/v1")
+ client.with_basic_auth("[email protected]", "p@ss!word#123")
+
+ auth_header = client.headers["Authorization"]
+ encoded_part = auth_header.replace("Basic ", "")
+ decoded = base64.b64decode(encoded_part).decode()
+ assert decoded == "[email protected]:p@ss!word#123"
diff --git a/python/tests/stac/test_auth_integration.py
b/python/tests/stac/test_auth_integration.py
new file mode 100644
index 0000000000..0a94b2de06
--- /dev/null
+++ b/python/tests/stac/test_auth_integration.py
@@ -0,0 +1,193 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Integration tests for STAC authentication with real services.
+
+These tests use environment variables to configure authenticated services.
+Tests are skipped when environment variables are not set.
+
+Environment variables:
+ STAC_PUBLIC_URL - Public STAC service URL (e.g., Microsoft Planetary
Computer)
+ STAC_AUTH_URL - Authenticated STAC service URL
+ STAC_USERNAME - Username or API key for basic authentication
+ STAC_PASSWORD - Password (can be empty for API keys)
+ STAC_BEARER_TOKEN - Bearer token for token-based authentication
+ STAC_AUTH_URL_REQUIRE_AUTH - URL requiring authentication (for failure
testing)
+"""
+
+import os
+import pytest
+from sedona.spark.stac.client import Client
+
+from tests.test_base import TestBase
+
+
+class TestStacAuthIntegration(TestBase):
+ """
+ Integration tests for authenticated STAC services.
+
+ Tests are skipped when required environment variables are not set.
+ """
+
+ @pytest.mark.skipif(
+ not os.getenv("STAC_PUBLIC_URL"),
+ reason="STAC_PUBLIC_URL not set - skip public service test",
+ )
+ def test_public_service_without_authentication(self):
+ """
+ Test that public STAC services work without authentication.
+
+ This verifies backward compatibility - existing code should work
unchanged.
+ Set STAC_PUBLIC_URL environment variable to test.
+ Example:
https://planetarycomputer.microsoft.com/api/stac/v1/collections/naip
+ """
+ public_url = os.getenv("STAC_PUBLIC_URL")
+
+ client = Client.open(public_url)
+
+ # Try to load data without authentication
+ df = client.search(max_items=10)
+
+ assert df is not None
+ assert df.count() >= 0
+
+ @pytest.mark.skipif(
+ not os.getenv("STAC_AUTH_URL") or not os.getenv("STAC_BEARER_TOKEN"),
+ reason="STAC_AUTH_URL or STAC_BEARER_TOKEN not set - skip bearer token
test",
+ )
+ def test_bearer_token_authentication(self):
+ """
+ Test authentication with bearer token.
+
+ NOTE: Planet Labs API requires Basic Auth (not Bearer token) for
collections.
+ Bearer token authentication works with other STAC services.
+
+ Requires: STAC_AUTH_URL and STAC_BEARER_TOKEN environment variables
+ """
+ auth_url = os.getenv("STAC_AUTH_URL")
+ bearer_token = os.getenv("STAC_BEARER_TOKEN")
+
+ client = Client.open(auth_url)
+ client.with_bearer_token(bearer_token)
+
+ # Try to load data with bearer token
+ df = client.search(max_items=10)
+
+ assert df is not None
+ assert df.count() >= 0
+
+ @pytest.mark.skipif(
+ not os.getenv("STAC_AUTH_URL") or not os.getenv("STAC_USERNAME"),
+ reason="STAC_AUTH_URL or STAC_USERNAME not set - skip basic auth test",
+ )
+ def test_basic_authentication(self):
+ """
+ Test authentication with username/password or API key.
+
+ This test works with Planet Labs API (username=API_key,
password=empty).
+
+ Requires: STAC_AUTH_URL and STAC_USERNAME environment variables
+ STAC_PASSWORD is optional (defaults to empty string)
+ """
+ auth_url = os.getenv("STAC_AUTH_URL")
+ username = os.getenv("STAC_USERNAME")
+ password = os.getenv("STAC_PASSWORD", "")
+
+ client = Client.open(auth_url)
+ client.with_basic_auth(username, password)
+
+ # Try to load data with basic auth
+ df = client.search(max_items=10)
+
+ assert df is not None
+ assert df.count() >= 0
+
+ @pytest.mark.skipif(
+ not os.getenv("STAC_AUTH_URL_REQUIRE_AUTH"),
+ reason="STAC_AUTH_URL_REQUIRE_AUTH not set - skip auth failure test",
+ )
+ def test_authentication_failure(self):
+ """
+ Test that authentication errors are properly raised.
+
+ This verifies that accessing an authenticated endpoint without
credentials
+ results in a proper authentication error.
+
+ Requires: STAC_AUTH_URL_REQUIRE_AUTH environment variable
+ """
+ auth_url = os.getenv("STAC_AUTH_URL_REQUIRE_AUTH")
+
+ client = Client.open(auth_url)
+
+ # Try to access authenticated endpoint without credentials
+ with pytest.raises(Exception) as exc_info:
+ df = client.search(max_items=1)
+ df.count() # Force execution
+
+ # Verify we get an authentication-related error
+ error_message = str(exc_info.value).lower()
+ assert (
+ "401" in error_message
+ or "unauthorized" in error_message
+ or "403" in error_message
+ or "forbidden" in error_message
+ ), f"Expected authentication error, but got: {exc_info.value}"
+
+
+class TestStacAuthUnit(TestBase):
+ """Unit tests for authentication methods that don't require real
services."""
+
+ def test_bearer_token_format(self):
+ """Test that bearer token is formatted correctly."""
+ client = Client.open("https://example.com/stac/v1")
+ client.with_bearer_token("test_token_12345")
+
+ assert "Authorization" in client.headers
+ assert client.headers["Authorization"] == "Bearer test_token_12345"
+
+ def test_basic_auth_encoding(self):
+ """Test that basic auth is encoded correctly."""
+ import base64
+
+ client = Client.open("https://example.com/stac/v1")
+ client.with_basic_auth("testuser", "testpass")
+
+ auth_header = client.headers["Authorization"]
+ assert auth_header.startswith("Basic ")
+
+ # Decode and verify
+ encoded_part = auth_header.replace("Basic ", "")
+ decoded = base64.b64decode(encoded_part).decode()
+ assert decoded == "testuser:testpass"
+
+ def test_method_chaining(self):
+ """Test that authentication methods support chaining."""
+ client = Client.open("https://example.com/stac/v1")
+ result = client.with_bearer_token("test_token")
+
+ # Method should return self for chaining
+ assert result is client
+
+ def test_headers_propagation(self):
+ """Test that headers are propagated to collection clients."""
+ client = Client.open("https://example.com/stac/v1")
+ client.with_bearer_token("test_token")
+
+ # Headers should be set on client
+ assert "Authorization" in client.headers
+ assert client.headers["Authorization"] == "Bearer test_token"
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
index e691a74232..e97bbdc227 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacBatch.scala
@@ -65,6 +65,9 @@ case class StacBatch(
val mapper = new ObjectMapper()
+ // Parse headers from options for authenticated requests
+ private val headers: Map[String, String] = StacUtils.parseHeaders(opts)
+
/**
* Sets the maximum number of items left to process.
*
@@ -168,7 +171,7 @@ case class StacBatch(
var nextUrl: Option[String] = Some(itemUrl)
breakable {
while (nextUrl.isDefined) {
- val itemJson = StacUtils.loadStacCollectionToJson(nextUrl.get)
+ val itemJson = StacUtils.loadStacCollectionToJson(nextUrl.get,
headers)
val itemRootNode = mapper.readTree(itemJson)
// Check if there exists a "next" link
val itemLinksNode = itemRootNode.get("links")
@@ -252,7 +255,7 @@ case class StacBatch(
collectionBasePath + href
}
// Recursively process the linked collection
- val linkedCollectionJson = StacUtils.loadStacCollectionToJson(childUrl)
+ val linkedCollectionJson =
StacUtils.loadStacCollectionToJson(childUrl, headers)
val nestedCollectionBasePath =
StacUtils.getStacCollectionBasePath(childUrl)
val collectionFiltered =
filterCollection(linkedCollectionJson, spatialFilter, temporalFilter)
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala
index cb5f46f5a2..64960a4ed3 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacPartitionReader.scala
@@ -51,6 +51,7 @@ class StacPartitionReader(
private var currentFile: File = _
private var featureIterator: Iterator[InternalRow] = Iterator.empty
private val mapper = new ObjectMapper()
+ private val headers = StacUtils.parseHeaders(opts)
override def next(): Boolean = {
if (featureIterator.hasNext) {
@@ -159,7 +160,39 @@ class StacPartitionReader(
while (attempt < maxRetries && !success) {
try {
- fileContent = Source.fromURL(url).mkString
+ if (headers.isEmpty) {
+ fileContent = Source.fromURL(url).mkString
+ } else {
+ val connection = url.openConnection()
+ var inputStream: java.io.InputStream = null
+ var source: Source = null
+
+ try {
+ headers.foreach { case (key, value) =>
+ connection.setRequestProperty(key, value)
+ }
+ inputStream = connection.getInputStream
+ source = Source.fromInputStream(inputStream)
+ fileContent = source.mkString
+ } finally {
+ // Close resources in reverse order
+ if (source != null) {
+ try source.close()
+ catch { case _: Throwable => }
+ }
+ if (inputStream != null) {
+ try inputStream.close()
+ catch { case _: Throwable => }
+ }
+ // Disconnect HTTP connection if applicable
+ connection match {
+ case httpConn: java.net.HttpURLConnection =>
+ try httpConn.disconnect()
+ catch { case _: Throwable => }
+ case _ =>
+ }
+ }
+ }
success = true
} catch {
case e: Exception =>
diff --git
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
index 7b88b3c9ad..515d4f0667 100644
---
a/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
+++
b/spark/common/src/main/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtils.scala
@@ -33,11 +33,44 @@ import scala.io.Source
object StacUtils {
+ // Reusable ObjectMapper instance to avoid expensive creation on each
parseHeaders call
+ private val objectMapper: ObjectMapper = {
+ val mapper = new ObjectMapper()
+ mapper.registerModule(DefaultScalaModule)
+ mapper
+ }
+
// Function to load JSON from URL or service
def loadStacCollectionToJson(opts: Map[String, String]): String = {
val urlFull: String = getFullCollectionUrl(opts)
+ val headers: Map[String, String] = parseHeaders(opts)
- loadStacCollectionToJson(urlFull)
+ loadStacCollectionToJson(urlFull, headers)
+ }
+
+ /**
+ * Parse headers from the options map.
+ *
+ * Headers can be provided as a JSON string in the "headers" option.
+ *
+ * @param opts
+ * The options map that may contain a "headers" key with JSON-encoded
headers
+ * @return
+ * Map of header names to values
+ */
+ def parseHeaders(opts: Map[String, String]): Map[String, String] = {
+ opts.get("headers") match {
+ case Some(headersJson) =>
+ try {
+ objectMapper.readValue(headersJson, classOf[Map[String, String]])
+ } catch {
+ case e: Exception =>
+ throw new IllegalArgumentException(
+ s"Failed to parse headers JSON: ${e.getMessage}",
+ e)
+ }
+ case None => Map.empty[String, String]
+ }
}
def getFullCollectionUrl(opts: Map[String, String]) = {
@@ -50,8 +83,11 @@ object StacUtils {
urlFinal
}
- // Function to load JSON from URL or service
- def loadStacCollectionToJson(url: String, maxRetries: Int = 3): String = {
+ // Function to load JSON from URL or service with optional headers
+ def loadStacCollectionToJson(
+ url: String,
+ headers: Map[String, String] = Map.empty,
+ maxRetries: Int = 3): String = {
var retries = 0
var success = false
var result: String = ""
@@ -59,9 +95,45 @@ object StacUtils {
while (retries < maxRetries && !success) {
try {
result = if (url.startsWith("s3://") || url.startsWith("s3a://")) {
+ // S3 URLs are handled by Spark
SparkSession.active.read.textFile(url).collect().mkString("\n")
- } else {
+ } else if (headers.isEmpty) {
+ // No headers - use the simple Source.fromURL approach for backward
compatibility
Source.fromURL(url).mkString
+ } else {
+ // Headers provided - use URLConnection to set custom headers
+ val connection = new java.net.URL(url).openConnection()
+ var inputStream: java.io.InputStream = null
+ var source: Source = null
+
+ try {
+ // Set all custom headers
+ headers.foreach { case (key, value) =>
+ connection.setRequestProperty(key, value)
+ }
+
+ // Read the response
+ inputStream = connection.getInputStream
+ source = Source.fromInputStream(inputStream)
+ source.mkString
+ } finally {
+ // Close resources in reverse order
+ if (source != null) {
+ try source.close()
+ catch { case _: Throwable => }
+ }
+ if (inputStream != null) {
+ try inputStream.close()
+ catch { case _: Throwable => }
+ }
+ // Disconnect HTTP connection if applicable
+ connection match {
+ case httpConn: java.net.HttpURLConnection =>
+ try httpConn.disconnect()
+ catch { case _: Throwable => }
+ case _ =>
+ }
+ }
}
success = true
} catch {
@@ -78,6 +150,11 @@ object StacUtils {
result
}
+ // Overloaded version for backward compatibility
+ def loadStacCollectionToJson(url: String, maxRetries: Int): String = {
+ loadStacCollectionToJson(url, Map.empty[String, String], maxRetries)
+ }
+
// Function to get the base URL from the collection URL or service
def getStacCollectionBasePath(opts: Map[String, String]): String = {
val ref = opts.getOrElse(
diff --git
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
index 50051e8258..b26cc6b1f6 100644
---
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
+++
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacDataSourceTest.scala
@@ -163,6 +163,141 @@ class StacDataSourceTest extends TestBaseScala {
assert(rowCount == 0)
}
+ it("should load STAC data with public service without authentication") {
+ // This test verifies backward compatibility with public STAC services
+ // Set STAC_PUBLIC_URL environment variable to test (e.g., Microsoft
Planetary Computer)
+ // Example:
https://planetarycomputer.microsoft.com/api/stac/v1/collections/naip
+ val publicUrl = sys.env.get("STAC_PUBLIC_URL")
+
+ if (publicUrl.isDefined) {
+ val dfStac = sparkSession.read
+ .format("stac")
+ .load(publicUrl.get)
+ .limit(10)
+
+ // Verify we can load data
+ assert(dfStac.count() >= 0, "Failed to load data from public STAC
service")
+ } else {
+ // Skip test if environment variable is not set
+ cancel(
+ "Skipping public STAC service test - set STAC_PUBLIC_URL to run (e.g.,
https://planetarycomputer.microsoft.com/api/stac/v1/collections/naip)")
+ }
+ }
+
+ // Authentication tests for remote services
+ it("should load STAC data with bearer token authentication") {
+ // This test verifies loading STAC data from services that support Bearer
token authentication.
+ // Set STAC_AUTH_URL and STAC_BEARER_TOKEN environment variables to test
with a compatible service.
+ // Note: Not all STAC services support Bearer tokens; some may require
other authentication methods.
+ val authUrl = sys.env.get("STAC_AUTH_URL")
+ val bearerToken = sys.env.get("STAC_BEARER_TOKEN")
+
+ if (authUrl.isDefined && bearerToken.isDefined) {
+ val headersJson = s"""{"Authorization":"Bearer ${bearerToken.get}"}"""
+
+ val dfStac = sparkSession.read
+ .format("stac")
+ .option("headers", headersJson)
+ .load(authUrl.get)
+ .limit(10)
+
+ // Verify we can load data
+ assert(dfStac.count() >= 0, "Failed to load data with bearer token
authentication")
+ } else {
+ // Skip test if environment variables are not set
+ cancel(
+ "Skipping bearer token authentication test - set STAC_AUTH_URL and
STAC_BEARER_TOKEN to run")
+ }
+ }
+
+ it("should load STAC data with basic authentication") {
+ // This test works with services that support HTTP Basic Authentication.
+ // For example, Planet Labs API uses API key as username with an empty
password.
+ // Environment variables required:
+ // STAC_AUTH_URL - The URL of the authenticated STAC service
+ // STAC_USERNAME - Username or API key
+ // STAC_PASSWORD - Password (can be empty for API keys)
+ val authUrl = sys.env.get("STAC_AUTH_URL")
+ val username = sys.env.get("STAC_USERNAME")
+ val password = sys.env.get("STAC_PASSWORD").getOrElse("")
+
+ if (authUrl.isDefined && username.isDefined) {
+ // Encode credentials as Base64
+ val credentials = s"${username.get}:$password"
+ val base64Credentials =
+
java.util.Base64.getEncoder.encodeToString(credentials.getBytes("UTF-8"))
+ val headersJson = s"""{"Authorization":"Basic $base64Credentials"}"""
+
+ val dfStac = sparkSession.read
+ .format("stac")
+ .option("headers", headersJson)
+ .load(authUrl.get)
+ .limit(10)
+
+ // Verify we can load data
+ assert(dfStac.count() >= 0, "Failed to load data with basic
authentication")
+ assertSchema(dfStac.schema)
+ } else {
+ // Skip test if environment variables are not set
+ cancel("Skipping basic authentication test - set STAC_AUTH_URL and
STAC_USERNAME to run")
+ }
+ }
+
+ it("should fail gracefully when authentication is required but not
provided") {
+ // This test verifies that we get a proper error when accessing
+ // an authenticated endpoint without credentials
+ val authUrl = sys.env.get("STAC_AUTH_URL_REQUIRE_AUTH")
+
+ if (authUrl.isDefined) {
+ val exception = intercept[Exception] {
+ val dfStac = sparkSession.read
+ .format("stac")
+ .load(authUrl.get)
+ .limit(10)
+ dfStac.count()
+ }
+
+ // Verify we get an authentication-related error
+ val errorMessage = exception.getMessage.toLowerCase
+ assert(
+ errorMessage.contains("401") ||
+ errorMessage.contains("unauthorized") ||
+ errorMessage.contains("403") ||
+ errorMessage.contains("forbidden"),
+ s"Expected authentication error, but got: ${exception.getMessage}")
+ } else {
+ cancel("Skipping authentication failure test - set
STAC_AUTH_URL_REQUIRE_AUTH to run")
+ }
+ }
+
+ it("should parse headers JSON correctly") {
+ // Test that headers are correctly parsed from JSON
+ val headersJson = """{"Authorization":"Bearer
test_token","X-Custom":"value"}"""
+ val headers = StacUtils.parseHeaders(Map("headers" -> headersJson))
+
+ assert(headers.size == 2, "Headers should contain 2 entries")
+ assert(headers("Authorization") == "Bearer test_token", "Authorization
header should match")
+ assert(headers("X-Custom") == "value", "Custom header should match")
+ }
+
+ it("should handle empty headers JSON") {
+ val headersJson = """{}"""
+ val headers = StacUtils.parseHeaders(Map("headers" -> headersJson))
+ assert(headers.isEmpty, "Empty JSON should result in empty headers map")
+ }
+
+ it("should handle missing headers option") {
+ val headers = StacUtils.parseHeaders(Map.empty[String, String])
+ assert(headers.isEmpty, "Missing headers option should result in empty
headers map")
+ }
+
+ it("should throw error for invalid headers JSON") {
+ val invalidJson = """{"Authorization":"Bearer token", invalid}"""
+ assertThrows[IllegalArgumentException] {
+ StacUtils.parseHeaders(Map("headers" -> invalidJson))
+ }
+ }
+
def assertSchema(actualSchema: StructType): Unit = {
// Base STAC fields that should always be present
val baseFields = Seq(
diff --git
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala
index e635235cd5..da6bb9ba91 100644
---
a/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala
+++
b/spark/common/src/test/scala/org/apache/spark/sql/sedona_sql/io/stac/StacUtilsTest.scala
@@ -717,4 +717,69 @@ class StacUtilsTest extends AnyFunSuite {
val expectedUrl =
s"$baseUrl&bbox=1.0%2C3.0%2C2.0%2C4.0&datetime=2025-03-06T00:00:00.000Z/.."
assert(result == expectedUrl)
}
+
+ // Tests for authentication headers
+
+ test("parseHeaders should return empty map when no headers option provided")
{
+ val opts = Map.empty[String, String]
+ val result = StacUtils.parseHeaders(opts)
+ assert(result.isEmpty)
+ }
+
+ test("parseHeaders should parse JSON-encoded headers correctly") {
+ val headersJson = """{"Authorization":"Bearer
token123","X-Custom":"value"}"""
+ val opts = Map("headers" -> headersJson)
+ val result = StacUtils.parseHeaders(opts)
+
+ assert(result.size == 2)
+ assert(result("Authorization") == "Bearer token123")
+ assert(result("X-Custom") == "value")
+ }
+
+ test("parseHeaders should parse basic authentication header") {
+ val headersJson = """{"Authorization":"Basic dGVzdHVzZXI6dGVzdHBhc3M="}"""
+ val opts = Map("headers" -> headersJson)
+ val result = StacUtils.parseHeaders(opts)
+
+ assert(result.size == 1)
+ assert(result("Authorization") == "Basic dGVzdHVzZXI6dGVzdHBhc3M=")
+ }
+
+ test("parseHeaders should handle empty headers JSON") {
+ val headersJson = """{}"""
+ val opts = Map("headers" -> headersJson)
+ val result = StacUtils.parseHeaders(opts)
+
+ assert(result.isEmpty)
+ }
+
+ test("parseHeaders should throw IllegalArgumentException for invalid JSON") {
+ val headersJson = """invalid json"""
+ val opts = Map("headers" -> headersJson)
+
+ assertThrows[IllegalArgumentException] {
+ StacUtils.parseHeaders(opts)
+ }
+ }
+
+ test("parseHeaders should handle multiple custom headers") {
+ val headersJson =
+ """{"Authorization":"Bearer
token","X-API-Key":"key123","User-Agent":"TestClient/1.0"}"""
+ val opts = Map("headers" -> headersJson)
+ val result = StacUtils.parseHeaders(opts)
+
+ assert(result.size == 3)
+ assert(result("Authorization") == "Bearer token")
+ assert(result("X-API-Key") == "key123")
+ assert(result("User-Agent") == "TestClient/1.0")
+ }
+
+ test("loadStacCollectionToJson with headers should pass headers to HTTP
request") {
+ // This test verifies the method signature accepts headers
+ // Actual HTTP behavior would require a mock server, which is tested in
integration tests
+ val opts = Map("path" -> "file:///tmp/collection.json", "headers" ->
"""{"X-Test":"value"}""")
+ val headers = StacUtils.parseHeaders(opts)
+
+ assert(headers("X-Test") == "value")
+ }
}