This is an automated email from the ASF dual-hosted git repository.

vavila pushed a commit to branch fix/oauth-fixes
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 8e776d0833858c9ddc26845e87f0bae60454662f
Author: Vitor Avila <[email protected]>
AuthorDate: Wed Jan 21 23:43:57 2026 -0300

    fix: DB OAuth2 fixes
---
 superset/db_engine_specs/base.py     | 20 +++++++++----
 superset/models/core.py              |  4 +++
 superset/utils/oauth2.py             |  5 +++-
 tests/unit_tests/models/core_test.py | 56 ++++++++++++++++++++++++++++++++++++
 4 files changed, 78 insertions(+), 7 deletions(-)

diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 4d72be41d8a..4113a3e8fe5 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -717,9 +717,13 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             "redirect_uri": config["redirect_uri"],
             "grant_type": "authorization_code",
         }
-        if config["request_content_type"] == "data":
-            return requests.post(uri, data=req_body, timeout=timeout).json()
-        return requests.post(uri, json=req_body, timeout=timeout).json()
+        response = (
+            requests.post(uri, data=req_body, timeout=timeout)
+            if config["request_content_type"] == "data"
+            else requests.post(uri, json=req_body, timeout=timeout)
+        )
+        response.raise_for_status()
+        return response.json()
 
     @classmethod
     def get_oauth2_fresh_token(
@@ -738,9 +742,13 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             "refresh_token": refresh_token,
             "grant_type": "refresh_token",
         }
-        if config["request_content_type"] == "data":
-            return requests.post(uri, data=req_body, timeout=timeout).json()
-        return requests.post(uri, json=req_body, timeout=timeout).json()
+        response = (
+            requests.post(uri, data=req_body, timeout=timeout)
+            if config["request_content_type"] == "data"
+            else requests.post(uri, json=req_body, timeout=timeout)
+        )
+        response.raise_for_status()
+        return response.json()
 
     @classmethod
     def get_allows_alias_in_select(
diff --git a/superset/models/core.py b/superset/models/core.py
index cb7bdf2d352..fd813ebe0c5 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -1261,6 +1261,10 @@ class Database(CoreDatabase, AuditMixinNullable, 
ImportExportMixin):  # pylint:
         if oauth2_client_info := encrypted_extra.get("oauth2_client_info"):
             schema = OAuth2ClientConfigSchema()
             client_config = schema.load(oauth2_client_info)
+            if "request_content_type" not in oauth2_client_info:
+                client_config["request_content_type"] = (
+                    self.db_engine_spec.oauth2_token_request_type
+                )
             return cast(OAuth2ClientConfig, client_config)
 
         return self.db_engine_spec.get_oauth2_config()
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index ebe1f4012eb..cd1a2a14d9e 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -189,7 +189,10 @@ class OAuth2ClientConfigSchema(Schema):
     scope = fields.String(required=True)
     redirect_uri = fields.String(
         required=False,
-        load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True),
+        load_default=lambda: app.config.get(
+            "DATABASE_OAUTH2_REDIRECT_URI",
+            url_for("DatabaseRestApi.oauth2", _external=True),
+        ),
     )
     authorization_request_uri = fields.String(required=True)
     token_request_uri = fields.String(required=True)
diff --git a/tests/unit_tests/models/core_test.py 
b/tests/unit_tests/models/core_test.py
index 998a1033bb0..0e0f3d8ae2f 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -661,6 +661,62 @@ def test_get_oauth2_config(app_context: None) -> None:
     assert database.get_oauth2_config() is None
 
     database.encrypted_extra = json.dumps(oauth2_client_info)
+    assert database.get_oauth2_config() == {
+        "id": "my_client_id",
+        "secret": "my_client_secret",
+        "authorization_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/authorize";,
+        "token_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/token-request";,
+        "scope": "refresh_token session:role:USERADMIN",
+        "redirect_uri": "http://example.com/api/v1/database/oauth2/";,
+        "request_content_type": "data",  # Default value from BaseEngineSpec
+    }
+
+
+def test_get_oauth2_config_token_request_type_from_db_engine_specs(
+    mocker: MockerFixture, app_context: None
+) -> None:
+    """
+    Test that DB Engine Spec overrides for ``oauth2_token_request_type`` are 
respected.
+    """
+    database = Database(
+        database_name="db",
+        sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+    )
+    mocker.patch.object(
+        database.db_engine_spec,
+        "oauth2_token_request_type",
+        "json",
+    )
+
+    database.encrypted_extra = json.dumps(oauth2_client_info)
+    assert database.get_oauth2_config() == {
+        "id": "my_client_id",
+        "secret": "my_client_secret",
+        "authorization_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/authorize";,
+        "token_request_uri": 
"https://abcd1234.snowflakecomputing.com/oauth/token-request";,
+        "scope": "refresh_token session:role:USERADMIN",
+        "redirect_uri": "http://example.com/api/v1/database/oauth2/";,
+        "request_content_type": "json",
+    }
+
+
+def test_get_oauth2_config_custom_token_request_type_extra(app_context: None) 
-> None:
+    """
+    Test passing a custom ``token_request_type`` via ``encrypted_extra``
+    takes precedence.
+    """
+    database = Database(
+        database_name="db",
+        sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+    )
+    custom_oauth2_client_info = {
+        "oauth2_client_info": {
+            **oauth2_client_info["oauth2_client_info"],
+            "request_content_type": "json",
+        }
+    }
+
+    database.encrypted_extra = json.dumps(custom_oauth2_client_info)
     assert database.get_oauth2_config() == {
         "id": "my_client_id",
         "secret": "my_client_secret",

Reply via email to