This is an automated email from the ASF dual-hosted git repository. vavila pushed a commit to branch fix/oauth-fixes in repository https://gitbox.apache.org/repos/asf/superset.git
commit 0878883249cb74597ea3923eb7f162bde0dc4719 Author: Vitor Avila <[email protected]> AuthorDate: Fri Jan 23 13:13:13 2026 -0300 fix: more DB OAuth2 fixes --- superset/db_engine_specs/base.py | 5 +- superset/db_engine_specs/gsheets.py | 88 +++++- superset/utils/oauth2.py | 21 +- tests/unit_tests/db_engine_specs/test_base.py | 157 +++++----- tests/unit_tests/db_engine_specs/test_gsheets.py | 351 +++++++++++++++++++---- tests/unit_tests/utils/oauth2_tests.py | 48 +++- 6 files changed, 545 insertions(+), 125 deletions(-) diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py index 4113a3e8fe..c82a456cec 100644 --- a/superset/db_engine_specs/base.py +++ b/superset/db_engine_specs/base.py @@ -610,7 +610,10 @@ class BaseEngineSpec: # pylint: disable=too-many-public-methods re-run the query after authorization. """ tab_id = str(uuid4()) - default_redirect_uri = url_for("DatabaseRestApi.oauth2", _external=True) + default_redirect_uri = app.config.get( + "DATABASE_OAUTH2_REDIRECT_URI", + url_for("DatabaseRestApi.oauth2", _external=True), + ) # The state is passed to the OAuth2 provider, and sent back to Superset after # the user authorizes the access. The redirect endpoint in Superset can then diff --git a/superset/db_engine_specs/gsheets.py b/superset/db_engine_specs/gsheets.py index a36e95b92d..fee64d75f5 100644 --- a/superset/db_engine_specs/gsheets.py +++ b/superset/db_engine_specs/gsheets.py @@ -30,9 +30,11 @@ from flask_babel import gettext as __ from marshmallow import fields, Schema from marshmallow.exceptions import ValidationError from requests import Session +from requests.exceptions import HTTPError from shillelagh.adapters.api.gsheets.lib import SCOPES from shillelagh.exceptions import UnauthenticatedError from sqlalchemy.engine import create_engine +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from superset import db, security_manager @@ -41,7 +43,9 @@ from superset.db_engine_specs.base import DatabaseCategory from superset.db_engine_specs.shillelagh import ShillelaghEngineSpec from superset.errors import ErrorLevel, SupersetError, SupersetErrorType from superset.exceptions import SupersetException +from superset.superset_typing import OAuth2TokenResponse from superset.utils import json +from superset.utils.oauth2 import get_oauth2_access_token if TYPE_CHECKING: from superset.models.core import Database @@ -83,14 +87,16 @@ class GSheetsParametersSchema(Schema): ) -class GSheetsParametersType(TypedDict): +class GSheetsParametersType(TypedDict, total=False): service_account_info: str catalog: dict[str, str] | None + oauth2_client_info: dict[str, str] | None -class GSheetsPropertiesType(TypedDict): +class GSheetsPropertiesType(TypedDict, total=False): parameters: GSheetsParametersType catalog: dict[str, str] + masked_encrypted_extra: str class GSheetsEngineSpec(ShillelaghEngineSpec): @@ -123,7 +129,10 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): # when editing the database, mask this field in `encrypted_extra` # pylint: disable=invalid-name - encrypted_extra_sensitive_fields = {"$.service_account_info.private_key"} + encrypted_extra_sensitive_fields = { + "$.service_account_info.private_key", + "$.oauth2_client_info.secret", + } custom_errors: dict[Pattern[str], tuple[str, SupersetErrorType, dict[str, Any]]] = { SYNTAX_ERROR_REGEX: ( @@ -179,6 +188,47 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): } return urljoin(uri, "?" + urlencode(params)) + @classmethod + def needs_oauth2(cls, ex: Exception) -> bool: + """ + Check if the exception is one that indicates OAuth2 is needed. + + In case the token was manually revoked on Google side, `google-auth` will + try to automatically refresh credentials, but it fails since it only has the + access token. This override catches this scenario as well. + """ + return ( + g + and hasattr(g, "user") + and ( + isinstance(ex, cls.oauth2_exception) + or "credentials do not contain the necessary fields" in str(ex) + ) + ) + + @classmethod + def get_oauth2_fresh_token( + cls, + config: OAuth2ClientConfig, + refresh_token: str, + ) -> OAuth2TokenResponse: + """ + Refresh an OAuth2 access token that has expired. + + When trying to refresh an expired token that was revoked on Google side, + the request fails with 400 status code. + """ + try: + return super().get_oauth2_fresh_token(config, refresh_token) + except HTTPError as ex: + if ex.response is not None and ex.response.status_code == 400: + error_data = ex.response.json() + if error_data.get("error") == "invalid_grant": + raise UnauthenticatedError( + error_data.get("error_description", "Token has been revoked") + ) from ex + raise + @classmethod def impersonate_user( cls, @@ -198,6 +248,28 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): return url, engine_kwargs + @classmethod + def get_table_names( + cls, + database: Database, + inspector: Inspector, + schema: str | None, + ) -> set[str]: + """ + Get all sheets added to the connection. + + For OAuth2 connections, force the OAuth2 dance in case the user + doesn't have a token yet to avoid showing table names berofe auth. + """ + if database.is_oauth2_enabled() and not get_oauth2_access_token( + database.get_oauth2_config(), + database.id, + g.user.id, + database.db_engine_spec, + ): + database.start_oauth2_dance() + return super().get_table_names(database, inspector, schema) + @classmethod def get_extra_table_metadata( cls, @@ -335,9 +407,19 @@ class GSheetsEngineSpec(ShillelaghEngineSpec): return errors try: + url = url.replace('"', '""') results = conn.execute(f'SELECT * FROM "{url}" LIMIT 1') # noqa: S608 results.fetchall() except Exception: # pylint: disable=broad-except + # OAuth2 connection check + # Check `parameters` first (used by frontend during form validation) + if parameters.get("oauth2_client_info"): + continue + # Check `masked_encrypted_extra` (for create/update events) + encrypted = json.loads(properties.get("masked_encrypted_extra", "{}")) + if encrypted.get("oauth2_client_info"): + continue + errors.append( SupersetError( message=( diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py index cd1a2a14d9..4f04d0a334 100644 --- a/superset/utils/oauth2.py +++ b/superset/utils/oauth2.py @@ -17,6 +17,7 @@ from __future__ import annotations +import logging from contextlib import contextmanager from datetime import datetime, timedelta, timezone from typing import Any, Iterator, TYPE_CHECKING @@ -37,6 +38,8 @@ if TYPE_CHECKING: JWT_EXPIRATION = timedelta(minutes=5) +logger = logging.getLogger(__name__) + @backoff.on_exception( backoff.expo, @@ -96,10 +99,20 @@ def refresh_oauth2_token( user_id=user_id, database_id=database_id, ): - token_response = db_engine_spec.get_oauth2_fresh_token( - config, - token.refresh_token, - ) + try: + token_response = db_engine_spec.get_oauth2_fresh_token( + config, + token.refresh_token, + ) + except Exception: + # If token refresh failed, delete the invalid token to prevent retry loops + logger.warning( + "OAuth2 token refresh failed for user=%s db=%s, deleting invalid token", + user_id, + database_id, + ) + db.session.delete(token) + return None # store new access token; note that the refresh token might be revoked, in which # case there would be no access token in the response diff --git a/tests/unit_tests/db_engine_specs/test_base.py b/tests/unit_tests/db_engine_specs/test_base.py index bff4c93117..ccfe2b337f 100644 --- a/tests/unit_tests/db_engine_specs/test_base.py +++ b/tests/unit_tests/db_engine_specs/test_base.py @@ -23,18 +23,27 @@ import json # noqa: TID251 import re from textwrap import dedent from typing import Any +from urllib.parse import parse_qs, urlparse import pytest from pytest_mock import MockerFixture -from sqlalchemy import types +from sqlalchemy import Boolean, Column, Integer, types from sqlalchemy.dialects import sqlite from sqlalchemy.engine.url import make_url, URL from sqlalchemy.sql import sqltypes +from superset.db_engine_specs.base import BaseEngineSpec, convert_inspector_columns from superset.errors import ErrorLevel, SupersetError, SupersetErrorType +from superset.exceptions import OAuth2RedirectError from superset.sql.parse import Table -from superset.superset_typing import ResultSetColumnType, SQLAColumnType -from superset.utils.core import GenericDataType +from superset.superset_typing import ( + OAuth2ClientConfig, + OAuth2State, + ResultSetColumnType, + SQLAColumnType, +) +from superset.utils.core import FilterOperator, GenericDataType +from superset.utils.oauth2 import decode_oauth2_state from tests.unit_tests.db_engine_specs.utils import assert_column_spec @@ -68,9 +77,6 @@ def test_get_text_clause_with_colon() -> None: """ Make sure text clauses are correctly escaped """ - - from superset.db_engine_specs.base import BaseEngineSpec - text_clause = BaseEngineSpec.get_text_clause( "SELECT foo FROM tbl WHERE foo = '123:456')" ) @@ -90,8 +96,6 @@ def test_validate_db_uri(mocker: MockerFixture) -> None: {"DB_SQLA_URI_VALIDATOR": mock_validate}, ) - from superset.db_engine_specs.base import BaseEngineSpec - with pytest.raises(ValueError): # noqa: PT011 BaseEngineSpec.validate_database_uri(URL.create("sqlite")) @@ -130,8 +134,6 @@ select 'USD' as cur ], ) def test_cte_query_parsing(original: types.TypeEngine, expected: str) -> None: - from superset.db_engine_specs.base import BaseEngineSpec - actual = BaseEngineSpec.get_cte_query(original) assert actual == expected @@ -197,8 +199,6 @@ def test_get_column_spec( def test_convert_inspector_columns( cols: list[SQLAColumnType], expected_result: list[ResultSetColumnType] ): - from superset.db_engine_specs.base import convert_inspector_columns - assert convert_inspector_columns(cols) == expected_result @@ -206,8 +206,6 @@ def test_select_star(mocker: MockerFixture) -> None: """ Test the ``select_star`` method. """ - from superset.db_engine_specs.base import BaseEngineSpec - cols: list[ResultSetColumnType] = [ { "column_name": "a", @@ -249,7 +247,6 @@ def test_extra_table_metadata(mocker: MockerFixture) -> None: """ Test the deprecated `extra_table_metadata` method. """ - from superset.db_engine_specs.base import BaseEngineSpec from superset.models.core import Database class ThirdPartyDBEngineSpec(BaseEngineSpec): @@ -285,8 +282,6 @@ def test_get_default_catalog(mocker: MockerFixture) -> None: """ Test the `get_default_catalog` method. """ - from superset.db_engine_specs.base import BaseEngineSpec - database = mocker.MagicMock() assert BaseEngineSpec.get_default_catalog(database) is None @@ -295,7 +290,6 @@ def test_quote_table() -> None: """ Test the `quote_table` function. """ - from superset.db_engine_specs.base import BaseEngineSpec dialect = sqlite.dialect() @@ -318,8 +312,6 @@ def test_mask_encrypted_extra() -> None: """ Test that the private key is masked when the database is edited. """ - from superset.db_engine_specs.base import BaseEngineSpec - config = json.dumps( { "foo": "bar", @@ -342,8 +334,6 @@ def test_unmask_encrypted_extra() -> None: """ Test that the private key can be reused from the previous `encrypted_extra`. """ - from superset.db_engine_specs.base import BaseEngineSpec - old = json.dumps( { "foo": "bar", @@ -375,8 +365,6 @@ def test_impersonate_user_backwards_compatible(mocker: MockerFixture) -> None: """ Test that the `impersonate_user` method calls the original methods it replaced. """ - from superset.db_engine_specs.base import BaseEngineSpec - database = mocker.MagicMock() url = make_url("sqlite://foo.db") new_url = make_url("sqlite://bar.db") @@ -417,8 +405,6 @@ def test_impersonate_user_no_database(mocker: MockerFixture) -> None: """ Test `impersonate_user` when `update_impersonation_config` has an old signature. """ - from superset.db_engine_specs.base import BaseEngineSpec - database = mocker.MagicMock() url = make_url("sqlite://foo.db") new_url = make_url("sqlite://bar.db") @@ -457,10 +443,6 @@ def test_handle_boolean_filter_default_behavior() -> None: """ Test that BaseEngineSpec uses IS operators for boolean filters by default. """ - from sqlalchemy import Boolean, Column - - from superset.db_engine_specs.base import BaseEngineSpec - # Create a mock SQLAlchemy column bool_col = Column("test_col", Boolean) @@ -479,9 +461,6 @@ def test_handle_boolean_filter_with_equality() -> None: """ Test that BaseEngineSpec can use equality operators when configured. """ - from sqlalchemy import Boolean, Column - - from superset.db_engine_specs.base import BaseEngineSpec # Create a test engine spec that uses equality class TestEngineSpec(BaseEngineSpec): @@ -502,15 +481,9 @@ def test_handle_null_filter() -> None: """ Test null/not null filter handling. """ - from sqlalchemy import Boolean, Column - - from superset.db_engine_specs.base import BaseEngineSpec - bool_col = Column("test_col", Boolean) # Test IS_NULL - use actual FilterOperator values - from superset.utils.core import FilterOperator - result_null = BaseEngineSpec.handle_null_filter(bool_col, FilterOperator.IS_NULL) assert hasattr(result_null, "left") assert hasattr(result_null, "right") @@ -531,15 +504,9 @@ def test_handle_comparison_filter() -> None: """ Test comparison filter handling for all operators. """ - from sqlalchemy import Column, Integer - - from superset.db_engine_specs.base import BaseEngineSpec - int_col = Column("test_col", Integer) # Test all comparison operators - use actual FilterOperator values - from superset.utils.core import FilterOperator - operators_and_values = [ (FilterOperator.EQUALS, 5), (FilterOperator.NOT_EQUALS, 5), @@ -563,8 +530,6 @@ def test_use_equality_for_boolean_filters_property() -> None: """ Test that BaseEngineSpec has the correct default value for boolean filter property. """ - from superset.db_engine_specs.base import BaseEngineSpec - # Default should be False (use IS operators) assert BaseEngineSpec.use_equality_for_boolean_filters is False @@ -573,9 +538,6 @@ def test_extract_errors(mocker: MockerFixture) -> None: """ Test that error is extracted correctly when no custom error message is provided. """ - - from superset.db_engine_specs.base import BaseEngineSpec - mocker.patch( "flask.current_app.config", {}, @@ -597,8 +559,6 @@ def test_extract_errors_from_config(mocker: MockerFixture) -> None: using database_name. """ - from superset.db_engine_specs.base import BaseEngineSpec - class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -632,8 +592,6 @@ def test_extract_errors_only_to_specified_database(mocker: MockerFixture) -> Non Test that custom error messages are only applied to the specified database_name. """ - from superset.db_engine_specs.base import BaseEngineSpec - class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -669,8 +627,6 @@ def test_extract_errors_from_config_with_regex(mocker: MockerFixture) -> None: and show_issue_info are extracted correctly from config. """ - from superset.db_engine_specs.base import BaseEngineSpec - class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -740,7 +696,6 @@ def test_extract_errors_with_non_dict_custom_errors(mocker: MockerFixture): Test that extract_errors doesn't fail when custom database errors are in wrong format. """ - from superset.db_engine_specs.base import BaseEngineSpec class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -765,7 +720,6 @@ def test_extract_errors_with_non_dict_engine_custom_errors(mocker: MockerFixture Test that extract_errors doesn't fail when database-specific custom errors are in wrong format. """ - from superset.db_engine_specs.base import BaseEngineSpec class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -790,7 +744,6 @@ def test_extract_errors_with_empty_custom_error_message(mocker: MockerFixture): Test that when the custom error message is empty, the original error message is preserved. """ - from superset.db_engine_specs.base import BaseEngineSpec class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -824,7 +777,6 @@ def test_extract_errors_matches_database_name_selection(mocker: MockerFixture) - """ Test that custom error messages are matched by database_name. """ - from superset.db_engine_specs.base import BaseEngineSpec class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -866,7 +818,6 @@ def test_extract_errors_no_match_falls_back(mocker: MockerFixture) -> None: """ Test that when database_name has no match, the original error message is preserved. """ - from superset.db_engine_specs.base import BaseEngineSpec class TestEngineSpec(BaseEngineSpec): engine_name = "ExampleEngine" @@ -901,12 +852,6 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) -> Test that BaseEngineSpec.get_oauth2_authorization_uri uses standard OAuth 2.0 parameters only and does not include provider-specific params like prompt=consent. """ - from urllib.parse import parse_qs, urlparse - - from superset.db_engine_specs.base import BaseEngineSpec - from superset.superset_typing import OAuth2ClientConfig, OAuth2State - from superset.utils.oauth2 import decode_oauth2_state - config: OAuth2ClientConfig = { "id": "client-id", "secret": "client-secret", @@ -943,3 +888,81 @@ def test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) -> assert "prompt" not in query assert "access_type" not in query assert "include_granted_scopes" not in query + + +def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> None: + """ + Test that start_oauth2_dance uses DATABASE_OAUTH2_REDIRECT_URI config if set. + """ + custom_redirect_uri = "https://proxy.example.com/oauth2/" + + mocker.patch( + "flask.current_app.config", + { + "DATABASE_OAUTH2_REDIRECT_URI": custom_redirect_uri, + "SECRET_KEY": "test-secret-key", + "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", + }, + ) + + g = mocker.patch("superset.db_engine_specs.base.g") + g.user.id = 1 + + database = mocker.MagicMock() + database.id = 1 + database.get_oauth2_config.return_value = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "https://another-link.com", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + } + + with pytest.raises(OAuth2RedirectError) as exc_info: + BaseEngineSpec.start_oauth2_dance(database) + + error = exc_info.value.error + + assert error.extra["redirect_uri"] == custom_redirect_uri + + +def test_start_oauth2_dance_falls_back_to_url_for(mocker: MockerFixture) -> None: + """ + Test that start_oauth2_dance falls back to url_for when no config is set. + """ + fallback_uri = "http://localhost:8088/api/v1/database/oauth2/" + + mocker.patch( + "flask.current_app.config", + { + "SECRET_KEY": "test-secret-key", + "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256", + }, + ) + mocker.patch( + "superset.db_engine_specs.base.url_for", + return_value=fallback_uri, + ) + + g = mocker.patch("superset.db_engine_specs.base.g") + g.user.id = 1 + + database = mocker.MagicMock() + database.id = 1 + database.get_oauth2_config.return_value = { + "id": "client-id", + "secret": "client-secret", + "scope": "read write", + "redirect_uri": "https://another-link.com", + "authorization_request_uri": "https://oauth.example.com/authorize", + "token_request_uri": "https://oauth.example.com/token", + "request_content_type": "json", + } + + with pytest.raises(OAuth2RedirectError) as exc_info: + BaseEngineSpec.start_oauth2_dance(database) + + error = exc_info.value.error + + assert error.extra["redirect_uri"] == fallback_uri diff --git a/tests/unit_tests/db_engine_specs/test_gsheets.py b/tests/unit_tests/db_engine_specs/test_gsheets.py index 2ed796c32d..c7ec6d6bb9 100644 --- a/tests/unit_tests/db_engine_specs/test_gsheets.py +++ b/tests/unit_tests/db_engine_specs/test_gsheets.py @@ -23,6 +23,8 @@ from urllib.parse import parse_qs, urlparse import pandas as pd import pytest from pytest_mock import MockerFixture +from requests.exceptions import HTTPError +from shillelagh.exceptions import UnauthenticatedError from sqlalchemy.engine.url import make_url from superset.errors import ErrorLevel, SupersetError, SupersetErrorType @@ -160,58 +162,61 @@ def test_validate_parameters_catalog( } errors = GSheetsEngineSpec.validate_parameters(properties) # ignore: type - assert errors == [ - SupersetError( - message=( - "The URL could not be identified. Please check for typos " - "and make sure that ‘Type of Google Sheets allowed’ " - "selection matches the input." - ), - error_type=SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, - level=ErrorLevel.WARNING, - extra={ - "catalog": { - "idx": 0, - "url": True, - }, - "issue_codes": [ - { - "code": 1003, - "message": "Issue 1003 - There is a syntax error in the SQL query. Perhaps there was a misspelling or a typo.", # noqa: E501 - }, - { - "code": 1005, - "message": "Issue 1005 - The table was deleted or renamed in the database.", # noqa: E501 + assert ( + errors + == [ + SupersetError( + message=( + "The URL could not be identified. Please check for typos " + "and make sure that ‘Type of Google Sheets allowed’ " + "selection matches the input." + ), + error_type=SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, + level=ErrorLevel.WARNING, + extra={ + "catalog": { + "idx": 0, + "url": True, }, - ], - }, - ), - SupersetError( - message=( - "The URL could not be identified. Please check for typos " - "and make sure that ‘Type of Google Sheets allowed’ " - "selection matches the input." - ), - error_type=SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, - level=ErrorLevel.WARNING, - extra={ - "catalog": { - "idx": 2, - "url": True, + "issue_codes": [ + { + "code": 1003, + "message": "Issue 1003 - There is a syntax error in the SQL query. Perhaps there was a misspelling or a typo.", # noqa: E501 + }, + { + "code": 1005, + "message": "Issue 1005 - The table was deleted or renamed in the database.", # noqa: E501 + }, + ], }, - "issue_codes": [ - { - "code": 1003, - "message": "Issue 1003 - There is a syntax error in the SQL query. Perhaps there was a misspelling or a typo.", # noqa: E501 - }, - { - "code": 1005, - "message": "Issue 1005 - The table was deleted or renamed in the database.", # noqa: E501 + ), + SupersetError( + message=( + "The URL could not be identified. Please check for typos " + "and make sure that ‘Type of Google Sheets allowed’ " + "selection matches the input." + ), + error_type=SupersetErrorType.TABLE_DOES_NOT_EXIST_ERROR, + level=ErrorLevel.WARNING, + extra={ + "catalog": { + "idx": 2, + "url": True, }, - ], - }, - ), - ] + "issue_codes": [ + { + "code": 1003, + "message": "Issue 1003 - There is a syntax error in the SQL query. Perhaps there was a misspelling or a typo.", # noqa: E501 + }, + { + "code": 1005, + "message": "Issue 1005 - The table was deleted or renamed in the database.", # noqa: E501 + }, + ], + }, + ), + ] + ) create_engine.assert_called_with( "gsheets://", @@ -737,3 +742,251 @@ def test_update_params_from_encrypted_extra(mocker: MockerFixture) -> None: GSheetsEngineSpec.update_params_from_encrypted_extra(database, params) assert params == {"foo": "bar"} + + +def test_needs_oauth2_with_credentials_error(mocker: MockerFixture) -> None: + """ + Test that needs_oauth2 returns True for google-auth credentials error. + + When a token is manually revoked on Google side, google-auth tries to + refresh credentials but fails with this message. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user = mocker.MagicMock() + + ex = Exception("credentials do not contain the necessary fields") + assert GSheetsEngineSpec.needs_oauth2(ex) is True + + +def test_needs_oauth2_with_other_error(mocker: MockerFixture) -> None: + """ + Test that needs_oauth2 returns False for other errors. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user = mocker.MagicMock() + + ex = Exception("Some other error") + assert GSheetsEngineSpec.needs_oauth2(ex) is False + + +def test_get_oauth2_fresh_token_success( + mocker: MockerFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: + """ + Test that get_oauth2_fresh_token returns token on success. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + requests = mocker.patch("superset.db_engine_specs.base.requests") + requests.post().json.return_value = { + "access_token": "new-access-token", + "expires_in": 3600, + } + + result = GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token") + assert result == { + "access_token": "new-access-token", + "expires_in": 3600, + } + + +def test_get_oauth2_fresh_token_invalid_grant( + mocker: MockerFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: + """ + Test that get_oauth2_fresh_token raises UnauthenticatedError for invalid_grant. + + When a token is revoked on Google side, the refresh request returns 400 + with error=invalid_grant. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + mock_response = mocker.MagicMock() + mock_response.status_code = 400 + mock_response.json.return_value = { + "error": "invalid_grant", + "error_description": "Token has been expired or revoked.", + } + http_error = HTTPError() + http_error.response = mock_response + + requests = mocker.patch("superset.db_engine_specs.base.requests") + requests.post().raise_for_status.side_effect = http_error + + with pytest.raises(UnauthenticatedError): + GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token") + + +def test_get_oauth2_fresh_token_other_http_error( + mocker: MockerFixture, + oauth2_config: OAuth2ClientConfig, +) -> None: + """ + Test that get_oauth2_fresh_token re-raises non-invalid_grant HTTP errors. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + mock_response = mocker.MagicMock() + mock_response.status_code = 500 + mock_response.json.return_value = {"error": "server_error"} + + http_error = HTTPError() + http_error.response = mock_response + + requests = mocker.patch("superset.db_engine_specs.base.requests") + requests.post().raise_for_status.side_effect = http_error + + with pytest.raises(HTTPError): + GSheetsEngineSpec.get_oauth2_fresh_token(oauth2_config, "refresh-token") + + +def test_get_table_names_triggers_oauth2_dance(mocker: MockerFixture) -> None: + """ + Test that get_table_names triggers OAuth2 dance when no token exists. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user.id = 1 + + get_oauth2_access_token = mocker.patch( + "superset.db_engine_specs.gsheets.get_oauth2_access_token", + return_value=None, + ) + + database = mocker.MagicMock() + database.id = 1 + database.is_oauth2_enabled.return_value = True + database.get_oauth2_config.return_value = {"id": "client-id"} + database.db_engine_spec = GSheetsEngineSpec + + inspector = mocker.MagicMock() + + GSheetsEngineSpec.get_table_names(database, inspector, None) + + database.start_oauth2_dance.assert_called_once() + get_oauth2_access_token.assert_called_once() + + +def test_get_table_names_does_not_trigger_oauth2_when_token_exists( + mocker: MockerFixture, +) -> None: + """ + Test that get_table_names does not trigger OAuth2 dance when token exists. + """ + from superset.db_engine_specs.gsheets import GSheetsEngineSpec + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user.id = 1 + + get_oauth2_access_token = mocker.patch( + "superset.db_engine_specs.gsheets.get_oauth2_access_token", + return_value="valid-token", + ) + + mocker.patch( + "superset.db_engine_specs.shillelagh.ShillelaghEngineSpec.get_table_names", + return_value={"sheet1", "sheet2"}, + ) + + database = mocker.MagicMock() + database.id = 1 + database.is_oauth2_enabled.return_value = True + database.get_oauth2_config.return_value = {"id": "client-id"} + database.db_engine_spec = GSheetsEngineSpec + + inspector = mocker.MagicMock() + + result = GSheetsEngineSpec.get_table_names(database, inspector, None) + + database.start_oauth2_dance.assert_not_called() + get_oauth2_access_token.assert_called_once() + assert result == {"sheet1", "sheet2"} + + +def test_validate_parameters_skips_oauth2_connections_with_parameters( + mocker: MockerFixture, +) -> None: + """ + Test that validate_parameters skips validation for OAuth2 connections. + + When oauth2_client_info is present in parameters, the validation should + skip URL checks since the user will authenticate via OAuth2. + """ + from superset.db_engine_specs.gsheets import ( + GSheetsEngineSpec, + GSheetsPropertiesType, + ) + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user.email = "[email protected]" + + create_engine = mocker.patch("superset.db_engine_specs.gsheets.create_engine") + conn = create_engine.return_value.connect.return_value + results = conn.execute.return_value + results.fetchall.side_effect = ProgrammingError( + "The caller does not have permission" + ) + + properties: GSheetsPropertiesType = { + "parameters": { + "service_account_info": "", + "catalog": {}, + "oauth2_client_info": {"id": "client-id", "secret": "client-secret"}, + }, + "catalog": { + "sheet1": "https://docs.google.com/spreadsheets/d/1/edit", + }, + } + errors = GSheetsEngineSpec.validate_parameters(properties) + + assert errors == [] + + +def test_validate_parameters_skips_oauth2_connections_with_masked_encrypted_extra( + mocker: MockerFixture, +) -> None: + """ + Test validate_parameters skips validation for OAuth2 via masked_encrypted_extra. + + When oauth2_client_info is present in masked_encrypted_extra (used during + create/update), the validation should skip URL checks. + """ + from superset.db_engine_specs.gsheets import ( + GSheetsEngineSpec, + GSheetsPropertiesType, + ) + + g = mocker.patch("superset.db_engine_specs.gsheets.g") + g.user.email = "[email protected]" + + create_engine = mocker.patch("superset.db_engine_specs.gsheets.create_engine") + conn = create_engine.return_value.connect.return_value + results = conn.execute.return_value + results.fetchall.side_effect = ProgrammingError( + "The caller does not have permission" + ) + + properties: GSheetsPropertiesType = { + "parameters": { + "service_account_info": "", + "catalog": {}, + }, + "catalog": { + "sheet1": "https://docs.google.com/spreadsheets/d/1/edit", + }, + "masked_encrypted_extra": json.dumps( + { + "oauth2_client_info": {"id": "client-id", "secret": "XXXXXXXXXX"}, + } + ), + } + errors = GSheetsEngineSpec.validate_parameters(properties) + + assert errors == [] diff --git a/tests/unit_tests/utils/oauth2_tests.py b/tests/unit_tests/utils/oauth2_tests.py index e9aa283b1a..7fbe4a636e 100644 --- a/tests/unit_tests/utils/oauth2_tests.py +++ b/tests/unit_tests/utils/oauth2_tests.py @@ -18,11 +18,15 @@ # pylint: disable=invalid-name, disallowed-name from datetime import datetime +from typing import cast from freezegun import freeze_time from pytest_mock import MockerFixture -from superset.utils.oauth2 import get_oauth2_access_token +from superset.superset_typing import OAuth2ClientConfig +from superset.utils.oauth2 import get_oauth2_access_token, refresh_oauth2_token + +DUMMY_OAUTH2_CONFIG = cast(OAuth2ClientConfig, {}) def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None: @@ -93,3 +97,45 @@ def test_get_oauth2_access_token_base_no_refresh(mocker: MockerFixture) -> None: # check that token was deleted db.session.delete.assert_called_with(token) + + +def test_refresh_oauth2_token_deletes_token_on_exception(mocker: MockerFixture) -> None: + """ + Test that refresh_oauth2_token deletes the token when refresh fails. + + When the token refresh fails (e.g., token was revoked on provider side), + the invalid token should be deleted to prevent retry loops. + """ + db = mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.KeyValueDistributedLock") + db_engine_spec = mocker.MagicMock() + db_engine_spec.get_oauth2_fresh_token.side_effect = Exception("Token revoked") + token = mocker.MagicMock() + token.refresh_token = "refresh-token" # noqa: S105 + + result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + + assert result is None + db.session.delete.assert_called_with(token) + + +def test_refresh_oauth2_token_no_access_token_in_response( + mocker: MockerFixture, +) -> None: + """ + Test that refresh_oauth2_token returns None when no access_token in response. + + This can happen when the refresh token was revoked. + """ + mocker.patch("superset.utils.oauth2.db") + mocker.patch("superset.utils.oauth2.KeyValueDistributedLock") + db_engine_spec = mocker.MagicMock() + db_engine_spec.get_oauth2_fresh_token.return_value = { + "error": "invalid_grant", + } + token = mocker.MagicMock() + token.refresh_token = "refresh-token" # noqa: S105 + + result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, token) + + assert result is None
