This is an automated email from the ASF dual-hosted git repository.
beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 5d20dc57d7 feat(oauth2): add PKCE support for database OAuth2
authentication (#37067)
5d20dc57d7 is described below
commit 5d20dc57d76467cb5511ced07433d0b0049b21bb
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri Jan 30 23:28:10 2026 -0500
feat(oauth2): add PKCE support for database OAuth2 authentication (#37067)
---
superset/commands/database/oauth2.py | 22 +++++
superset/db_engine_specs/base.py | 66 ++++++++++++--
superset/db_engine_specs/gsheets.py | 11 ++-
superset/key_value/types.py | 1 +
superset/superset_typing.py | 2 +-
superset/utils/oauth2.py | 49 +++++++++--
tests/unit_tests/databases/api_test.py | 26 +++++-
tests/unit_tests/db_engine_specs/test_base.py | 122 ++++++++++++++++++++++++++
tests/unit_tests/sql_lab_test.py | 57 ++++++++----
tests/unit_tests/utils/oauth2_tests.py | 104 +++++++++++++++++++++-
10 files changed, 422 insertions(+), 38 deletions(-)
diff --git a/superset/commands/database/oauth2.py
b/superset/commands/database/oauth2.py
index f7259077bc..8355bc0098 100644
--- a/superset/commands/database/oauth2.py
+++ b/superset/commands/database/oauth2.py
@@ -18,12 +18,15 @@
from datetime import datetime, timedelta
from functools import partial
from typing import cast
+from uuid import UUID
from superset.commands.base import BaseCommand
from superset.commands.database.exceptions import DatabaseNotFoundError
from superset.daos.database import DatabaseUserOAuth2TokensDAO
+from superset.daos.key_value import KeyValueDAO
from superset.databases.schemas import OAuth2ProviderResponseSchema
from superset.exceptions import OAuth2Error
+from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
from superset.models.core import Database, DatabaseUserOAuth2Tokens
from superset.superset_typing import OAuth2State
from superset.utils.decorators import on_error, transaction
@@ -50,9 +53,28 @@ class OAuth2StoreTokenCommand(BaseCommand):
if oauth2_config is None:
raise OAuth2Error("No configuration found for OAuth2")
+ # Look up PKCE code_verifier from KV store (RFC 7636)
+ code_verifier = None
+ tab_id = self._state["tab_id"]
+ try:
+ tab_uuid = UUID(tab_id)
+ except ValueError:
+ tab_uuid = None
+
+ if tab_uuid:
+ kv_value = KeyValueDAO.get_value(
+ resource=KeyValueResource.PKCE_CODE_VERIFIER,
+ key=tab_uuid,
+ codec=JsonKeyValueCodec(),
+ )
+ if kv_value:
+ code_verifier = kv_value.get("code_verifier")
+ KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER,
tab_uuid)
+
token_response = self._database.db_engine_spec.get_oauth2_token(
oauth2_config,
self._parameters["code"],
+ code_verifier=code_verifier,
)
# delete old tokens
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index ac355c9366..ac9c397e5b 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -21,7 +21,7 @@ from __future__ import annotations
import logging
import re
import warnings
-from datetime import datetime
+from datetime import datetime, timedelta
from inspect import signature
from re import Match, Pattern
from typing import (
@@ -36,7 +36,7 @@ from typing import (
Union,
)
from urllib.parse import urlencode, urljoin
-from uuid import uuid4
+from uuid import UUID, uuid4
import pandas as pd
import requests
@@ -63,6 +63,7 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as
TimeGrainConstants
from superset.databases.utils import get_table_metadata, make_url_safe
from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
from superset.exceptions import OAuth2Error, OAuth2RedirectError
+from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
from superset.sql.parse import (
BaseSQLStatement,
LimitMethod,
@@ -83,7 +84,11 @@ from superset.utils.core import ColumnSpec, GenericDataType,
QuerySource
from superset.utils.hashing import hash_from_str
from superset.utils.json import redact_sensitive, reveal_sensitive
from superset.utils.network import is_hostname_valid, is_port_open
-from superset.utils.oauth2 import encode_oauth2_state
+from superset.utils.oauth2 import (
+ encode_oauth2_state,
+ generate_code_challenge,
+ generate_code_verifier,
+)
if TYPE_CHECKING:
from superset.connectors.sqla.models import TableColumn
@@ -608,13 +613,38 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
tab sends a message to the original tab informing that authorization
was
successful (or not), and then closes. The original tab will
automatically
re-run the query after authorization.
+
+ PKCE (RFC 7636) is used to protect against authorization code
interception
+ attacks. A code_verifier is generated and stored server-side in the KV
store,
+ while the code_challenge (derived from the verifier) is sent to the
+ authorization server.
"""
+ # Prevent circular import.
+ from superset.daos.key_value import KeyValueDAO
+
tab_id = str(uuid4())
default_redirect_uri = app.config.get(
"DATABASE_OAUTH2_REDIRECT_URI",
url_for("DatabaseRestApi.oauth2", _external=True),
)
+ # Generate PKCE code verifier (RFC 7636)
+ code_verifier = generate_code_verifier()
+
+ # Store the code_verifier server-side in the KV store, keyed by tab_id.
+ # This avoids exposing it in the URL/browser history via the JWT state.
+ KeyValueDAO.delete_expired_entries(KeyValueResource.PKCE_CODE_VERIFIER)
+ KeyValueDAO.create_entry(
+ resource=KeyValueResource.PKCE_CODE_VERIFIER,
+ value={"code_verifier": code_verifier},
+ codec=JsonKeyValueCodec(),
+ key=UUID(tab_id),
+ expires_on=datetime.now() + timedelta(minutes=5),
+ )
+ # We need to commit here because we're going to raise an exception,
which will
+ # revert any non-commited changes.
+ db.session.commit()
+
# The state is passed to the OAuth2 provider, and sent back to
Superset after
# the user authorizes the access. The redirect endpoint in Superset
can then
# inspect the state to figure out to which user/database the access
token
@@ -641,7 +671,11 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
if oauth2_config is None:
raise OAuth2Error("No configuration found for OAuth2")
- oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state)
+ oauth_url = cls.get_oauth2_authorization_uri(
+ oauth2_config,
+ state,
+ code_verifier=code_verifier,
+ )
raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
@@ -685,21 +719,29 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
cls,
config: OAuth2ClientConfig,
state: OAuth2State,
+ code_verifier: str | None = None,
) -> str:
"""
Return URI for initial OAuth2 request.
- Uses standard OAuth 2.0 parameters only. Subclasses can override
- to add provider-specific parameters (e.g., Google's prompt=consent).
+ Uses standard OAuth 2.0 parameters plus PKCE (RFC 7636) parameters.
+ Subclasses can override to add provider-specific parameters
+ (e.g., Google's prompt=consent).
"""
uri = config["authorization_request_uri"]
- params = {
+ params: dict[str, str] = {
"scope": config["scope"],
"response_type": "code",
"state": encode_oauth2_state(state),
"redirect_uri": config["redirect_uri"],
"client_id": config["id"],
}
+
+ # Add PKCE parameters (RFC 7636) if code_verifier is provided
+ if code_verifier:
+ params["code_challenge"] = generate_code_challenge(code_verifier)
+ params["code_challenge_method"] = "S256"
+
return urljoin(uri, "?" + urlencode(params))
@classmethod
@@ -707,19 +749,27 @@ class BaseEngineSpec: # pylint:
disable=too-many-public-methods
cls,
config: OAuth2ClientConfig,
code: str,
+ code_verifier: str | None = None,
) -> OAuth2TokenResponse:
"""
Exchange authorization code for refresh/access tokens.
+
+ If code_verifier is provided (PKCE flow), it will be included in the
+ token request per RFC 7636.
"""
timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
uri = config["token_request_uri"]
- req_body = {
+ req_body: dict[str, str] = {
"code": code,
"client_id": config["id"],
"client_secret": config["secret"],
"redirect_uri": config["redirect_uri"],
"grant_type": "authorization_code",
}
+ # Add PKCE code_verifier if present (RFC 7636)
+ if code_verifier:
+ req_body["code_verifier"] = code_verifier
+
response = (
requests.post(uri, data=req_body, timeout=timeout)
if config["request_content_type"] == "data"
diff --git a/superset/db_engine_specs/gsheets.py
b/superset/db_engine_specs/gsheets.py
index 86eadf6c5a..780f92cc75 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -161,6 +161,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
cls,
config: "OAuth2ClientConfig",
state: "OAuth2State",
+ code_verifier: str | None = None,
) -> str:
"""
Return URI for initial OAuth2 request with Google-specific parameters.
@@ -172,10 +173,10 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
"""
from urllib.parse import urlencode, urljoin
- from superset.utils.oauth2 import encode_oauth2_state
+ from superset.utils.oauth2 import encode_oauth2_state,
generate_code_challenge
uri = config["authorization_request_uri"]
- params = {
+ params: dict[str, str] = {
"scope": config["scope"],
"response_type": "code",
"state": encode_oauth2_state(state),
@@ -186,6 +187,12 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
"include_granted_scopes": "false",
"prompt": "consent",
}
+
+ # Add PKCE parameters (RFC 7636) if code_verifier is provided
+ if code_verifier:
+ params["code_challenge"] = generate_code_challenge(code_verifier)
+ params["code_challenge_method"] = "S256"
+
return urljoin(uri, "?" + urlencode(params))
@classmethod
diff --git a/superset/key_value/types.py b/superset/key_value/types.py
index 3b2da06493..2cc025e426 100644
--- a/superset/key_value/types.py
+++ b/superset/key_value/types.py
@@ -45,6 +45,7 @@ class KeyValueResource(StrEnum):
EXPLORE_PERMALINK = "explore_permalink"
METASTORE_CACHE = "superset_metastore_cache"
LOCK = "lock"
+ PKCE_CODE_VERIFIER = "pkce_code_verifier"
SQLLAB_PERMALINK = "sqllab_permalink"
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index 4d409398d1..02e294a08c 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -356,7 +356,7 @@ class OAuth2TokenResponse(TypedDict, total=False):
refresh_token: str
-class OAuth2State(TypedDict):
+class OAuth2State(TypedDict, total=False):
"""
Type for the state passed during OAuth2.
"""
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 0124f57308..9a24f0c095 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -17,7 +17,10 @@
from __future__ import annotations
+import base64
+import hashlib
import logging
+import secrets
from contextlib import contextmanager
from datetime import datetime, timedelta, timezone
from typing import Any, Iterator, TYPE_CHECKING
@@ -40,6 +43,37 @@ JWT_EXPIRATION = timedelta(minutes=5)
logger = logging.getLogger(__name__)
+# PKCE code verifier length (RFC 7636 recommends 43-128 characters)
+PKCE_CODE_VERIFIER_LENGTH = 64
+
+
+def generate_code_verifier() -> str:
+ """
+ Generate a PKCE code verifier (RFC 7636).
+
+ The code verifier is a high-entropy cryptographic random string using
+ unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~",
+ with a minimum length of 43 characters and a maximum length of 128.
+ """
+ # Generate random bytes and encode as URL-safe base64
+ random_bytes = secrets.token_bytes(PKCE_CODE_VERIFIER_LENGTH)
+ # Use URL-safe base64 encoding without padding
+ code_verifier =
base64.urlsafe_b64encode(random_bytes).rstrip(b"=").decode("ascii")
+ return code_verifier
+
+
+def generate_code_challenge(code_verifier: str) -> str:
+ """
+ Generate a PKCE code challenge from a code verifier (RFC 7636).
+
+ Uses the S256 method: BASE64URL(SHA256(code_verifier))
+ """
+ # Compute SHA-256 hash of the code verifier
+ digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
+ # Encode as URL-safe base64 without padding
+ code_challenge =
base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
+ return code_challenge
+
@backoff.on_exception(
backoff.expo,
@@ -140,13 +174,14 @@ def encode_oauth2_state(state: OAuth2State) -> str:
"""
Encode the OAuth2 state.
"""
- payload = {
+ payload: dict[str, Any] = {
"exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
"database_id": state["database_id"],
"user_id": state["user_id"],
"default_redirect_uri": state["default_redirect_uri"],
"tab_id": state["tab_id"],
}
+
encoded_state = jwt.encode(
payload=payload,
key=app.config["SECRET_KEY"],
@@ -172,12 +207,12 @@ class OAuth2StateSchema(Schema):
data: dict[str, Any],
**kwargs: Any,
) -> OAuth2State:
- return OAuth2State(
- database_id=data["database_id"],
- user_id=data["user_id"],
- default_redirect_uri=data["default_redirect_uri"],
- tab_id=data["tab_id"],
- )
+ return {
+ "database_id": data["database_id"],
+ "user_id": data["user_id"],
+ "default_redirect_uri": data["default_redirect_uri"],
+ "tab_id": data["tab_id"],
+ }
class Meta: # pylint: disable=too-few-public-methods
# ignore `exp`
diff --git a/tests/unit_tests/databases/api_test.py
b/tests/unit_tests/databases/api_test.py
index 7b77f5099d..244d75f7e2 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -255,7 +255,7 @@ def test_database_connection(
"service_account_info": {
"type": "service_account",
"project_id": "black-sanctum-314419",
- "private_key_id":
"259b0d419a8f840056158763ff54d8b08f7b8173",
+ "private_key_id":
"259b0d419a8f840056158763ff54d8b08f7b8173", # noqa: E501
"private_key": "XXXXXXXXXX",
"client_email":
"google-spreadsheets-demo-se...@black-sanctum-314419.iam.gserviceaccount.com",
# noqa: E501
"client_id": "114567578578109757129",
@@ -621,6 +621,10 @@ def test_oauth2_happy_path(
"expires_in": 3600,
"refresh_token": "ZZZ",
}
+ mocker.patch(
+ "superset.commands.database.oauth2.KeyValueDAO.get_value",
+ return_value=None,
+ )
state: OAuth2State = {
"user_id": 1,
@@ -641,7 +645,11 @@ def test_oauth2_happy_path(
)
assert response.status_code == 200
- get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
+ get_oauth2_token.assert_called_with(
+ {"id": "one", "secret": "two"},
+ "XXX",
+ code_verifier=None,
+ )
token = db.session.query(DatabaseUserOAuth2Tokens).one()
assert token.user_id == 1
@@ -689,6 +697,10 @@ def test_oauth2_permissions(
"expires_in": 3600,
"refresh_token": "ZZZ",
}
+ mocker.patch(
+ "superset.commands.database.oauth2.KeyValueDAO.get_value",
+ return_value=None,
+ )
state: OAuth2State = {
"user_id": 1,
@@ -709,7 +721,11 @@ def test_oauth2_permissions(
)
assert response.status_code == 200
- get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
+ get_oauth2_token.assert_called_with(
+ {"id": "one", "secret": "two"},
+ "XXX",
+ code_verifier=None,
+ )
token = db.session.query(DatabaseUserOAuth2Tokens).one()
assert token.user_id == 1
@@ -762,6 +778,10 @@ def test_oauth2_multiple_tokens(
"refresh_token": "ZZZ2",
},
]
+ mocker.patch(
+ "superset.commands.database.oauth2.KeyValueDAO.get_value",
+ return_value=None,
+ )
state: OAuth2State = {
"user_id": 1,
diff --git a/tests/unit_tests/db_engine_specs/test_base.py
b/tests/unit_tests/db_engine_specs/test_base.py
index ccfe2b337f..5d8c330410 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -889,6 +889,124 @@ def
test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) ->
assert "access_type" not in query
assert "include_granted_scopes" not in query
+ # Verify PKCE parameters are NOT included when code_verifier is not
provided
+ assert "code_challenge" not in query
+ assert "code_challenge_method" not in query
+
+
+def test_get_oauth2_authorization_uri_with_pkce(mocker: MockerFixture) -> None:
+ """
+ Test that BaseEngineSpec.get_oauth2_authorization_uri includes PKCE
parameters
+ when code_verifier is passed as a parameter (RFC 7636).
+ """
+ from urllib.parse import parse_qs, urlparse
+
+ from superset.db_engine_specs.base import BaseEngineSpec
+ from superset.utils.oauth2 import generate_code_challenge,
generate_code_verifier
+
+ config: OAuth2ClientConfig = {
+ "id": "client-id",
+ "secret": "client-secret",
+ "scope": "read write",
+ "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
+ "authorization_request_uri": "https://oauth.example.com/authorize",
+ "token_request_uri": "https://oauth.example.com/token",
+ "request_content_type": "json",
+ }
+
+ code_verifier = generate_code_verifier()
+ state: OAuth2State = {
+ "database_id": 1,
+ "user_id": 1,
+ "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
+ "tab_id": "1234",
+ }
+
+ url = BaseEngineSpec.get_oauth2_authorization_uri(
+ config, state, code_verifier=code_verifier
+ )
+ parsed = urlparse(url)
+ query = parse_qs(parsed.query)
+
+ # Verify PKCE parameters are included (RFC 7636)
+ assert "code_challenge" in query
+ assert query["code_challenge_method"][0] == "S256"
+ # Verify the code_challenge matches the expected value
+ expected_challenge = generate_code_challenge(code_verifier)
+ assert query["code_challenge"][0] == expected_challenge
+
+
+def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None:
+ """
+ Test that BaseEngineSpec.get_oauth2_token works without PKCE code_verifier.
+ """
+ from superset.db_engine_specs.base import BaseEngineSpec
+
+ mocker.patch(
+ "flask.current_app.config",
+ {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda:
30)},
+ )
+ mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
+ mock_post.return_value.json.return_value = {
+ "access_token": "test-access-token", # noqa: S105
+ "expires_in": 3600,
+ }
+
+ config: OAuth2ClientConfig = {
+ "id": "client-id",
+ "secret": "client-secret",
+ "scope": "read write",
+ "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
+ "authorization_request_uri": "https://oauth.example.com/authorize",
+ "token_request_uri": "https://oauth.example.com/token",
+ "request_content_type": "json",
+ }
+
+ result = BaseEngineSpec.get_oauth2_token(config, "auth-code")
+
+ assert result["access_token"] == "test-access-token" # noqa: S105
+ # Verify code_verifier is NOT in the request body
+ call_kwargs = mock_post.call_args
+ request_body = call_kwargs.kwargs.get("json") or
call_kwargs.kwargs.get("data")
+ assert "code_verifier" not in request_body
+
+
+def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None:
+ """
+ Test BaseEngineSpec.get_oauth2_token includes code_verifier when provided.
+ """
+ from superset.db_engine_specs.base import BaseEngineSpec
+ from superset.utils.oauth2 import generate_code_verifier
+
+ mocker.patch(
+ "flask.current_app.config",
+ {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda:
30)},
+ )
+ mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
+ mock_post.return_value.json.return_value = {
+ "access_token": "test-access-token", # noqa: S105
+ "expires_in": 3600,
+ }
+
+ config: OAuth2ClientConfig = {
+ "id": "client-id",
+ "secret": "client-secret",
+ "scope": "read write",
+ "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/",
+ "authorization_request_uri": "https://oauth.example.com/authorize",
+ "token_request_uri": "https://oauth.example.com/token",
+ "request_content_type": "json",
+ }
+
+ code_verifier = generate_code_verifier()
+ result = BaseEngineSpec.get_oauth2_token(config, "auth-code",
code_verifier)
+
+ assert result["access_token"] == "test-access-token" # noqa: S105
+ # Verify code_verifier IS in the request body (PKCE)
+ call_kwargs = mock_post.call_args
+ request_body = call_kwargs.kwargs.get("json") or
call_kwargs.kwargs.get("data")
+ assert request_body["code_verifier"] == code_verifier
+
def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) ->
None:
"""
@@ -904,6 +1022,8 @@ def
test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> N
"DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
},
)
+ mocker.patch("superset.daos.key_value.KeyValueDAO")
+ mocker.patch("superset.db_engine_specs.base.db")
g = mocker.patch("superset.db_engine_specs.base.g")
g.user.id = 1
@@ -944,6 +1064,8 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker:
MockerFixture) -> None
"superset.db_engine_specs.base.url_for",
return_value=fallback_uri,
)
+ mocker.patch("superset.daos.key_value.KeyValueDAO")
+ mocker.patch("superset.db_engine_specs.base.db")
g = mocker.patch("superset.db_engine_specs.base.g")
g.user.id = 1
diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py
index 6e221e1494..9fd3a0ac0e 100644
--- a/tests/unit_tests/sql_lab_test.py
+++ b/tests/unit_tests/sql_lab_test.py
@@ -18,6 +18,7 @@
import json # noqa: TID251
from unittest.mock import MagicMock
+from urllib.parse import parse_qs, urlparse
from uuid import UUID
import pytest
@@ -201,6 +202,13 @@ def test_get_sql_results_oauth2(mocker: MockerFixture,
app) -> None:
"superset.db_engine_specs.base.uuid4",
return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"),
)
+ mocker.patch(
+ "superset.db_engine_specs.base.generate_code_verifier",
+
return_value="xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ",
+ )
+ mocker.patch("superset.daos.key_value.KeyValueDAO.delete_expired_entries")
+ mocker.patch("superset.daos.key_value.KeyValueDAO.create_entry")
+ mocker.patch("superset.db_engine_specs.base.db.session.commit")
g = mocker.patch("superset.db_engine_specs.base.g")
g.user = mocker.MagicMock()
@@ -222,22 +230,39 @@ def test_get_sql_results_oauth2(mocker: MockerFixture,
app) -> None:
mocker.patch("superset.sql_lab.get_query", return_value=query)
payload = get_sql_results(query_id=1, rendered_query="SELECT 1")
- assert payload == {
- "status": QueryStatus.FAILED,
- "error": "You don't have permission to access the data.",
- "errors": [
- {
- "message": "You don't have permission to access the data.",
- "error_type": SupersetErrorType.OAUTH2_REDIRECT,
- "level": ErrorLevel.WARNING,
- "extra": {
- "url":
"https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3AUSERADMIN&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vbG9jYWxob3N0L2FwaS92MS9kYXRhYmFzZS9vYXV0aDIvIiwidGFiX2lkIjoiZmIxMWY1MjgtNmViYS00YThhLTgzN2UtNmIwZDM5ZWU5MTg3In0%252E7nLkei6-V8sVk_Pgm8cFhk0tnKRKayRE1Vc7RxuM9mw&redirect_uri=http%3A%2F%2Flocal
[...]
- "tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
- "redirect_uri": "http://localhost/api/v1/database/oauth2/",
- },
- }
- ],
- }
+ assert payload["status"] == QueryStatus.FAILED
+ assert payload["error"] == "You don't have permission to access the data."
+ assert len(payload["errors"]) == 1
+
+ error = payload["errors"][0]
+ assert error["message"] == "You don't have permission to access the data."
+ assert error["error_type"] == SupersetErrorType.OAUTH2_REDIRECT
+ assert error["level"] == ErrorLevel.WARNING
+ assert error["extra"]["tab_id"] == "fb11f528-6eba-4a8a-837e-6b0d39ee9187"
+ assert error["extra"]["redirect_uri"] ==
"http://localhost/api/v1/database/oauth2/"
+
+ # Parse the OAuth2 authorization URL and verify components individually,
+ # since the JWT state and PKCE code_challenge are computed
deterministically
+ # from mocked inputs but their exact encoding depends on library internals.
+ url = urlparse(error["extra"]["url"])
+ assert url.scheme == "https"
+ assert url.netloc == "abcd1234.snowflakecomputing.com"
+ assert url.path == "/oauth/authorize"
+
+ params = parse_qs(url.query)
+ assert params["scope"] == ["refresh_token session:role:USERADMIN"]
+ assert params["response_type"] == ["code"]
+ assert params["redirect_uri"] ==
["http://localhost/api/v1/database/oauth2/"]
+ assert params["client_id"] == ["my_client_id"]
+ assert params["code_challenge_method"] == ["S256"]
+
+ # Verify PKCE code_challenge matches the mocked code_verifier
+ from superset.utils.oauth2 import generate_code_challenge
+
+ expected_code_challenge = generate_code_challenge(
+
"xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ"
+ )
+ assert params["code_challenge"] == [expected_code_challenge]
def test_apply_rls(mocker: MockerFixture) -> None:
diff --git a/tests/unit_tests/utils/oauth2_tests.py
b/tests/unit_tests/utils/oauth2_tests.py
index fc3ed7a651..33a0c0c266 100644
--- a/tests/unit_tests/utils/oauth2_tests.py
+++ b/tests/unit_tests/utils/oauth2_tests.py
@@ -17,6 +17,8 @@
# pylint: disable=invalid-name, disallowed-name
+import base64
+import hashlib
from datetime import datetime
from typing import cast
@@ -25,7 +27,14 @@ from freezegun import freeze_time
from pytest_mock import MockerFixture
from superset.superset_typing import OAuth2ClientConfig
-from superset.utils.oauth2 import get_oauth2_access_token, refresh_oauth2_token
+from superset.utils.oauth2 import (
+ decode_oauth2_state,
+ encode_oauth2_state,
+ generate_code_challenge,
+ generate_code_verifier,
+ get_oauth2_access_token,
+ refresh_oauth2_token,
+)
DUMMY_OAUTH2_CONFIG = cast(OAuth2ClientConfig, {})
@@ -177,3 +186,96 @@ def test_refresh_oauth2_token_no_access_token_in_response(
result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec,
token)
assert result is None
+
+
+def test_generate_code_verifier_length() -> None:
+ """
+ Test that generate_code_verifier produces a string of valid length (RFC
7636).
+ """
+ code_verifier = generate_code_verifier()
+ # RFC 7636 requires 43-128 characters
+ assert 43 <= len(code_verifier) <= 128
+
+
+def test_generate_code_verifier_uniqueness() -> None:
+ """
+ Test that generate_code_verifier produces unique values.
+ """
+ verifiers = {generate_code_verifier() for _ in range(100)}
+ # All generated verifiers should be unique
+ assert len(verifiers) == 100
+
+
+def test_generate_code_verifier_valid_characters() -> None:
+ """
+ Test that generate_code_verifier only uses valid characters (RFC 7636).
+ """
+ code_verifier = generate_code_verifier()
+ # RFC 7636 allows: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
+ # URL-safe base64 uses: [A-Z] / [a-z] / [0-9] / "-" / "_"
+ valid_chars = set(
+ "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
+ )
+ assert all(char in valid_chars for char in code_verifier)
+
+
+def test_generate_code_challenge_s256() -> None:
+ """
+ Test that generate_code_challenge produces correct S256 challenge.
+ """
+ # Use a known code_verifier to verify the challenge computation
+ code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
+
+ # Compute expected challenge manually
+ digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
+ expected_challenge =
base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
+
+ code_challenge = generate_code_challenge(code_verifier)
+ assert code_challenge == expected_challenge
+
+
+def test_generate_code_challenge_rfc_example() -> None:
+ """
+ Test PKCE code challenge against RFC 7636 Appendix B example.
+
+ See: https://datatracker.ietf.org/doc/html/rfc7636#appendix-B
+ """
+ # RFC 7636 example code_verifier (Appendix B)
+ code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
+ # RFC 7636 expected code_challenge for S256 method
+ expected_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
+
+ code_challenge = generate_code_challenge(code_verifier)
+ assert code_challenge == expected_challenge
+
+
+def test_encode_decode_oauth2_state(
+ mocker: MockerFixture,
+) -> None:
+ """
+ Test that encode/decode cycle preserves state fields.
+ """
+ from superset.superset_typing import OAuth2State
+
+ mocker.patch(
+ "flask.current_app.config",
+ {
+ "SECRET_KEY": "test-secret-key",
+ "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
+ },
+ )
+
+ state: OAuth2State = {
+ "database_id": 1,
+ "user_id": 2,
+ "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/",
+ "tab_id": "test-tab-id",
+ }
+
+ with freeze_time("2024-01-01"):
+ encoded = encode_oauth2_state(state)
+ decoded = decode_oauth2_state(encoded)
+
+ assert "code_verifier" not in decoded
+ assert decoded["database_id"] == 1
+ assert decoded["user_id"] == 2