This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch aws-iam in repository https://gitbox.apache.org/repos/asf/superset.git
commit 64656baf356d780fb243f3521d0eb8028d71d709 Author: Beto Dealmeida <[email protected]> AuthorDate: Thu Jan 22 19:34:43 2026 -0500 Phase 2 --- .pre-commit-config.yaml | 1 + superset/db_engine_specs/aurora.py | 13 + superset/db_engine_specs/aws_iam.py | 188 +++++++++- superset/db_engine_specs/mysql.py | 54 ++- superset/db_engine_specs/postgres.py | 8 +- superset/db_engine_specs/redshift.py | 40 +++ tests/unit_tests/db_engine_specs/test_aurora.py | 74 ++++ tests/unit_tests/db_engine_specs/test_aws_iam.py | 378 ++++++++++++++++++++- .../{test_aurora.py => test_mysql_iam.py} | 159 +++++---- .../db_engine_specs/test_redshift_iam.py | 245 +++++++++++++ 10 files changed, 1066 insertions(+), 94 deletions(-) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 0ca94fca54f..d44a63abfdd 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -27,6 +27,7 @@ repos: args: [--check-untyped-defs] exclude: ^superset-extensions-cli/ additional_dependencies: [ + types-cachetools, types-simplejson, types-python-dateutil, types-requests, diff --git a/superset/db_engine_specs/aurora.py b/superset/db_engine_specs/aurora.py index c7bcbc77a45..afc8652103e 100644 --- a/superset/db_engine_specs/aurora.py +++ b/superset/db_engine_specs/aurora.py @@ -44,6 +44,19 @@ class AuroraPostgresDataAPI(PostgresEngineSpec): ) +class AuroraMySQLEngineSpec(MySQLEngineSpec): + """ + Aurora MySQL engine spec. + + IAM authentication is handled by the parent MySQLEngineSpec via + the aws_iam config in encrypted_extra. + """ + + engine = "mysql" + engine_name = "Aurora MySQL" + default_driver = "mysqldb" + + class AuroraPostgresEngineSpec(PostgresEngineSpec): """ Aurora PostgreSQL engine spec. diff --git a/superset/db_engine_specs/aws_iam.py b/superset/db_engine_specs/aws_iam.py index 74ee58a9d92..29145b811a5 100644 --- a/superset/db_engine_specs/aws_iam.py +++ b/superset/db_engine_specs/aws_iam.py @@ -21,14 +21,19 @@ This mixin provides cross-account IAM authentication support for AWS databases (Aurora PostgreSQL, Aurora MySQL, Redshift). It handles: - Assuming IAM roles via STS AssumeRole - Generating RDS IAM auth tokens +- Generating Redshift Serverless credentials - Configuring SSL (required for IAM auth) +- Caching STS credentials to reduce API calls """ from __future__ import annotations import logging +import threading from typing import Any, TYPE_CHECKING, TypedDict +from cachetools import TTLCache + from superset.databases.utils import make_url_safe from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetSecurityException @@ -41,8 +46,16 @@ logger = logging.getLogger(__name__) # Default session duration for STS AssumeRole (1 hour) DEFAULT_SESSION_DURATION = 3600 -# Default port for PostgreSQL +# Default ports DEFAULT_POSTGRES_PORT = 5432 +DEFAULT_MYSQL_PORT = 3306 +DEFAULT_REDSHIFT_PORT = 5439 + +# Cache STS credentials: key = (role_arn, region, external_id), TTL = 50 min +_credentials_cache: TTLCache[tuple[str, str, str | None], dict[str, Any]] = TTLCache( + maxsize=100, ttl=3000 +) +_credentials_lock = threading.RLock() class AWSIAMConfig(TypedDict, total=False): @@ -54,6 +67,9 @@ class AWSIAMConfig(TypedDict, total=False): region: str db_username: str session_duration: int + # Redshift Serverless fields + workgroup_name: str + db_name: str class AWSIAMAuthMixin: @@ -110,7 +126,11 @@ class AWSIAMAuthMixin: session_duration: int = DEFAULT_SESSION_DURATION, ) -> dict[str, Any]: """ - Assume cross-account IAM role via STS AssumeRole. + Assume cross-account IAM role via STS AssumeRole with credential caching. + + Credentials are cached by (role_arn, region, external_id) with a 50-minute + TTL to reduce STS API calls while ensuring tokens are refreshed before the + default 1-hour expiration. :param role_arn: The ARN of the IAM role to assume :param region: AWS region for the STS client @@ -119,6 +139,13 @@ class AWSIAMAuthMixin: :returns: Dictionary with AccessKeyId, SecretAccessKey, SessionToken :raises SupersetSecurityException: If role assumption fails """ + cache_key = (role_arn, region, external_id) + + with _credentials_lock: + cached = _credentials_cache.get(cache_key) + if cached is not None: + return cached + try: # Lazy import to avoid errors when boto3 is not installed import boto3 @@ -145,7 +172,12 @@ class AWSIAMAuthMixin: assume_role_kwargs["ExternalId"] = external_id response = sts_client.assume_role(**assume_role_kwargs) - return response["Credentials"] + credentials = response["Credentials"] + + with _credentials_lock: + _credentials_cache[cache_key] = credentials + + return credentials except ClientError as ex: error_code = ex.response.get("Error", {}).get("Code", "") @@ -238,12 +270,68 @@ class AWSIAMAuthMixin: ) ) from ex + @classmethod + def generate_redshift_credentials( + cls, + credentials: dict[str, Any], + workgroup_name: str, + db_name: str, + region: str, + ) -> tuple[str, str]: + """ + Generate Redshift Serverless credentials using temporary STS credentials. + + :param credentials: STS credentials from assume_role + :param workgroup_name: Redshift Serverless workgroup name + :param db_name: Redshift database name + :param region: AWS region + :returns: Tuple of (username, password) for Redshift connection + :raises SupersetSecurityException: If credential generation fails + """ + try: + import boto3 + from botocore.exceptions import ClientError + except ImportError as ex: + raise SupersetSecurityException( + SupersetError( + message="boto3 is required for AWS IAM authentication.", + error_type=SupersetErrorType.GENERIC_DB_ENGINE_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + + try: + client = boto3.client( + "redshift-serverless", + region_name=region, + aws_access_key_id=credentials["AccessKeyId"], + aws_secret_access_key=credentials["SecretAccessKey"], + aws_session_token=credentials["SessionToken"], + ) + + response = client.get_credentials( + workgroupName=workgroup_name, + dbName=db_name, + ) + return response["dbUser"], response["dbPassword"] + + except ClientError as ex: + raise SupersetSecurityException( + SupersetError( + message=f"Failed to get Redshift credentials: {ex}", + error_type=SupersetErrorType.CONNECTION_ACCESS_DENIED_ERROR, + level=ErrorLevel.ERROR, + ) + ) from ex + @classmethod def _apply_iam_authentication( cls, database: Database, params: dict[str, Any], iam_config: AWSIAMConfig, + ssl_args: dict[str, str] | None = None, + default_port: int = DEFAULT_POSTGRES_PORT, ) -> None: """ Apply IAM authentication to the connection parameters. @@ -253,8 +341,13 @@ class AWSIAMAuthMixin: :param database: Database model instance :param params: Engine parameters dict to modify :param iam_config: IAM configuration from encrypted_extra + :param ssl_args: SSL args to apply (defaults to sslmode=require) + :param default_port: Default port if not specified in URI :raises SupersetSecurityException: If any step fails """ + if ssl_args is None: + ssl_args = {"sslmode": "require"} + # Extract configuration role_arn = iam_config.get("role_arn") region = iam_config.get("region") @@ -289,7 +382,7 @@ class AWSIAMAuthMixin: # Get hostname and port from the database URI uri = make_url_safe(database.sqlalchemy_uri_decrypted) hostname = uri.host - port = uri.port or DEFAULT_POSTGRES_PORT + port = uri.port or default_port if not hostname: raise SupersetSecurityException( @@ -336,8 +429,89 @@ class AWSIAMAuthMixin: connect_args["user"] = db_username # Step 4: Enable SSL (required for IAM authentication) - # sslmode=require ensures encrypted connection without cert verification - # For production, consider sslmode=verify-full with RDS CA bundle - connect_args["sslmode"] = "require" + connect_args.update(ssl_args) logger.debug("IAM authentication configured successfully") + + @classmethod + def _apply_redshift_iam_authentication( + cls, + database: Database, + params: dict[str, Any], + iam_config: AWSIAMConfig, + ) -> None: + """ + Apply Redshift Serverless IAM authentication to connection parameters. + + Flow: assume role -> get Redshift credentials -> update connect_args -> SSL. + + :param database: Database model instance + :param params: Engine parameters dict to modify + :param iam_config: IAM configuration from encrypted_extra + :raises SupersetSecurityException: If any step fails + """ + # Extract configuration + role_arn = iam_config.get("role_arn") + region = iam_config.get("region") + external_id = iam_config.get("external_id") + session_duration = iam_config.get("session_duration", DEFAULT_SESSION_DURATION) + workgroup_name = iam_config.get("workgroup_name") + db_name = iam_config.get("db_name") + + # Validate required fields + missing_fields = [] + if not role_arn: + missing_fields.append("role_arn") + if not region: + missing_fields.append("region") + if not workgroup_name: + missing_fields.append("workgroup_name") + if not db_name: + missing_fields.append("db_name") + + if missing_fields: + raise SupersetSecurityException( + SupersetError( + message="AWS IAM configuration missing required fields: " + f"{', '.join(missing_fields)}", + error_type=SupersetErrorType.CONNECTION_MISSING_PARAMETERS_ERROR, + level=ErrorLevel.ERROR, + ) + ) + + # Type assertions after validation + assert role_arn is not None + assert region is not None + assert workgroup_name is not None + assert db_name is not None + + logger.debug( + "Applying Redshift IAM authentication for workgroup %s", + workgroup_name, + ) + + # Step 1: Assume the IAM role + credentials = cls.get_iam_credentials( + role_arn=role_arn, + region=region, + external_id=external_id, + session_duration=session_duration, + ) + + # Step 2: Get Redshift Serverless credentials + db_user, db_password = cls.generate_redshift_credentials( + credentials=credentials, + workgroup_name=workgroup_name, + db_name=db_name, + region=region, + ) + + # Step 3: Update connection parameters + connect_args = params.setdefault("connect_args", {}) + connect_args["password"] = db_password + connect_args["user"] = db_user + + # Step 4: Enable SSL (required for Redshift IAM authentication) + connect_args["sslmode"] = "verify-ca" + + logger.debug("Redshift IAM authentication configured successfully") diff --git a/superset/db_engine_specs/mysql.py b/superset/db_engine_specs/mysql.py index 8c071f8dab6..0996ee920d0 100644 --- a/superset/db_engine_specs/mysql.py +++ b/superset/db_engine_specs/mysql.py @@ -14,12 +14,15 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. +from __future__ import annotations + import contextlib +import logging import re from datetime import datetime from decimal import Decimal from re import Pattern -from typing import Any, Callable, Optional +from typing import Any, Callable, Optional, TYPE_CHECKING from urllib import parse from flask_babel import gettext as __ @@ -42,8 +45,14 @@ from superset.constants import TimeGrain from superset.db_engine_specs.base import BaseEngineSpec, BasicParametersMixin from superset.errors import SupersetErrorType from superset.models.sql_lab import Query +from superset.utils import json from superset.utils.core import GenericDataType +if TYPE_CHECKING: + from superset.models.core import Database + +logger = logging.getLogger(__name__) + # Regular expressions to catch custom errors CONNECTION_ACCESS_DENIED_REGEX = re.compile( "Access denied for user '(?P<username>.*?)'@'(?P<hostname>.*?)'" @@ -192,6 +201,49 @@ class MySQLEngineSpec(BasicParametersMixin, BaseEngineSpec): "mysqlconnector": {"allow_local_infile": 0}, } + # Sensitive fields that should be masked in encrypted_extra + encrypted_extra_sensitive_fields = { + "$.aws_iam.external_id", + "$.aws_iam.role_arn", + } + + @staticmethod + def update_params_from_encrypted_extra( + database: Database, + params: dict[str, Any], + ) -> None: + """ + Extract sensitive parameters from encrypted_extra. + + Handles AWS IAM authentication if configured, then merges any + remaining encrypted_extra keys into params. + """ + if not database.encrypted_extra: + return + + try: + encrypted_extra = json.loads(database.encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise + + # Handle AWS IAM auth: pop the key so it doesn't reach create_engine() + iam_config = encrypted_extra.pop("aws_iam", None) + if iam_config and iam_config.get("enabled"): + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + AWSIAMAuthMixin._apply_iam_authentication( + database, + params, + iam_config, + ssl_args={"ssl_mode": "REQUIRED"}, + default_port=3306, + ) + + # Standard behavior: merge remaining keys into params + if encrypted_extra: + params.update(encrypted_extra) + @classmethod def convert_dttm( cls, target_type: str, dttm: datetime, db_extra: Optional[dict[str, Any]] = None diff --git a/superset/db_engine_specs/postgres.py b/superset/db_engine_specs/postgres.py index 28e08b8f2e8..5dfdb6ce7fe 100644 --- a/superset/db_engine_specs/postgres.py +++ b/superset/db_engine_specs/postgres.py @@ -351,7 +351,13 @@ class PostgresEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): if iam_config and iam_config.get("enabled"): from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin - AWSIAMAuthMixin._apply_iam_authentication(database, params, iam_config) + AWSIAMAuthMixin._apply_iam_authentication( + database, + params, + iam_config, + ssl_args={"sslmode": "require"}, + default_port=5432, + ) # Standard behavior: merge remaining keys into params if encrypted_extra: diff --git a/superset/db_engine_specs/redshift.py b/superset/db_engine_specs/redshift.py index 49dcd4e983c..3c251e93849 100644 --- a/superset/db_engine_specs/redshift.py +++ b/superset/db_engine_specs/redshift.py @@ -31,6 +31,7 @@ from superset.errors import SupersetErrorType from superset.models.core import Database from superset.models.sql_lab import Query from superset.sql.parse import Table +from superset.utils import json logger = logging.getLogger() @@ -103,6 +104,45 @@ class RedshiftEngineSpec(BasicParametersMixin, PostgresBaseEngineSpec): ), } + # Sensitive fields that should be masked in encrypted_extra + encrypted_extra_sensitive_fields = { + "$.aws_iam.external_id", + "$.aws_iam.role_arn", + } + + @staticmethod + def update_params_from_encrypted_extra( + database: Database, + params: dict[str, Any], + ) -> None: + """ + Extract sensitive parameters from encrypted_extra. + + Handles AWS IAM authentication for Redshift Serverless if configured, + then merges any remaining encrypted_extra keys into params. + """ + if not database.encrypted_extra: + return + + try: + encrypted_extra = json.loads(database.encrypted_extra) + except json.JSONDecodeError as ex: + logger.error(ex, exc_info=True) + raise + + # Handle AWS IAM auth: pop the key so it doesn't reach create_engine() + iam_config = encrypted_extra.pop("aws_iam", None) + if iam_config and iam_config.get("enabled"): + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + AWSIAMAuthMixin._apply_redshift_iam_authentication( + database, params, iam_config + ) + + # Standard behavior: merge remaining keys into params + if encrypted_extra: + params.update(encrypted_extra) + @classmethod def df_to_sql( cls, diff --git a/tests/unit_tests/db_engine_specs/test_aurora.py b/tests/unit_tests/db_engine_specs/test_aurora.py index a225f3641e4..b3dfc6940f7 100644 --- a/tests/unit_tests/db_engine_specs/test_aurora.py +++ b/tests/unit_tests/db_engine_specs/test_aurora.py @@ -223,6 +223,80 @@ def test_aurora_postgres_inherits_from_postgres() -> None: assert AuroraPostgresEngineSpec.supports_catalog is True +def test_aurora_mysql_engine_spec_properties() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + + assert AuroraMySQLEngineSpec.engine == "mysql" + assert AuroraMySQLEngineSpec.engine_name == "Aurora MySQL" + assert AuroraMySQLEngineSpec.default_driver == "mysqldb" + + +def test_aurora_mysql_inherits_from_mysql() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + from superset.db_engine_specs.mysql import MySQLEngineSpec + + assert issubclass(AuroraMySQLEngineSpec, MySQLEngineSpec) + assert AuroraMySQLEngineSpec.supports_dynamic_schema is True + + +def test_aurora_mysql_has_iam_support() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + + # Verify it inherits encrypted_extra_sensitive_fields + assert ( + "$.aws_iam.external_id" + in AuroraMySQLEngineSpec.encrypted_extra_sensitive_fields + ) + assert ( + "$.aws_iam.role_arn" in AuroraMySQLEngineSpec.encrypted_extra_sensitive_fields + ) + + +def test_aurora_mysql_update_params_from_encrypted_extra_with_iam() -> None: + from superset.db_engine_specs.aurora import AuroraMySQLEngineSpec + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "mysql://[email protected]:3306/mydb" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + AuroraMySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 + assert params["connect_args"]["user"] == "superset_iam_user" + assert params["connect_args"]["ssl_mode"] == "REQUIRED" + + def test_aurora_data_api_classes_unchanged() -> None: from superset.db_engine_specs.aurora import ( AuroraMySQLDataAPI, diff --git a/tests/unit_tests/db_engine_specs/test_aws_iam.py b/tests/unit_tests/db_engine_specs/test_aws_iam.py index f7bd9dd3793..607b5a0fbf1 100644 --- a/tests/unit_tests/db_engine_specs/test_aws_iam.py +++ b/tests/unit_tests/db_engine_specs/test_aws_iam.py @@ -89,7 +89,14 @@ def test_get_iam_credentials_with_external_id() -> None: def test_get_iam_credentials_access_denied() -> None: from botocore.exceptions import ClientError - from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + with _credentials_lock: + _credentials_cache.clear() with patch("boto3.client") as mock_boto3_client: mock_sts = MagicMock() @@ -361,7 +368,14 @@ def test_apply_iam_authentication_default_port() -> None: def test_get_iam_credentials_boto3_not_installed() -> None: import sys - from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + with _credentials_lock: + _credentials_cache.clear() # Temporarily hide boto3 boto3_module = sys.modules.get("boto3") @@ -381,3 +395,363 @@ def test_get_iam_credentials_boto3_not_installed() -> None: sys.modules["boto3"] = boto3_module else: del sys.modules["boto3"] + + +def test_get_iam_credentials_caching() -> None: + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + mock_credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + # Clear cache before test + with _credentials_lock: + _credentials_cache.clear() + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.return_value = {"Credentials": mock_credentials} + mock_boto3_client.return_value = mock_sts + + # First call should hit STS + result1 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/CachedRole", + region="us-east-1", + ) + + # Second call should use cache + result2 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::123456789012:role/CachedRole", + region="us-east-1", + ) + + assert result1 == mock_credentials + assert result2 == mock_credentials + # STS should only be called once + mock_sts.assume_role.assert_called_once() + + # Clean up + with _credentials_lock: + _credentials_cache.clear() + + +def test_get_iam_credentials_cache_different_keys() -> None: + from superset.db_engine_specs.aws_iam import ( + _credentials_cache, + _credentials_lock, + AWSIAMAuthMixin, + ) + + creds_role1 = { + "AccessKeyId": "ASIA_ROLE1", + "SecretAccessKey": "secret1", + "SessionToken": "token1", + } + creds_role2 = { + "AccessKeyId": "ASIA_ROLE2", + "SecretAccessKey": "secret2", + "SessionToken": "token2", + } + + # Clear cache before test + with _credentials_lock: + _credentials_cache.clear() + + with patch("boto3.client") as mock_boto3_client: + mock_sts = MagicMock() + mock_sts.assume_role.side_effect = [ + {"Credentials": creds_role1}, + {"Credentials": creds_role2}, + ] + mock_boto3_client.return_value = mock_sts + + result1 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::111111111111:role/Role1", + region="us-east-1", + ) + result2 = AWSIAMAuthMixin.get_iam_credentials( + role_arn="arn:aws:iam::222222222222:role/Role2", + region="us-east-1", + ) + + assert result1 == creds_role1 + assert result2 == creds_role2 + # Both calls should hit STS (different cache keys) + assert mock_sts.assume_role.call_count == 2 + + # Clean up + with _credentials_lock: + _credentials_cache.clear() + + +def test_apply_iam_authentication_custom_ssl_args() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "mysql://[email protected]:3306/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ), + ): + AWSIAMAuthMixin._apply_iam_authentication( + mock_database, + params, + iam_config, + ssl_args={"ssl_mode": "REQUIRED"}, + default_port=3306, + ) + + assert params["connect_args"]["ssl_mode"] == "REQUIRED" + assert "sslmode" not in params["connect_args"] + + +def test_apply_iam_authentication_custom_default_port() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + # URI without explicit port + mock_database.sqlalchemy_uri_decrypted = ( + "mysql://[email protected]/mydb" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + AWSIAMAuthMixin._apply_iam_authentication( + mock_database, + params, + iam_config, + default_port=3306, + ) + + token_call_kwargs = mock_gen_token.call_args[1] + assert token_call_kwargs["port"] == 3306 + + +def test_generate_redshift_credentials() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_credentials.return_value = { + "dbUser": "IAM:admin", + "dbPassword": "redshift-temp-password", + } + mock_boto3_client.return_value = mock_redshift + + db_user, db_password = AWSIAMAuthMixin.generate_redshift_credentials( + credentials=credentials, + workgroup_name="my-workgroup", + db_name="dev", + region="us-east-1", + ) + + assert db_user == "IAM:admin" + assert db_password == "redshift-temp-password" # noqa: S105 + mock_boto3_client.assert_called_once_with( + "redshift-serverless", + region_name="us-east-1", + aws_access_key_id="ASIA...", + aws_secret_access_key="secret...", # noqa: S106 + aws_session_token="token...", # noqa: S106 + ) + mock_redshift.get_credentials.assert_called_once_with( + workgroupName="my-workgroup", + dbName="dev", + ) + + +def test_generate_redshift_credentials_client_error() -> None: + from botocore.exceptions import ClientError + + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + + credentials = { + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + } + + with patch("boto3.client") as mock_boto3_client: + mock_redshift = MagicMock() + mock_redshift.get_credentials.side_effect = ClientError( + {"Error": {"Code": "AccessDenied", "Message": "Access Denied"}}, + "GetCredentials", + ) + mock_boto3_client.return_value = mock_redshift + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin.generate_redshift_credentials( + credentials=credentials, + workgroup_name="my-workgroup", + db_name="dev", + region="us-east-1", + ) + + assert "Failed to get Redshift credentials" in str(exc_info.value) + + +def test_apply_redshift_iam_authentication() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://[email protected]" + ".redshift-serverless.amazonaws.com:5439/dev" + ) + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_redshift_credentials", + return_value=("IAM:admin", "redshift-temp-password"), + ) as mock_gen_creds, + ): + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::123456789012:role/RedshiftRole", + region="us-east-1", + external_id=None, + session_duration=3600, + ) + + mock_gen_creds.assert_called_once_with( + credentials={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + workgroup_name="my-workgroup", + db_name="dev", + region="us-east-1", + ) + + assert params["connect_args"]["password"] == "redshift-temp-password" # noqa: S105 + assert params["connect_args"]["user"] == "IAM:admin" + assert params["connect_args"]["sslmode"] == "verify-ca" + + +def test_apply_redshift_iam_authentication_missing_workgroup() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "db_name": "dev", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "workgroup_name" in str(exc_info.value) + + +def test_apply_redshift_iam_authentication_missing_db_name() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin, AWSIAMConfig + + mock_database = MagicMock() + mock_database.sqlalchemy_uri_decrypted = "redshift+psycopg2://user@host:5439/dev" + + iam_config: AWSIAMConfig = { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + } + + params: dict[str, Any] = {} + + with pytest.raises(SupersetSecurityException) as exc_info: + AWSIAMAuthMixin._apply_redshift_iam_authentication( + mock_database, params, iam_config + ) + + assert "db_name" in str(exc_info.value) diff --git a/tests/unit_tests/db_engine_specs/test_aurora.py b/tests/unit_tests/db_engine_specs/test_mysql_iam.py similarity index 51% copy from tests/unit_tests/db_engine_specs/test_aurora.py copy to tests/unit_tests/db_engine_specs/test_mysql_iam.py index a225f3641e4..b6231dee3f1 100644 --- a/tests/unit_tests/db_engine_specs/test_aurora.py +++ b/tests/unit_tests/db_engine_specs/test_mysql_iam.py @@ -27,32 +27,39 @@ import pytest from superset.utils import json -def test_aurora_postgres_engine_spec_properties() -> None: - from superset.db_engine_specs.aurora import AuroraPostgresEngineSpec +def test_mysql_encrypted_extra_sensitive_fields() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec - assert AuroraPostgresEngineSpec.engine == "postgresql" - assert AuroraPostgresEngineSpec.engine_name == "Aurora PostgreSQL" - assert AuroraPostgresEngineSpec.default_driver == "psycopg2" + assert "$.aws_iam.external_id" in MySQLEngineSpec.encrypted_extra_sensitive_fields + assert "$.aws_iam.role_arn" in MySQLEngineSpec.encrypted_extra_sensitive_fields -def test_update_params_from_encrypted_extra_without_iam() -> None: - from superset.db_engine_specs.postgres import PostgresEngineSpec +def test_mysql_update_params_no_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec + + database = MagicMock() + database.encrypted_extra = None + + params: dict[str, Any] = {} + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_mysql_update_params_empty_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec database = MagicMock() database.encrypted_extra = json.dumps({}) - database.sqlalchemy_uri_decrypted = ( - "postgresql://user:[email protected]:5432/mydb" - ) params: dict[str, Any] = {} - PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) - # No modifications should be made assert params == {} -def test_update_params_from_encrypted_extra_iam_disabled() -> None: - from superset.db_engine_specs.postgres import PostgresEngineSpec +def test_mysql_update_params_iam_disabled() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec database = MagicMock() database.encrypted_extra = json.dumps( @@ -65,20 +72,16 @@ def test_update_params_from_encrypted_extra_iam_disabled() -> None: } } ) - database.sqlalchemy_uri_decrypted = ( - "postgresql://user:[email protected]:5432/mydb" - ) params: dict[str, Any] = {} - PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) - # No modifications should be made when IAM is disabled assert params == {} -def test_update_params_from_encrypted_extra_with_iam() -> None: +def test_mysql_update_params_with_iam() -> None: from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin - from superset.db_engine_specs.postgres import PostgresEngineSpec + from superset.db_engine_specs.mysql import MySQLEngineSpec database = MagicMock() database.encrypted_extra = json.dumps( @@ -92,7 +95,7 @@ def test_update_params_from_encrypted_extra_with_iam() -> None: } ) database.sqlalchemy_uri_decrypted = ( - "postgresql://[email protected]:5432/mydb" + "mysql://[email protected]:3306/mydb" ) params: dict[str, Any] = {} @@ -113,51 +116,79 @@ def test_update_params_from_encrypted_extra_with_iam() -> None: return_value="iam-auth-token", ), ): - PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) assert "connect_args" in params assert params["connect_args"]["password"] == "iam-auth-token" # noqa: S105 assert params["connect_args"]["user"] == "superset_iam_user" - assert params["connect_args"]["sslmode"] == "require" + assert params["connect_args"]["ssl_mode"] == "REQUIRED" -def test_update_params_merges_remaining_encrypted_extra() -> None: - from superset.db_engine_specs.postgres import PostgresEngineSpec +def test_mysql_update_params_iam_uses_mysql_port() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.mysql import MySQLEngineSpec database = MagicMock() database.encrypted_extra = json.dumps( { - "aws_iam": {"enabled": False}, - "pool_size": 10, + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "db_username": "superset_iam_user", + } } ) + # URI without explicit port database.sqlalchemy_uri_decrypted = ( - "postgresql://user:[email protected]:5432/mydb" + "mysql://[email protected]/mydb" ) params: dict[str, Any] = {} - PostgresEngineSpec.update_params_from_encrypted_extra(database, params) - # aws_iam should be consumed, pool_size should be merged - assert "aws_iam" not in params - assert params["pool_size"] == 10 + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_rds_auth_token", + return_value="iam-auth-token", + ) as mock_gen_token, + ): + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) + + # Should use default MySQL port 3306 + token_call_kwargs = mock_gen_token.call_args[1] + assert token_call_kwargs["port"] == 3306 -def test_update_params_from_encrypted_extra_no_encrypted_extra() -> None: - from superset.db_engine_specs.postgres import PostgresEngineSpec +def test_mysql_update_params_merges_remaining_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec database = MagicMock() - database.encrypted_extra = None + database.encrypted_extra = json.dumps( + { + "aws_iam": {"enabled": False}, + "pool_size": 10, + } + ) params: dict[str, Any] = {} - PostgresEngineSpec.update_params_from_encrypted_extra(database, params) + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) - # No modifications should be made - assert params == {} + assert "aws_iam" not in params + assert params["pool_size"] == 10 -def test_update_params_from_encrypted_extra_invalid_json() -> None: - from superset.db_engine_specs.postgres import PostgresEngineSpec +def test_mysql_update_params_invalid_json() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec database = MagicMock() database.encrypted_extra = "not-valid-json" @@ -165,21 +196,11 @@ def test_update_params_from_encrypted_extra_invalid_json() -> None: params: dict[str, Any] = {} with pytest.raises(json.JSONDecodeError): - PostgresEngineSpec.update_params_from_encrypted_extra(database, params) - - -def test_encrypted_extra_sensitive_fields() -> None: - from superset.db_engine_specs.postgres import PostgresEngineSpec + MySQLEngineSpec.update_params_from_encrypted_extra(database, params) - # Verify sensitive fields are properly defined - assert ( - "$.aws_iam.external_id" in PostgresEngineSpec.encrypted_extra_sensitive_fields - ) - assert "$.aws_iam.role_arn" in PostgresEngineSpec.encrypted_extra_sensitive_fields - -def test_mask_encrypted_extra() -> None: - from superset.db_engine_specs.postgres import PostgresEngineSpec +def test_mysql_mask_encrypted_extra() -> None: + from superset.db_engine_specs.mysql import MySQLEngineSpec encrypted_extra = json.dumps( { @@ -193,7 +214,7 @@ def test_mask_encrypted_extra() -> None: } ) - masked = PostgresEngineSpec.mask_encrypted_extra(encrypted_extra) + masked = MySQLEngineSpec.mask_encrypted_extra(encrypted_extra) assert masked is not None masked_config = json.loads(masked) @@ -209,31 +230,3 @@ def test_mask_encrypted_extra() -> None: assert masked_config["aws_iam"]["enabled"] is True assert masked_config["aws_iam"]["region"] == "us-east-1" assert masked_config["aws_iam"]["db_username"] == "superset_user" - - -def test_aurora_postgres_inherits_from_postgres() -> None: - from superset.db_engine_specs.aurora import AuroraPostgresEngineSpec - from superset.db_engine_specs.postgres import PostgresEngineSpec - - # Verify inheritance - assert issubclass(AuroraPostgresEngineSpec, PostgresEngineSpec) - - # Verify it inherits PostgreSQL capabilities - assert AuroraPostgresEngineSpec.supports_dynamic_schema is True - assert AuroraPostgresEngineSpec.supports_catalog is True - - -def test_aurora_data_api_classes_unchanged() -> None: - from superset.db_engine_specs.aurora import ( - AuroraMySQLDataAPI, - AuroraPostgresDataAPI, - ) - - # Verify Data API classes are still available and unchanged - assert AuroraMySQLDataAPI.engine == "mysql" - assert AuroraMySQLDataAPI.default_driver == "auroradataapi" - assert AuroraMySQLDataAPI.engine_name == "Aurora MySQL (Data API)" - - assert AuroraPostgresDataAPI.engine == "postgresql" - assert AuroraPostgresDataAPI.default_driver == "auroradataapi" - assert AuroraPostgresDataAPI.engine_name == "Aurora PostgreSQL (Data API)" diff --git a/tests/unit_tests/db_engine_specs/test_redshift_iam.py b/tests/unit_tests/db_engine_specs/test_redshift_iam.py new file mode 100644 index 00000000000..3f0da7e734e --- /dev/null +++ b/tests/unit_tests/db_engine_specs/test_redshift_iam.py @@ -0,0 +1,245 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +# pylint: disable=import-outside-toplevel + +from __future__ import annotations + +from typing import Any +from unittest.mock import MagicMock, patch + +import pytest + +from superset.utils import json + + +def test_redshift_encrypted_extra_sensitive_fields() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + assert ( + "$.aws_iam.external_id" in RedshiftEngineSpec.encrypted_extra_sensitive_fields + ) + assert "$.aws_iam.role_arn" in RedshiftEngineSpec.encrypted_extra_sensitive_fields + + +def test_redshift_update_params_no_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = None + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_redshift_update_params_empty_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps({}) + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_redshift_update_params_iam_disabled() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": False, + "role_arn": "arn:aws:iam::123456789012:role/TestRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + } + ) + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert params == {} + + +def test_redshift_update_params_with_iam() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/RedshiftRole", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://[email protected]" + ".redshift-serverless.amazonaws.com:5439/dev" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ), + patch.object( + AWSIAMAuthMixin, + "generate_redshift_credentials", + return_value=("IAM:admin", "redshift-temp-password"), + ), + ): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "connect_args" in params + assert params["connect_args"]["password"] == "redshift-temp-password" # noqa: S105 + assert params["connect_args"]["user"] == "IAM:admin" + assert params["connect_args"]["sslmode"] == "verify-ca" + + +def test_redshift_update_params_with_external_id() -> None: + from superset.db_engine_specs.aws_iam import AWSIAMAuthMixin + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::222222222222:role/CrossAccountRedshift", + "external_id": "superset-prod-12345", + "region": "us-west-2", + "workgroup_name": "prod-workgroup", + "db_name": "analytics", + "session_duration": 1800, + } + } + ) + database.sqlalchemy_uri_decrypted = ( + "redshift+psycopg2://[email protected]" + ".redshift-serverless.amazonaws.com:5439/analytics" + ) + + params: dict[str, Any] = {} + + with ( + patch.object( + AWSIAMAuthMixin, + "get_iam_credentials", + return_value={ + "AccessKeyId": "ASIA...", + "SecretAccessKey": "secret...", + "SessionToken": "token...", + }, + ) as mock_get_creds, + patch.object( + AWSIAMAuthMixin, + "generate_redshift_credentials", + return_value=("IAM:admin", "redshift-temp-password"), + ), + ): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + mock_get_creds.assert_called_once_with( + role_arn="arn:aws:iam::222222222222:role/CrossAccountRedshift", + region="us-west-2", + external_id="superset-prod-12345", + session_duration=1800, + ) + + +def test_redshift_update_params_merges_remaining_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = json.dumps( + { + "aws_iam": {"enabled": False}, + "pool_size": 5, + } + ) + + params: dict[str, Any] = {} + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + assert "aws_iam" not in params + assert params["pool_size"] == 5 + + +def test_redshift_update_params_invalid_json() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + database = MagicMock() + database.encrypted_extra = "not-valid-json" + + params: dict[str, Any] = {} + + with pytest.raises(json.JSONDecodeError): + RedshiftEngineSpec.update_params_from_encrypted_extra(database, params) + + +def test_redshift_mask_encrypted_extra() -> None: + from superset.db_engine_specs.redshift import RedshiftEngineSpec + + encrypted_extra = json.dumps( + { + "aws_iam": { + "enabled": True, + "role_arn": "arn:aws:iam::123456789012:role/SecretRole", + "external_id": "secret-external-id-12345", + "region": "us-east-1", + "workgroup_name": "my-workgroup", + "db_name": "dev", + } + } + ) + + masked = RedshiftEngineSpec.mask_encrypted_extra(encrypted_extra) + assert masked is not None + + masked_config = json.loads(masked) + + # role_arn and external_id should be masked + assert ( + masked_config["aws_iam"]["role_arn"] + != "arn:aws:iam::123456789012:role/SecretRole" + ) + assert masked_config["aws_iam"]["external_id"] != "secret-external-id-12345" + + # Non-sensitive fields should remain unchanged + assert masked_config["aws_iam"]["enabled"] is True + assert masked_config["aws_iam"]["region"] == "us-east-1" + assert masked_config["aws_iam"]["workgroup_name"] == "my-workgroup" + assert masked_config["aws_iam"]["db_name"] == "dev"
