This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new 5d20dc57d7 feat(oauth2): add PKCE support for database OAuth2 
authentication (#37067)
5d20dc57d7 is described below

commit 5d20dc57d76467cb5511ced07433d0b0049b21bb
Author: Beto Dealmeida <[email protected]>
AuthorDate: Fri Jan 30 23:28:10 2026 -0500

    feat(oauth2): add PKCE support for database OAuth2 authentication (#37067)
---
 superset/commands/database/oauth2.py          |  22 +++++
 superset/db_engine_specs/base.py              |  66 ++++++++++++--
 superset/db_engine_specs/gsheets.py           |  11 ++-
 superset/key_value/types.py                   |   1 +
 superset/superset_typing.py                   |   2 +-
 superset/utils/oauth2.py                      |  49 +++++++++--
 tests/unit_tests/databases/api_test.py        |  26 +++++-
 tests/unit_tests/db_engine_specs/test_base.py | 122 ++++++++++++++++++++++++++
 tests/unit_tests/sql_lab_test.py              |  57 ++++++++----
 tests/unit_tests/utils/oauth2_tests.py        | 104 +++++++++++++++++++++-
 10 files changed, 422 insertions(+), 38 deletions(-)

diff --git a/superset/commands/database/oauth2.py 
b/superset/commands/database/oauth2.py
index f7259077bc..8355bc0098 100644
--- a/superset/commands/database/oauth2.py
+++ b/superset/commands/database/oauth2.py
@@ -18,12 +18,15 @@
 from datetime import datetime, timedelta
 from functools import partial
 from typing import cast
+from uuid import UUID
 
 from superset.commands.base import BaseCommand
 from superset.commands.database.exceptions import DatabaseNotFoundError
 from superset.daos.database import DatabaseUserOAuth2TokensDAO
+from superset.daos.key_value import KeyValueDAO
 from superset.databases.schemas import OAuth2ProviderResponseSchema
 from superset.exceptions import OAuth2Error
+from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
 from superset.models.core import Database, DatabaseUserOAuth2Tokens
 from superset.superset_typing import OAuth2State
 from superset.utils.decorators import on_error, transaction
@@ -50,9 +53,28 @@ class OAuth2StoreTokenCommand(BaseCommand):
         if oauth2_config is None:
             raise OAuth2Error("No configuration found for OAuth2")
 
+        # Look up PKCE code_verifier from KV store (RFC 7636)
+        code_verifier = None
+        tab_id = self._state["tab_id"]
+        try:
+            tab_uuid = UUID(tab_id)
+        except ValueError:
+            tab_uuid = None
+
+        if tab_uuid:
+            kv_value = KeyValueDAO.get_value(
+                resource=KeyValueResource.PKCE_CODE_VERIFIER,
+                key=tab_uuid,
+                codec=JsonKeyValueCodec(),
+            )
+            if kv_value:
+                code_verifier = kv_value.get("code_verifier")
+                KeyValueDAO.delete_entry(KeyValueResource.PKCE_CODE_VERIFIER, 
tab_uuid)
+
         token_response = self._database.db_engine_spec.get_oauth2_token(
             oauth2_config,
             self._parameters["code"],
+            code_verifier=code_verifier,
         )
 
         # delete old tokens
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index ac355c9366..ac9c397e5b 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -21,7 +21,7 @@ from __future__ import annotations
 import logging
 import re
 import warnings
-from datetime import datetime
+from datetime import datetime, timedelta
 from inspect import signature
 from re import Match, Pattern
 from typing import (
@@ -36,7 +36,7 @@ from typing import (
     Union,
 )
 from urllib.parse import urlencode, urljoin
-from uuid import uuid4
+from uuid import UUID, uuid4
 
 import pandas as pd
 import requests
@@ -63,6 +63,7 @@ from superset.constants import QUERY_CANCEL_KEY, TimeGrain as 
TimeGrainConstants
 from superset.databases.utils import get_table_metadata, make_url_safe
 from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
 from superset.exceptions import OAuth2Error, OAuth2RedirectError
+from superset.key_value.types import JsonKeyValueCodec, KeyValueResource
 from superset.sql.parse import (
     BaseSQLStatement,
     LimitMethod,
@@ -83,7 +84,11 @@ from superset.utils.core import ColumnSpec, GenericDataType, 
QuerySource
 from superset.utils.hashing import hash_from_str
 from superset.utils.json import redact_sensitive, reveal_sensitive
 from superset.utils.network import is_hostname_valid, is_port_open
-from superset.utils.oauth2 import encode_oauth2_state
+from superset.utils.oauth2 import (
+    encode_oauth2_state,
+    generate_code_challenge,
+    generate_code_verifier,
+)
 
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import TableColumn
@@ -608,13 +613,38 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         tab sends a message to the original tab informing that authorization 
was
         successful (or not), and then closes. The original tab will 
automatically
         re-run the query after authorization.
+
+        PKCE (RFC 7636) is used to protect against authorization code 
interception
+        attacks. A code_verifier is generated and stored server-side in the KV 
store,
+        while the code_challenge (derived from the verifier) is sent to the
+        authorization server.
         """
+        # Prevent circular import.
+        from superset.daos.key_value import KeyValueDAO
+
         tab_id = str(uuid4())
         default_redirect_uri = app.config.get(
             "DATABASE_OAUTH2_REDIRECT_URI",
             url_for("DatabaseRestApi.oauth2", _external=True),
         )
 
+        # Generate PKCE code verifier (RFC 7636)
+        code_verifier = generate_code_verifier()
+
+        # Store the code_verifier server-side in the KV store, keyed by tab_id.
+        # This avoids exposing it in the URL/browser history via the JWT state.
+        KeyValueDAO.delete_expired_entries(KeyValueResource.PKCE_CODE_VERIFIER)
+        KeyValueDAO.create_entry(
+            resource=KeyValueResource.PKCE_CODE_VERIFIER,
+            value={"code_verifier": code_verifier},
+            codec=JsonKeyValueCodec(),
+            key=UUID(tab_id),
+            expires_on=datetime.now() + timedelta(minutes=5),
+        )
+        # We need to commit here because we're going to raise an exception, 
which will
+        # revert any non-commited changes.
+        db.session.commit()
+
         # The state is passed to the OAuth2 provider, and sent back to 
Superset after
         # the user authorizes the access. The redirect endpoint in Superset 
can then
         # inspect the state to figure out to which user/database the access 
token
@@ -641,7 +671,11 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         if oauth2_config is None:
             raise OAuth2Error("No configuration found for OAuth2")
 
-        oauth_url = cls.get_oauth2_authorization_uri(oauth2_config, state)
+        oauth_url = cls.get_oauth2_authorization_uri(
+            oauth2_config,
+            state,
+            code_verifier=code_verifier,
+        )
 
         raise OAuth2RedirectError(oauth_url, tab_id, default_redirect_uri)
 
@@ -685,21 +719,29 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         cls,
         config: OAuth2ClientConfig,
         state: OAuth2State,
+        code_verifier: str | None = None,
     ) -> str:
         """
         Return URI for initial OAuth2 request.
 
-        Uses standard OAuth 2.0 parameters only. Subclasses can override
-        to add provider-specific parameters (e.g., Google's prompt=consent).
+        Uses standard OAuth 2.0 parameters plus PKCE (RFC 7636) parameters.
+        Subclasses can override to add provider-specific parameters
+        (e.g., Google's prompt=consent).
         """
         uri = config["authorization_request_uri"]
-        params = {
+        params: dict[str, str] = {
             "scope": config["scope"],
             "response_type": "code",
             "state": encode_oauth2_state(state),
             "redirect_uri": config["redirect_uri"],
             "client_id": config["id"],
         }
+
+        # Add PKCE parameters (RFC 7636) if code_verifier is provided
+        if code_verifier:
+            params["code_challenge"] = generate_code_challenge(code_verifier)
+            params["code_challenge_method"] = "S256"
+
         return urljoin(uri, "?" + urlencode(params))
 
     @classmethod
@@ -707,19 +749,27 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         cls,
         config: OAuth2ClientConfig,
         code: str,
+        code_verifier: str | None = None,
     ) -> OAuth2TokenResponse:
         """
         Exchange authorization code for refresh/access tokens.
+
+        If code_verifier is provided (PKCE flow), it will be included in the
+        token request per RFC 7636.
         """
         timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
         uri = config["token_request_uri"]
-        req_body = {
+        req_body: dict[str, str] = {
             "code": code,
             "client_id": config["id"],
             "client_secret": config["secret"],
             "redirect_uri": config["redirect_uri"],
             "grant_type": "authorization_code",
         }
+        # Add PKCE code_verifier if present (RFC 7636)
+        if code_verifier:
+            req_body["code_verifier"] = code_verifier
+
         response = (
             requests.post(uri, data=req_body, timeout=timeout)
             if config["request_content_type"] == "data"
diff --git a/superset/db_engine_specs/gsheets.py 
b/superset/db_engine_specs/gsheets.py
index 86eadf6c5a..780f92cc75 100644
--- a/superset/db_engine_specs/gsheets.py
+++ b/superset/db_engine_specs/gsheets.py
@@ -161,6 +161,7 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
         cls,
         config: "OAuth2ClientConfig",
         state: "OAuth2State",
+        code_verifier: str | None = None,
     ) -> str:
         """
         Return URI for initial OAuth2 request with Google-specific parameters.
@@ -172,10 +173,10 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
         """
         from urllib.parse import urlencode, urljoin
 
-        from superset.utils.oauth2 import encode_oauth2_state
+        from superset.utils.oauth2 import encode_oauth2_state, 
generate_code_challenge
 
         uri = config["authorization_request_uri"]
-        params = {
+        params: dict[str, str] = {
             "scope": config["scope"],
             "response_type": "code",
             "state": encode_oauth2_state(state),
@@ -186,6 +187,12 @@ class GSheetsEngineSpec(ShillelaghEngineSpec):
             "include_granted_scopes": "false",
             "prompt": "consent",
         }
+
+        # Add PKCE parameters (RFC 7636) if code_verifier is provided
+        if code_verifier:
+            params["code_challenge"] = generate_code_challenge(code_verifier)
+            params["code_challenge_method"] = "S256"
+
         return urljoin(uri, "?" + urlencode(params))
 
     @classmethod
diff --git a/superset/key_value/types.py b/superset/key_value/types.py
index 3b2da06493..2cc025e426 100644
--- a/superset/key_value/types.py
+++ b/superset/key_value/types.py
@@ -45,6 +45,7 @@ class KeyValueResource(StrEnum):
     EXPLORE_PERMALINK = "explore_permalink"
     METASTORE_CACHE = "superset_metastore_cache"
     LOCK = "lock"
+    PKCE_CODE_VERIFIER = "pkce_code_verifier"
     SQLLAB_PERMALINK = "sqllab_permalink"
 
 
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index 4d409398d1..02e294a08c 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -356,7 +356,7 @@ class OAuth2TokenResponse(TypedDict, total=False):
     refresh_token: str
 
 
-class OAuth2State(TypedDict):
+class OAuth2State(TypedDict, total=False):
     """
     Type for the state passed during OAuth2.
     """
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index 0124f57308..9a24f0c095 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -17,7 +17,10 @@
 
 from __future__ import annotations
 
+import base64
+import hashlib
 import logging
+import secrets
 from contextlib import contextmanager
 from datetime import datetime, timedelta, timezone
 from typing import Any, Iterator, TYPE_CHECKING
@@ -40,6 +43,37 @@ JWT_EXPIRATION = timedelta(minutes=5)
 
 logger = logging.getLogger(__name__)
 
+# PKCE code verifier length (RFC 7636 recommends 43-128 characters)
+PKCE_CODE_VERIFIER_LENGTH = 64
+
+
+def generate_code_verifier() -> str:
+    """
+    Generate a PKCE code verifier (RFC 7636).
+
+    The code verifier is a high-entropy cryptographic random string using
+    unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~",
+    with a minimum length of 43 characters and a maximum length of 128.
+    """
+    # Generate random bytes and encode as URL-safe base64
+    random_bytes = secrets.token_bytes(PKCE_CODE_VERIFIER_LENGTH)
+    # Use URL-safe base64 encoding without padding
+    code_verifier = 
base64.urlsafe_b64encode(random_bytes).rstrip(b"=").decode("ascii")
+    return code_verifier
+
+
+def generate_code_challenge(code_verifier: str) -> str:
+    """
+    Generate a PKCE code challenge from a code verifier (RFC 7636).
+
+    Uses the S256 method: BASE64URL(SHA256(code_verifier))
+    """
+    # Compute SHA-256 hash of the code verifier
+    digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
+    # Encode as URL-safe base64 without padding
+    code_challenge = 
base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
+    return code_challenge
+
 
 @backoff.on_exception(
     backoff.expo,
@@ -140,13 +174,14 @@ def encode_oauth2_state(state: OAuth2State) -> str:
     """
     Encode the OAuth2 state.
     """
-    payload = {
+    payload: dict[str, Any] = {
         "exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
         "database_id": state["database_id"],
         "user_id": state["user_id"],
         "default_redirect_uri": state["default_redirect_uri"],
         "tab_id": state["tab_id"],
     }
+
     encoded_state = jwt.encode(
         payload=payload,
         key=app.config["SECRET_KEY"],
@@ -172,12 +207,12 @@ class OAuth2StateSchema(Schema):
         data: dict[str, Any],
         **kwargs: Any,
     ) -> OAuth2State:
-        return OAuth2State(
-            database_id=data["database_id"],
-            user_id=data["user_id"],
-            default_redirect_uri=data["default_redirect_uri"],
-            tab_id=data["tab_id"],
-        )
+        return {
+            "database_id": data["database_id"],
+            "user_id": data["user_id"],
+            "default_redirect_uri": data["default_redirect_uri"],
+            "tab_id": data["tab_id"],
+        }
 
     class Meta:  # pylint: disable=too-few-public-methods
         # ignore `exp`
diff --git a/tests/unit_tests/databases/api_test.py 
b/tests/unit_tests/databases/api_test.py
index 7b77f5099d..244d75f7e2 100644
--- a/tests/unit_tests/databases/api_test.py
+++ b/tests/unit_tests/databases/api_test.py
@@ -255,7 +255,7 @@ def test_database_connection(
                     "service_account_info": {
                         "type": "service_account",
                         "project_id": "black-sanctum-314419",
-                        "private_key_id": 
"259b0d419a8f840056158763ff54d8b08f7b8173",
+                        "private_key_id": 
"259b0d419a8f840056158763ff54d8b08f7b8173",  # noqa: E501
                         "private_key": "XXXXXXXXXX",
                         "client_email": 
"google-spreadsheets-demo-se...@black-sanctum-314419.iam.gserviceaccount.com",  
# noqa: E501
                         "client_id": "114567578578109757129",
@@ -621,6 +621,10 @@ def test_oauth2_happy_path(
         "expires_in": 3600,
         "refresh_token": "ZZZ",
     }
+    mocker.patch(
+        "superset.commands.database.oauth2.KeyValueDAO.get_value",
+        return_value=None,
+    )
 
     state: OAuth2State = {
         "user_id": 1,
@@ -641,7 +645,11 @@ def test_oauth2_happy_path(
         )
 
     assert response.status_code == 200
-    get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
+    get_oauth2_token.assert_called_with(
+        {"id": "one", "secret": "two"},
+        "XXX",
+        code_verifier=None,
+    )
 
     token = db.session.query(DatabaseUserOAuth2Tokens).one()
     assert token.user_id == 1
@@ -689,6 +697,10 @@ def test_oauth2_permissions(
         "expires_in": 3600,
         "refresh_token": "ZZZ",
     }
+    mocker.patch(
+        "superset.commands.database.oauth2.KeyValueDAO.get_value",
+        return_value=None,
+    )
 
     state: OAuth2State = {
         "user_id": 1,
@@ -709,7 +721,11 @@ def test_oauth2_permissions(
         )
 
     assert response.status_code == 200
-    get_oauth2_token.assert_called_with({"id": "one", "secret": "two"}, "XXX")
+    get_oauth2_token.assert_called_with(
+        {"id": "one", "secret": "two"},
+        "XXX",
+        code_verifier=None,
+    )
 
     token = db.session.query(DatabaseUserOAuth2Tokens).one()
     assert token.user_id == 1
@@ -762,6 +778,10 @@ def test_oauth2_multiple_tokens(
             "refresh_token": "ZZZ2",
         },
     ]
+    mocker.patch(
+        "superset.commands.database.oauth2.KeyValueDAO.get_value",
+        return_value=None,
+    )
 
     state: OAuth2State = {
         "user_id": 1,
diff --git a/tests/unit_tests/db_engine_specs/test_base.py 
b/tests/unit_tests/db_engine_specs/test_base.py
index ccfe2b337f..5d8c330410 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -889,6 +889,124 @@ def 
test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) ->
     assert "access_type" not in query
     assert "include_granted_scopes" not in query
 
+    # Verify PKCE parameters are NOT included when code_verifier is not 
provided
+    assert "code_challenge" not in query
+    assert "code_challenge_method" not in query
+
+
+def test_get_oauth2_authorization_uri_with_pkce(mocker: MockerFixture) -> None:
+    """
+    Test that BaseEngineSpec.get_oauth2_authorization_uri includes PKCE 
parameters
+    when code_verifier is passed as a parameter (RFC 7636).
+    """
+    from urllib.parse import parse_qs, urlparse
+
+    from superset.db_engine_specs.base import BaseEngineSpec
+    from superset.utils.oauth2 import generate_code_challenge, 
generate_code_verifier
+
+    config: OAuth2ClientConfig = {
+        "id": "client-id",
+        "secret": "client-secret",
+        "scope": "read write",
+        "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/";,
+        "authorization_request_uri": "https://oauth.example.com/authorize";,
+        "token_request_uri": "https://oauth.example.com/token";,
+        "request_content_type": "json",
+    }
+
+    code_verifier = generate_code_verifier()
+    state: OAuth2State = {
+        "database_id": 1,
+        "user_id": 1,
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
+        "tab_id": "1234",
+    }
+
+    url = BaseEngineSpec.get_oauth2_authorization_uri(
+        config, state, code_verifier=code_verifier
+    )
+    parsed = urlparse(url)
+    query = parse_qs(parsed.query)
+
+    # Verify PKCE parameters are included (RFC 7636)
+    assert "code_challenge" in query
+    assert query["code_challenge_method"][0] == "S256"
+    # Verify the code_challenge matches the expected value
+    expected_challenge = generate_code_challenge(code_verifier)
+    assert query["code_challenge"][0] == expected_challenge
+
+
+def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None:
+    """
+    Test that BaseEngineSpec.get_oauth2_token works without PKCE code_verifier.
+    """
+    from superset.db_engine_specs.base import BaseEngineSpec
+
+    mocker.patch(
+        "flask.current_app.config",
+        {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 
30)},
+    )
+    mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
+    mock_post.return_value.json.return_value = {
+        "access_token": "test-access-token",  # noqa: S105
+        "expires_in": 3600,
+    }
+
+    config: OAuth2ClientConfig = {
+        "id": "client-id",
+        "secret": "client-secret",
+        "scope": "read write",
+        "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/";,
+        "authorization_request_uri": "https://oauth.example.com/authorize";,
+        "token_request_uri": "https://oauth.example.com/token";,
+        "request_content_type": "json",
+    }
+
+    result = BaseEngineSpec.get_oauth2_token(config, "auth-code")
+
+    assert result["access_token"] == "test-access-token"  # noqa: S105
+    # Verify code_verifier is NOT in the request body
+    call_kwargs = mock_post.call_args
+    request_body = call_kwargs.kwargs.get("json") or 
call_kwargs.kwargs.get("data")
+    assert "code_verifier" not in request_body
+
+
+def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None:
+    """
+    Test BaseEngineSpec.get_oauth2_token includes code_verifier when provided.
+    """
+    from superset.db_engine_specs.base import BaseEngineSpec
+    from superset.utils.oauth2 import generate_code_verifier
+
+    mocker.patch(
+        "flask.current_app.config",
+        {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 
30)},
+    )
+    mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
+    mock_post.return_value.json.return_value = {
+        "access_token": "test-access-token",  # noqa: S105
+        "expires_in": 3600,
+    }
+
+    config: OAuth2ClientConfig = {
+        "id": "client-id",
+        "secret": "client-secret",
+        "scope": "read write",
+        "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/";,
+        "authorization_request_uri": "https://oauth.example.com/authorize";,
+        "token_request_uri": "https://oauth.example.com/token";,
+        "request_content_type": "json",
+    }
+
+    code_verifier = generate_code_verifier()
+    result = BaseEngineSpec.get_oauth2_token(config, "auth-code", 
code_verifier)
+
+    assert result["access_token"] == "test-access-token"  # noqa: S105
+    # Verify code_verifier IS in the request body (PKCE)
+    call_kwargs = mock_post.call_args
+    request_body = call_kwargs.kwargs.get("json") or 
call_kwargs.kwargs.get("data")
+    assert request_body["code_verifier"] == code_verifier
+
 
 def test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> 
None:
     """
@@ -904,6 +1022,8 @@ def 
test_start_oauth2_dance_uses_config_redirect_uri(mocker: MockerFixture) -> N
             "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
         },
     )
+    mocker.patch("superset.daos.key_value.KeyValueDAO")
+    mocker.patch("superset.db_engine_specs.base.db")
 
     g = mocker.patch("superset.db_engine_specs.base.g")
     g.user.id = 1
@@ -944,6 +1064,8 @@ def test_start_oauth2_dance_falls_back_to_url_for(mocker: 
MockerFixture) -> None
         "superset.db_engine_specs.base.url_for",
         return_value=fallback_uri,
     )
+    mocker.patch("superset.daos.key_value.KeyValueDAO")
+    mocker.patch("superset.db_engine_specs.base.db")
 
     g = mocker.patch("superset.db_engine_specs.base.g")
     g.user.id = 1
diff --git a/tests/unit_tests/sql_lab_test.py b/tests/unit_tests/sql_lab_test.py
index 6e221e1494..9fd3a0ac0e 100644
--- a/tests/unit_tests/sql_lab_test.py
+++ b/tests/unit_tests/sql_lab_test.py
@@ -18,6 +18,7 @@
 
 import json  # noqa: TID251
 from unittest.mock import MagicMock
+from urllib.parse import parse_qs, urlparse
 from uuid import UUID
 
 import pytest
@@ -201,6 +202,13 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, 
app) -> None:
         "superset.db_engine_specs.base.uuid4",
         return_value=UUID("fb11f528-6eba-4a8a-837e-6b0d39ee9187"),
     )
+    mocker.patch(
+        "superset.db_engine_specs.base.generate_code_verifier",
+        
return_value="xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ",
+    )
+    mocker.patch("superset.daos.key_value.KeyValueDAO.delete_expired_entries")
+    mocker.patch("superset.daos.key_value.KeyValueDAO.create_entry")
+    mocker.patch("superset.db_engine_specs.base.db.session.commit")
 
     g = mocker.patch("superset.db_engine_specs.base.g")
     g.user = mocker.MagicMock()
@@ -222,22 +230,39 @@ def test_get_sql_results_oauth2(mocker: MockerFixture, 
app) -> None:
     mocker.patch("superset.sql_lab.get_query", return_value=query)
 
     payload = get_sql_results(query_id=1, rendered_query="SELECT 1")
-    assert payload == {
-        "status": QueryStatus.FAILED,
-        "error": "You don't have permission to access the data.",
-        "errors": [
-            {
-                "message": "You don't have permission to access the data.",
-                "error_type": SupersetErrorType.OAUTH2_REDIRECT,
-                "level": ErrorLevel.WARNING,
-                "extra": {
-                    "url": 
"https://abcd1234.snowflakecomputing.com/oauth/authorize?scope=refresh_token+session%3Arole%3AUSERADMIN&response_type=code&state=eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9%252EeyJleHAiOjE2MTcyMzU1MDAsImRhdGFiYXNlX2lkIjoxLCJ1c2VyX2lkIjo0MiwiZGVmYXVsdF9yZWRpcmVjdF91cmkiOiJodHRwOi8vbG9jYWxob3N0L2FwaS92MS9kYXRhYmFzZS9vYXV0aDIvIiwidGFiX2lkIjoiZmIxMWY1MjgtNmViYS00YThhLTgzN2UtNmIwZDM5ZWU5MTg3In0%252E7nLkei6-V8sVk_Pgm8cFhk0tnKRKayRE1Vc7RxuM9mw&redirect_uri=http%3A%2F%2Flocal
 [...]
-                    "tab_id": "fb11f528-6eba-4a8a-837e-6b0d39ee9187",
-                    "redirect_uri": "http://localhost/api/v1/database/oauth2/";,
-                },
-            }
-        ],
-    }
+    assert payload["status"] == QueryStatus.FAILED
+    assert payload["error"] == "You don't have permission to access the data."
+    assert len(payload["errors"]) == 1
+
+    error = payload["errors"][0]
+    assert error["message"] == "You don't have permission to access the data."
+    assert error["error_type"] == SupersetErrorType.OAUTH2_REDIRECT
+    assert error["level"] == ErrorLevel.WARNING
+    assert error["extra"]["tab_id"] == "fb11f528-6eba-4a8a-837e-6b0d39ee9187"
+    assert error["extra"]["redirect_uri"] == 
"http://localhost/api/v1/database/oauth2/";
+
+    # Parse the OAuth2 authorization URL and verify components individually,
+    # since the JWT state and PKCE code_challenge are computed 
deterministically
+    # from mocked inputs but their exact encoding depends on library internals.
+    url = urlparse(error["extra"]["url"])
+    assert url.scheme == "https"
+    assert url.netloc == "abcd1234.snowflakecomputing.com"
+    assert url.path == "/oauth/authorize"
+
+    params = parse_qs(url.query)
+    assert params["scope"] == ["refresh_token session:role:USERADMIN"]
+    assert params["response_type"] == ["code"]
+    assert params["redirect_uri"] == 
["http://localhost/api/v1/database/oauth2/";]
+    assert params["client_id"] == ["my_client_id"]
+    assert params["code_challenge_method"] == ["S256"]
+
+    # Verify PKCE code_challenge matches the mocked code_verifier
+    from superset.utils.oauth2 import generate_code_challenge
+
+    expected_code_challenge = generate_code_challenge(
+        
"xkBPVZoFChVcy3VZ2l5u7d0FZPTU-olO7HtsAOok2IUGigyoZ62tG_oldy2xg9_HdqPKrWUmKZLmU-CUqz_SQ"
+    )
+    assert params["code_challenge"] == [expected_code_challenge]
 
 
 def test_apply_rls(mocker: MockerFixture) -> None:
diff --git a/tests/unit_tests/utils/oauth2_tests.py 
b/tests/unit_tests/utils/oauth2_tests.py
index fc3ed7a651..33a0c0c266 100644
--- a/tests/unit_tests/utils/oauth2_tests.py
+++ b/tests/unit_tests/utils/oauth2_tests.py
@@ -17,6 +17,8 @@
 
 # pylint: disable=invalid-name, disallowed-name
 
+import base64
+import hashlib
 from datetime import datetime
 from typing import cast
 
@@ -25,7 +27,14 @@ from freezegun import freeze_time
 from pytest_mock import MockerFixture
 
 from superset.superset_typing import OAuth2ClientConfig
-from superset.utils.oauth2 import get_oauth2_access_token, refresh_oauth2_token
+from superset.utils.oauth2 import (
+    decode_oauth2_state,
+    encode_oauth2_state,
+    generate_code_challenge,
+    generate_code_verifier,
+    get_oauth2_access_token,
+    refresh_oauth2_token,
+)
 
 DUMMY_OAUTH2_CONFIG = cast(OAuth2ClientConfig, {})
 
@@ -177,3 +186,96 @@ def test_refresh_oauth2_token_no_access_token_in_response(
     result = refresh_oauth2_token(DUMMY_OAUTH2_CONFIG, 1, 1, db_engine_spec, 
token)
 
     assert result is None
+
+
+def test_generate_code_verifier_length() -> None:
+    """
+    Test that generate_code_verifier produces a string of valid length (RFC 
7636).
+    """
+    code_verifier = generate_code_verifier()
+    # RFC 7636 requires 43-128 characters
+    assert 43 <= len(code_verifier) <= 128
+
+
+def test_generate_code_verifier_uniqueness() -> None:
+    """
+    Test that generate_code_verifier produces unique values.
+    """
+    verifiers = {generate_code_verifier() for _ in range(100)}
+    # All generated verifiers should be unique
+    assert len(verifiers) == 100
+
+
+def test_generate_code_verifier_valid_characters() -> None:
+    """
+    Test that generate_code_verifier only uses valid characters (RFC 7636).
+    """
+    code_verifier = generate_code_verifier()
+    # RFC 7636 allows: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
+    # URL-safe base64 uses: [A-Z] / [a-z] / [0-9] / "-" / "_"
+    valid_chars = set(
+        "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
+    )
+    assert all(char in valid_chars for char in code_verifier)
+
+
+def test_generate_code_challenge_s256() -> None:
+    """
+    Test that generate_code_challenge produces correct S256 challenge.
+    """
+    # Use a known code_verifier to verify the challenge computation
+    code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
+
+    # Compute expected challenge manually
+    digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
+    expected_challenge = 
base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
+
+    code_challenge = generate_code_challenge(code_verifier)
+    assert code_challenge == expected_challenge
+
+
+def test_generate_code_challenge_rfc_example() -> None:
+    """
+    Test PKCE code challenge against RFC 7636 Appendix B example.
+
+    See: https://datatracker.ietf.org/doc/html/rfc7636#appendix-B
+    """
+    # RFC 7636 example code_verifier (Appendix B)
+    code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
+    # RFC 7636 expected code_challenge for S256 method
+    expected_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
+
+    code_challenge = generate_code_challenge(code_verifier)
+    assert code_challenge == expected_challenge
+
+
+def test_encode_decode_oauth2_state(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test that encode/decode cycle preserves state fields.
+    """
+    from superset.superset_typing import OAuth2State
+
+    mocker.patch(
+        "flask.current_app.config",
+        {
+            "SECRET_KEY": "test-secret-key",
+            "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
+        },
+    )
+
+    state: OAuth2State = {
+        "database_id": 1,
+        "user_id": 2,
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
+        "tab_id": "test-tab-id",
+    }
+
+    with freeze_time("2024-01-01"):
+        encoded = encode_oauth2_state(state)
+        decoded = decode_oauth2_state(encoded)
+
+    assert "code_verifier" not in decoded
+    assert decoded["database_id"] == 1
+    assert decoded["user_id"] == 2

Reply via email to