This is an automated email from the ASF dual-hosted git repository.
vavila pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new cc972cad5ac fix: DB OAuth2 fixes (#37350)
cc972cad5ac is described below
commit cc972cad5ac19328d77662a4e377a10d5a8b2eac
Author: Vitor Avila <[email protected]>
AuthorDate: Thu Jan 22 01:51:48 2026 -0300
fix: DB OAuth2 fixes (#37350)
---
superset/db_engine_specs/base.py | 20 ++++++---
superset/models/core.py | 34 ++++++++-------
superset/utils/oauth2.py | 5 ++-
tests/unit_tests/models/core_test.py | 81 ++++++++++++++++++++++++++++++++++++
4 files changed, 118 insertions(+), 22 deletions(-)
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 4d72be41d8a..4113a3e8fe5 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -717,9 +717,13 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
}
- if config["request_content_type"] == "data":
- return requests.post(uri, data=req_body, timeout=timeout).json()
- return requests.post(uri, json=req_body, timeout=timeout).json()
+ response = (
+ requests.post(uri, data=req_body, timeout=timeout)
+ if config["request_content_type"] == "data"
+ else requests.post(uri, json=req_body, timeout=timeout)
+ )
+ response.raise_for_status()
+ return response.json()
@classmethod
def get_oauth2_fresh_token(
@@ -738,9 +742,13 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
"refresh_token": refresh_token,
"grant_type": "refresh_token",
}
- if config["request_content_type"] == "data":
- return requests.post(uri, data=req_body, timeout=timeout).json()
- return requests.post(uri, json=req_body, timeout=timeout).json()
+ response = (
+ requests.post(uri, data=req_body, timeout=timeout)
+ if config["request_content_type"] == "data"
+ else requests.post(uri, json=req_body, timeout=timeout)
+ )
+ response.raise_for_status()
+ return response.json()
@classmethod
def get_allows_alias_in_select(
diff --git a/superset/models/core.py b/superset/models/core.py
index cb7bdf2d352..d13c14b65ab 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -896,9 +896,7 @@ class Database(CoreDatabase, AuditMixinNullable,
ImportExportMixin): # pylint:
)
}
except Exception as ex:
- if self.is_oauth2_enabled() and
self.db_engine_spec.needs_oauth2(ex):
- self.start_oauth2_dance()
-
+ self._handle_oauth2_error(ex)
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@cache_util.memoized_func(
@@ -933,9 +931,7 @@ class Database(CoreDatabase, AuditMixinNullable,
ImportExportMixin): # pylint:
)
}
except Exception as ex:
- if self.is_oauth2_enabled() and
self.db_engine_spec.needs_oauth2(ex):
- self.start_oauth2_dance()
-
+ self._handle_oauth2_error(ex)
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@cache_util.memoized_func(
@@ -972,9 +968,7 @@ class Database(CoreDatabase, AuditMixinNullable,
ImportExportMixin): # pylint:
)
}
except Exception as ex:
- if self.is_oauth2_enabled() and
self.db_engine_spec.needs_oauth2(ex):
- self.start_oauth2_dance()
-
+ self._handle_oauth2_error(ex)
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
return set()
@@ -1003,9 +997,7 @@ class Database(CoreDatabase, AuditMixinNullable,
ImportExportMixin): # pylint:
with self.get_inspector(catalog=catalog) as inspector:
return self.db_engine_spec.get_schema_names(inspector)
except Exception as ex:
- if self.is_oauth2_enabled() and
self.db_engine_spec.needs_oauth2(ex):
- self.start_oauth2_dance()
-
+ self._handle_oauth2_error(ex)
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@cache_util.memoized_func(
@@ -1022,9 +1014,7 @@ class Database(CoreDatabase, AuditMixinNullable,
ImportExportMixin): # pylint:
with self.get_inspector() as inspector:
return self.db_engine_spec.get_catalog_names(self, inspector)
except Exception as ex:
- if self.is_oauth2_enabled() and
self.db_engine_spec.needs_oauth2(ex):
- self.start_oauth2_dance()
-
+ self._handle_oauth2_error(ex)
raise self.db_engine_spec.get_dbapi_mapped_exception(ex) from ex
@property
@@ -1261,6 +1251,10 @@ class Database(CoreDatabase, AuditMixinNullable,
ImportExportMixin): # pylint:
if oauth2_client_info := encrypted_extra.get("oauth2_client_info"):
schema = OAuth2ClientConfigSchema()
client_config = schema.load(oauth2_client_info)
+ if "request_content_type" not in oauth2_client_info:
+ client_config["request_content_type"] = (
+ self.db_engine_spec.oauth2_token_request_type
+ )
return cast(OAuth2ClientConfig, client_config)
return self.db_engine_spec.get_oauth2_config()
@@ -1275,6 +1269,16 @@ class Database(CoreDatabase, AuditMixinNullable,
ImportExportMixin): # pylint:
"""
return self.db_engine_spec.start_oauth2_dance(self)
+ def _handle_oauth2_error(self, ex: Exception) -> None:
+ """
+ Handle exceptions that may require OAuth2 authentication.
+
+ If OAuth2 is enabled and the exception indicates that OAuth2 is needed,
+ starts the OAuth2 dance.
+ """
+ if self.is_oauth2_enabled() and self.db_engine_spec.needs_oauth2(ex):
+ self.start_oauth2_dance()
+
def purge_oauth2_tokens(self) -> None:
"""
Delete all OAuth2 tokens associated with this database.
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index ebe1f4012eb..cd1a2a14d9e 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -189,7 +189,10 @@ class OAuth2ClientConfigSchema(Schema):
scope = fields.String(required=True)
redirect_uri = fields.String(
required=False,
- load_default=lambda: url_for("DatabaseRestApi.oauth2", _external=True),
+ load_default=lambda: app.config.get(
+ "DATABASE_OAUTH2_REDIRECT_URI",
+ url_for("DatabaseRestApi.oauth2", _external=True),
+ ),
)
authorization_request_uri = fields.String(required=True)
token_request_uri = fields.String(required=True)
diff --git a/tests/unit_tests/models/core_test.py
b/tests/unit_tests/models/core_test.py
index 998a1033bb0..7d7aa96ea19 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -660,6 +660,34 @@ def test_get_oauth2_config(app_context: None) -> None:
assert database.get_oauth2_config() is None
+ database.encrypted_extra = json.dumps(oauth2_client_info)
+ assert database.get_oauth2_config() == {
+ "id": "my_client_id",
+ "secret": "my_client_secret",
+ "authorization_request_uri":
"https://abcd1234.snowflakecomputing.com/oauth/authorize",
+ "token_request_uri":
"https://abcd1234.snowflakecomputing.com/oauth/token-request",
+ "scope": "refresh_token session:role:USERADMIN",
+ "redirect_uri": "http://example.com/api/v1/database/oauth2/",
+ "request_content_type": "data", # Default value from BaseEngineSpec
+ }
+
+
+def test_get_oauth2_config_token_request_type_from_db_engine_specs(
+ mocker: MockerFixture, app_context: None
+) -> None:
+ """
+ Test that DB Engine Spec overrides for ``oauth2_token_request_type`` are
respected.
+ """
+ database = Database(
+ database_name="db",
+ sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+ )
+ mocker.patch.object(
+ database.db_engine_spec,
+ "oauth2_token_request_type",
+ "json",
+ )
+
database.encrypted_extra = json.dumps(oauth2_client_info)
assert database.get_oauth2_config() == {
"id": "my_client_id",
@@ -672,6 +700,59 @@ def test_get_oauth2_config(app_context: None) -> None:
}
+def test_get_oauth2_config_custom_token_request_type_extra(app_context: None)
-> None:
+ """
+ Test passing a custom ``token_request_type`` via ``encrypted_extra``
+ takes precedence.
+ """
+ database = Database(
+ database_name="db",
+ sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+ )
+ custom_oauth2_client_info = {
+ "oauth2_client_info": {
+ **oauth2_client_info["oauth2_client_info"],
+ "request_content_type": "json",
+ }
+ }
+
+ database.encrypted_extra = json.dumps(custom_oauth2_client_info)
+ assert database.get_oauth2_config() == {
+ "id": "my_client_id",
+ "secret": "my_client_secret",
+ "authorization_request_uri":
"https://abcd1234.snowflakecomputing.com/oauth/authorize",
+ "token_request_uri":
"https://abcd1234.snowflakecomputing.com/oauth/token-request",
+ "scope": "refresh_token session:role:USERADMIN",
+ "redirect_uri": "http://example.com/api/v1/database/oauth2/",
+ "request_content_type": "json",
+ }
+
+
+def test_get_oauth2_config_redirect_uri_from_config(
+ mocker: MockerFixture,
+ app_context: None,
+) -> None:
+ """
+ Test that ``DATABASE_OAUTH2_REDIRECT_URI`` config takes precedence over
+ url_for default.
+ """
+ custom_redirect_uri = "https://custom.example.com/oauth/callback"
+ mocker.patch.dict(
+ "superset.utils.oauth2.app.config",
+ {"DATABASE_OAUTH2_REDIRECT_URI": custom_redirect_uri},
+ )
+ database = Database(
+ database_name="db",
+ sqlalchemy_uri="postgresql://user:password@host:5432/examples",
+ )
+ database.encrypted_extra = json.dumps(oauth2_client_info)
+
+ config = database.get_oauth2_config()
+
+ assert config is not None
+ assert config["redirect_uri"] == custom_redirect_uri
+
+
def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
"""
Test that we can start OAuth2 from `raw_connection()` errors.