This is an automated email from the ASF dual-hosted git repository. vavila pushed a commit to branch fix/oauth-fixes in repository https://gitbox.apache.org/repos/asf/superset.git
commit 8e776d0833858c9ddc26845e87f0bae60454662f Author: Vitor Avila <[email protected]> AuthorDate: Wed Jan 21 23:43:57 2026 -0300 fix: DB OAuth2 fixes --- superset/db_engine_specs/base.py | 20 +++++++++---- superset/models/core.py | 4 +++ superset/utils/oauth2.py | 5 +++- tests/unit_tests/models/core_test.py | 56 ++++++++++++++++++++++++++++++++++++ 4 files changed, 78 insertions(+), 7 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 4d72be41d8a..4113a3e8fe5 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -717,9 +717,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "redirect_uri": config["redirect_uri"], "grant_type": "authorization_code", } - if config["request_content_type"] == "data": - return requests.post(uri, data=req_body, timeout=timeout).json() - return requests.post(uri, json=req_body, timeout=timeout).json() + response = ( + requests.post(uri, data=req_body, timeout=timeout) + if config["request_content_type"] == "data" + else requests.post(uri, json=req_body, timeout=timeout) + ) + response.raise_for_status() + return response.json() @classmethod def get_oauth2_fresh_token( @@ -738,9 +742,13 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods "refresh_token": refresh_token, "grant_type": "refresh_token", } - if config["request_content_type"] == "data": - return requests.post(uri, data=req_body, timeout=timeout).json() - return requests.post(uri, json=req_body, timeout=timeout).json() + response = ( + requests.post(uri, data=req_body, timeout=timeout) + if config["request_content_type"] == "data" + else requests.post(uri, json=req_body, timeout=timeout) + ) + response.raise_for_status() + return response.json() @classmethod def get_allows_alias_in_select( diff --git a/superset/models/core.py b/superset/models/core.py index cb7bdf2d352..fd813ebe0c5 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -1261,6 +1261,10 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: if oauth2_client_info := encrypted_extra.get("oauth2_client_info"): schema = OAuth2ClientConfigSchema() client_config = schema.load(oauth2_client_info) + if "request_content_type" not in oauth2_client_info: + client_config["request_content_type"] = ( + self.db_engine_spec.oauth2_token_request_type + ) return cast(OAuth2ClientConfig, client_config) return self.db_engine_spec.get_oauth2_config() diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index ebe1f4012eb..cd1a2a14d9e 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -189,7 +189,10 @@ class OAuth2ClientConfigSchema(Schema): scope = fields.String(required=True) redirect_uri = fields.String( required=False, - load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True), + load_default=lambda: app.config.get( + "DATABASE_OAUTH2_REDIRECT_URI", + url_for("DatabaseRestApi.oauth2", _external=True), + ), ) authorization_request_uri = fields.String(required=True) token_request_uri = fields.String(required=True) diff --git a/tests/unit_tests/models/core_test.py b/tests/unit_tests/models/core_test.py index 998a1033bb0..0e0f3d8ae2f 100644 --- a/tests/unit_tests/models/core_test.py +++ b/tests/unit_tests/models/core_test.py @@ -661,6 +661,62 @@ def test_get_oauth2_config(app_context: None) -> None: assert database.get_oauth2_config() is None database.encrypted_extra = json.dumps(oauth2_client_info) + assert database.get_oauth2_config() == { + "id": "my_client_id", + "secret": "my_client_secret", + "authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize", + "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", + "scope": "refresh_token session:role:USERADMIN", + "redirect_uri": "http://example.com/api/v1/database/oauth2/", + "request_content_type": "data", # Default value from BaseEngineSpec + } + + +def test_get_oauth2_config_token_request_type_from_db_engine_specs( + mocker: MockerFixture, app_context: None +) -> None: + """ + Test that DB Engine Spec overrides for ``oauth2_token_request_type`` are respected. + """ + database = Database( + database_name="db", + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + mocker.patch.object( + database.db_engine_spec, + "oauth2_token_request_type", + "json", + ) + + database.encrypted_extra = json.dumps(oauth2_client_info) + assert database.get_oauth2_config() == { + "id": "my_client_id", + "secret": "my_client_secret", + "authorization_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/authorize", + "token_request_uri": "https://abcd1234.snowflakecomputing.com/oauth/token-request", + "scope": "refresh_token session:role:USERADMIN", + "redirect_uri": "http://example.com/api/v1/database/oauth2/", + "request_content_type": "json", + } + + +def test_get_oauth2_config_custom_token_request_type_extra(app_context: None) -> None: + """ + Test passing a custom ``token_request_type`` via ``encrypted_extra`` + takes precedence. + """ + database = Database( + database_name="db", + sqlalchemy_uri="postgresql://user:password@host:5432/examples", + ) + custom_oauth2_client_info = { + "oauth2_client_info": { + **oauth2_client_info["oauth2_client_info"], + "request_content_type": "json", + } + } + + database.encrypted_extra = json.dumps(custom_oauth2_client_info) assert database.get_oauth2_config() == { "id": "my_client_id", "secret": "my_client_secret",
