This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch awm-iam in repository https://gitbox.apache.org/repos/asf/superset.git
commit 1c0429a0de4eef7e74fc53f9f7b43ffb95a5f30f Author: Beto Dealmeida <[email protected]> AuthorDate: Thu Jan 22 19:06:20 2026 -0500 feat(AWS IAM): phase 1 --- docker-compose-light.yml | 2 + superset/db_engine_specs/aurora.py | 13 + superset/db_engine_specs/aws_iam.py | 343 ++++++++++++++++++++ superset/db_engine_specs/postgres.py | 37 +++ tests/unit_tests/db_engine_specs/test_aurora.py | 239 ++++++++++++++ tests/unit_tests/db_engine_specs/test_aws_iam.py | 383 +++++++++++++++++++++++ 6 files changed, 1017 insertions(+) diff --git a/docker-compose-light.yml b/docker-compose-light.yml index b06be681af..f4dd30487e 100644 --- a/docker-compose-light.yml +++ b/docker-compose-light.yml @@ -108,6 +108,8 @@ services: extra_hosts: - "host.docker.internal:host-gateway" user: *superset-user + ports: + - "${SUPERSET_PORT:-8088}:8088" depends_on: superset-init-light: condition: service_completed_successfully diff --git a/superset/db_engine_specs/aurora.py b/superset/db_engine_specs/aurora.py index 0baaf1e9b1..c7bcbc77a4 100644 --- a/superset/db_engine_specs/aurora.py +++ b/superset/db_engine_specs/aurora.py @@ -42,3 +42,16 @@ class AuroraPostgresDataAPI(PostgresEngineSpec): "secret_arn={secret_arn}&" "region_name={region_name}" ) + + +class AuroraPostgresEngineSpec(PostgresEngineSpec): + """ + Aurora PostgreSQL engine spec. + + IAM authentication is handled by the parent PostgresEngineSpec via + the aws_iam config in encrypted_extra. + """ + + engine = "postgresql" + engine_name = "Aurora PostgreSQL" + default_driver = "psycopg2" diff --git a/superset/db_engine_specs/aws_iam.py b/superset/db_engine_specs/aws_iam.py new file mode 100644 index 0000000000..74ee58a9d9 --- /dev/null +++ b/superset/db_engine_specs/aws_iam.py @@ -0,0 +1,343 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +""" +AWS IAM Authentication Mixin for database engine specs. + +This mixin provides cross-account IAM authentication support for AWS databases +(Aurora PostgreSQL, Aurora MySQL, Redshift). It handles: +- Assuming IAM roles via STS AssumeRole +- Generating RDS IAM auth tokens +- Configuring SSL (required for IAM auth) +""" + +from __future__ import annotations + +import logging +from typing import Any, TYPE_CHECKING, TypedDict + +from superset.databases.utils import make_url_safe +from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import SupersetSecurityException + +if TYPE_CHECKING: + from superset.models.core import Database + +logger = logging.getLogger(__name__) + +# Default session duration for STS AssumeRole (1 hour) +DEFAULT_SESSION_DURATION = 3600 + +# Default port for PostgreSQL +DEFAULT_POSTGRES_PORT = 5432 + + +class AWSIAMConfig(TypedDict, total=False): + """Configuration for AWS IAM authentication.""" + + enabled: bool + role_arn: str + external_id: str + region: str + db_username: str + session_duration: int + + +class AWSIAMAuthMixin: + """ + Mixin that provides AWS IAM authentication for database connections. + + This mixin can be used with database engine specs that support IAM + authentication (Aurora PostgreSQL, Aurora MySQL, Redshift). + + Configuration is provided via the database's encrypted_extra JSON: + + { + "aws_iam": { + "enabled": true, + "role_arn": "arn:aws:iam::222222222222:role/SupersetDatabaseAccess", + "external_id": "superset-prod-12345", # optional + "region": "us-east-1", + "db_username": "superset_iam_user", + "session_duration": 3600 # optional, defaults to 3600 + } + } + """ + + supports_iam_authentication = True + + # AWS error patterns for actionable error messages + aws_iam_custom_errors: dict[str, tuple[SupersetErrorType, str]] = { + "AccessDenied": ( + SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + "Unable to assume IAM role. Verify the role ARN and trust policy " + "allow access from Superset's IAM role.", + ), + "InvalidIdentityToken": ( + SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + "Invalid IAM credentials. Ensure Superset has a valid IAM role " + "with permissions to assume the target role.", + ), + "MalformedPolicyDocument": ( + SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + "Invalid IAM role ARN format. Please verify the role ARN.", + ), + "ExpiredTokenException": ( + SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + "AWS credentials have expired. Please refresh the connection.", + ), + } + + @classmethod + def get_iam_credentials( + cls, + role_arn: str, + region: str, + external_id: str | None = None, + session_duration: int = DEFAULT_SESSION_DURATION, + ) -> dict[str, Any]: + """ + Assume cross-account IAM role via STS AssumeRole. + + :param role_arn: The ARN of the IAM role to assume + :param region: AWS region for the STS client + :param external_id: External ID for the role assumption (optional) + :param session_duration: Duration of the session in seconds + :returns: Dictionary with AccessKeyId, SecretAccessKey, SessionToken + :raises SupersetSecurityException: If role assumption fails + """ + try: + # Lazy import to avoid errors when boto3 is not installed + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication. " + "Install it with: pip install boto3", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + sts_client = boto3.client("sts", region_name=region) + + assume_role_kwargs: dict[str, Any] = { + "RoleArn": role_arn, + "RoleSessionName": "superset-iam-session", + "DurationSeconds": session_duration, + } + if external_id: + assume_role_kwargs["ExternalId"] = external_id + + response = sts_client.assume_role(**assume_role_kwargs) + return response["Credentials"] + + except ClientError as ex: + error_code = ex.response.get("Error", {}).get("Code", "") + error_message = ex.response.get("Error", {}).get("Message", "") + + # Handle ExternalId mismatch (shows as AccessDenied with specific message) + # Check this first before generic AccessDenied handling + if "external id" in error_message.lower(): + raise SupersetSecurityException( + SupersetError( + message="External ID mismatch. Verify the external_id " + "configuration matches the trust policy.", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + if error_code in cls.aws_iam_custom_errors: + error_type, message = cls.aws_iam_custom_errors[error_code] + raise SupersetSecurityException( + SupersetError( + message=message, + error_type=error_type, + level=ErrorLevel.ERROR, + ) + ) from ex + + raise SupersetSecurityException( + SupersetError( + message=f"Failed to assume IAM role: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + @classmethod + def generate_rds_auth_token( + cls, + credentials: dict[str, Any], + hostname: str, + port: int, + username: str, + region: str, + ) -> str: + """ + Generate RDS IAM auth token using temporary credentials. + + :param credentials: STS credentials from assume_role + :param hostname: RDS/Aurora endpoint hostname + :param port: Database port + :param username: Database username configured for IAM auth + :param region: AWS region + :returns: IAM auth token to use as database password + :raises SupersetSecurityException: If token generation fails + """ + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + rds_client = boto3.client( + "rds", + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + token = rds_client.generate_db_auth_token( + DBHostname=hostname, + Port=port, + DBUsername=username, + ) + return token + + except ClientError as ex: + raise SupersetSecurityException( + SupersetError( + message=f"Failed to generate RDS auth token: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + @classmethod + def _apply_iam_authentication( + cls, + database: Database, + params: dict[str, Any], + iam_config: AWSIAMConfig, + ) -> None: + """ + Apply IAM authentication to the connection parameters. + + Full flow: assume role -> generate token -> update connect_args -> enable SSL. + + :param database: Database model instance + :param params: Engine parameters dict to modify + :param iam_config: IAM configuration from encrypted_extra + :raises SupersetSecurityException: If any step fails + """ + # Extract configuration + role_arn = iam_config.get("role_arn") + region = iam_config.get("region") + db_username = iam_config.get("db_username") + external_id = iam_config.get("external_id") + session_duration = iam_config.get("session_duration", DEFAULT_SESSION_DURATION) + + # Validate required fields + missing_fields = [] + if not role_arn: + missing_fields.append("role_arn") + if not region: + missing_fields.append("region") + if not db_username: + missing_fields.append("db_username") + + if missing_fields: + raise SupersetSecurityException( + SupersetError( + message="AWS IAM configuration missing required fields: " + f"{', '.join(missing_fields)}", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Type assertions after validation (mypy doesn't narrow types from list check) + assert role_arn is not None + assert region is not None + assert db_username is not None + + # Get hostname and port from the database URI + uri = make_url_safe(database.sqlalchemy_uri_decrypted) + hostname = uri.host + port = uri.port or DEFAULT_POSTGRES_PORT + + if not hostname: + raise SupersetSecurityException( + SupersetError( + message=( + "Database URI must include a hostname for IAM authentication" + ), + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + logger.debug( + "Applying IAM authentication for %s:%d as user %s", + hostname, + port, + db_username, + ) + + # Step 1: Assume the IAM role + credentials = cls.get_iam_credentials( + role_arn=role_arn, + region=region, + external_id=external_id, + session_duration=session_duration, + ) + + # Step 2: Generate the RDS auth token + token = cls.generate_rds_auth_token( + credentials=credentials, + hostname=hostname, + port=port, + username=db_username, + region=region, + ) + + # Step 3: Update connection parameters + connect_args = params.setdefault("connect_args", {}) + + # Set the IAM token as the password + connect_args["password"] = token + + # Override username if different from URI + connect_args["user"] = db_username + + # Step 4: Enable SSL (required for IAM authentication) + # sslmode=require ensures encrypted connection without cert verification + # For production, consider sslmode=verify-full with RDS CA bundle + connect_args["sslmode"] = "require" + + logger.debug("IAM authentication configured successfully") diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index b259164e4f..28e08b8f2e 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -218,6 +218,12 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): max_column_name_length = 63 try_remove_schema_from_table_name = False # pylint: disable=invalid-name + # Sensitive fields that should be masked in encrypted_extra + encrypted_extra_sensitive_fields = { + "$.aws_iam.external_id", + "$.aws_iam.role_arn", + } + column_type_mappings = ( ( re.compile(r"^double precision", re.IGNORECASE), @@ -320,6 +326,37 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): return uri, connect_args + @staticmethod + def update_params_from_encrypted_extra( + database: Database, + params: dict[str, Any], + ) -> None: + """ + Extract sensitive parameters from encrypted_extra. + + Handles AWS IAM authentication if configured, then merges any + remaining encrypted_extra keys into params (standard behavior). + """ + if not database.encrypted_extra: + return + + try: + encrypted_extra = json.loads(database.encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise + + # Handle AWS IAM auth: pop the key so it doesn't reach create_engine() + iam_config = encrypted_extra.pop("aws_iam", None) + if iam_config and iam_config.get("enabled"): + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + AWSIAMAuthMixin._apply_iam_authentication(database, params, iam_config) + + # Standard behavior: merge remaining keys into params + if encrypted_extra: + params.update(encrypted_extra) + @classmethod def get_default_catalog(cls, database: Database) -> str: """ diff --git a/tests/unit_tests/db_engine_specs/test_aurora.py b/tests/unit_tests/db_engine_specs/test_aurora.py new file mode 100644 index 0000000000..a225f3641e --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_aurora.py @@ -0,0 +1,239 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.utils import json + + +def test_aurora_postgres_engine_spec_properties() -> None: + from superset.db_engine_specs.aurora import AuroraPostgresEngineSpec + + assert AuroraPostgresEngineSpec.engine == "postgresql" + assert AuroraPostgresEngineSpec.engine_name == "Aurora PostgreSQL" + assert AuroraPostgresEngineSpec.default_driver == "psycopg2" + + +def test_update_params_from_encrypted_extra_without_iam() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps({}) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user:[email protected]:5432/mydb" + ) + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # No modifications should be made + assert params == {} + + +def test_update_params_from_encrypted_extra_iam_disabled() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": False, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user:[email protected]:5432/mydb" + ) + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # No modifications should be made when IAM is disabled + assert params == {} + + +def test_update_params_from_encrypted_extra_with_iam() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "postgresql://[email protected]:5432/mydb" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + assert params["connect_args"]["sslmode"] == "require" + + +def test_update_params_merges_remaining_encrypted_extra() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": {"enabled": False}, + "pool_size": 10, + } + ) + database.sqlalchemy_uri_decrypted = ( + "postgresql://user:[email protected]:5432/mydb" + ) + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # aws_iam should be consumed, pool_size should be merged + assert "aws_iam" not in params + assert params["pool_size"] == 10 + + +def test_update_params_from_encrypted_extra_no_encrypted_extra() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = None + + params: dict[str, Any] = {} + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + # No modifications should be made + assert params == {} + + +def test_update_params_from_encrypted_extra_invalid_json() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + database = MagicMock() + database.encrypted_extra = "not-valid-json" + + params: dict[str, Any] = {} + + with pytest.raises(json.JSONDecodeError): + PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + + +def test_encrypted_extra_sensitive_fields() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + # Verify sensitive fields are properly defined + assert ( + "$.aws_iam.external_id" in PostgresEngineSpec.encrypted_extra_sensitive_fields + ) + assert "$.aws_iam.role_arn" in PostgresEngineSpec.encrypted_extra_sensitive_fields + + +def test_mask_encrypted_extra() -> None: + from superset.db_engine_specs.postgres import PostgresEngineSpec + + encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/SecretRole", + "external_id": "secret-external-id-12345", + "region": "us-east-1", + "db_username": "superset_user", + } + } + ) + + masked = PostgresEngineSpec.mask_encrypted_extra(encrypted_extra) + assert masked is not None + + masked_config = json.loads(masked) + + # role_arn and external_id should be masked + assert ( + masked_config["aws_iam"]["role_arn"] + != "arn:aws:iam::123456789012:role/SecretRole" + ) + assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345" + + # Non-sensitive fields should remain unchanged + assert masked_config["aws_iam"]["enabled"] is True + assert masked_config["aws_iam"]["region"] == "us-east-1" + assert masked_config["aws_iam"]["db_username"] == "superset_user" + + +def test_aurora_postgres_inherits_from_postgres() -> None: + from superset.db_engine_specs.aurora import AuroraPostgresEngineSpec + from superset.db_engine_specs.postgres import PostgresEngineSpec + + # Verify inheritance + assert issubclass(AuroraPostgresEngineSpec, PostgresEngineSpec) + + # Verify it inherits PostgreSQL capabilities + assert AuroraPostgresEngineSpec.supports_dynamic_schema is True + assert AuroraPostgresEngineSpec.supports_catalog is True + + +def test_aurora_data_api_classes_unchanged() -> None: + from superset.db_engine_specs.aurora import ( + AuroraMySQLDataAPI, + AuroraPostgresDataAPI, + ) + + # Verify Data API classes are still available and unchanged + assert AuroraMySQLDataAPI.engine == "mysql" + assert AuroraMySQLDataAPI.default_driver == "auroradataapi" + assert AuroraMySQLDataAPI.engine_name == "Aurora MySQL (Data API)" + + assert AuroraPostgresDataAPI.engine == "postgresql" + assert AuroraPostgresDataAPI.default_driver == "auroradataapi" + assert AuroraPostgresDataAPI.engine_name == "Aurora PostgreSQL (Data API)" diff --git a/tests/unit_tests/db_engine_specs/test_aws_iam.py b/tests/unit_tests/db_engine_specs/test_aws_iam.py new file mode 100644 index 0000000000..f7bd9dd379 --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_aws_iam.py @@ -0,0 +1,383 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel, protected-access + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.exceptions import SupersetSecurityException + + +def test_get_iam_credentials_success() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + mock_credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + "Expiration": "2025-01-01T00:00:00Z", + } + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.return_value = {"Credentials": mock_credentials} + mock_boto3_client.return_value = mock_sts + + credentials = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + ) + + assert credentials == mock_credentials + mock_boto3_client.assert_called_once_with("sts", region_name="us-east-1") + mock_sts.assume_role.assert_called_once_with( + RoleArn="arn:aws:iam::123456789012:role/TestRole", + RoleSessionName="superset-iam-session", + DurationSeconds=3600, + ) + + +def test_get_iam_credentials_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + mock_credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.return_value = {"Credentials": mock_credentials} + mock_boto3_client.return_value = mock_sts + + credentials = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-west-2", + external_id="external-id-12345", + session_duration=900, + ) + + assert credentials == mock_credentials + mock_sts.assume_role.assert_called_once_with( + RoleArn="arn:aws:iam::123456789012:role/TestRole", + RoleSessionName="superset-iam-session", + DurationSeconds=900, + ExternalId="external-id-12345", + ) + + +def test_get_iam_credentials_access_denied() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.side_effect = ClientError( + {"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, + "AssumeRole", + ) + mock_boto3_client.return_value = mock_sts + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + ) + + assert "Unable to assume IAM role" in str(exc_info.value) + + +def test_get_iam_credentials_external_id_mismatch() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.side_effect = ClientError( + { + "Error": { + "Code": "AccessDenied", + "Message": "The external id does not match", + } + }, + "AssumeRole", + ) + mock_boto3_client.return_value = mock_sts + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + external_id="wrong-id", + ) + + assert "External ID mismatch" in str(exc_info.value) + + +def test_generate_rds_auth_token() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_rds = MagicMock() + mock_rds.generate_db_auth_token.return_value = "iam-token-12345" + mock_boto3_client.return_value = mock_rds + + token = AWSIAMAuthMixin.generate_rds_auth_token( + credentials=credentials, + hostname="mydb.cluster-xyz.us-east-1.rds.amazonaws.com", + port=5432, + username="superset_user", + region="us-east-1", + ) + + assert token == "iam-token-12345" # noqa: S105 + mock_boto3_client.assert_called_once_with( + "rds", + region_name="us-east-1", + aws_access_key_id="ASIA...", + aws_secret_access_key="secret...", # noqa: S106 + aws_session_token="token...", # noqa: S106 + ) + mock_rds.generate_db_auth_token.assert_called_once_with( + DBHostname="mydb.cluster-xyz.us-east-1.rds.amazonaws.com", + Port=5432, + DBUsername="superset_user", + ) + + +def test_apply_iam_authentication() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://[email protected]:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + external_id=None, + session_duration=3600, + ) + + mock_gen_token.assert_called_once() + token_call_kwargs = mock_gen_token.call_args[1] + assert ( + token_call_kwargs["hostname"] == "mydb.cluster-xyz.us-east-1.rds.amazonaws.com" + ) + assert token_call_kwargs["port"] == 5432 + assert token_call_kwargs["username"] == "superset_iam_user" + + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + assert params["connect_args"]["sslmode"] == "require" + + +def test_apply_iam_authentication_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://[email protected]:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::222222222222:role/CrossAccountRole", + "external_id": "superset-prod-12345", + "region": "us-west-2", + "db_username": "iam_user", + "session_duration": 1800, + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::222222222222:role/CrossAccountRole", + region="us-west-2", + external_id="superset-prod-12345", + session_duration=1800, + ) + + +def test_apply_iam_authentication_missing_role_arn() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://[email protected]:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + assert "role_arn" in str(exc_info.value) + + +def test_apply_iam_authentication_missing_db_username() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://[email protected]:5432/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + assert "db_username" in str(exc_info.value) + + +def test_apply_iam_authentication_default_port() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + # URI without explicit port + mock_database.sqlalchemy_uri_decrypted = ( + "postgresql://[email protected]/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + AWSIAMAuthMixin._apply_iam_authentication(mock_database, params, iam_config) + + # Should use default port 5432 + token_call_kwargs = mock_gen_token.call_args[1] + assert token_call_kwargs["port"] == 5432 + + +def test_get_iam_credentials_boto3_not_installed() -> None: + import sys + + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + # Temporarily hide boto3 + boto3_module = sys.modules.get("boto3") + sys.modules["boto3"] = None # type: ignore + + try: + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/TestRole", + region="us-east-1", + ) + + assert "boto3 is required" in str(exc_info.value) + finally: + # Restore boto3 + if boto3_module is not None: + sys.modules["boto3"] = boto3_module + else: + del sys.modules["boto3"]
