This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch support-pkce
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 20211f29b2fdbb6a44b9c19fb71991adc2153944
Author: Beto Dealmeida <[email protected]>
AuthorDate: Mon Jan 12 17:22:30 2026 -0500

    feat: support PKCE in OAuth2 flow
---
 superset/commands/database/oauth2.py          |   3 +
 superset/db_engine_specs/base.py              |  39 +++++++-
 superset/superset_typing.py                   |   4 +-
 superset/utils/oauth2.py                      |  58 +++++++++--
 tests/unit_tests/db_engine_specs/test_base.py | 120 +++++++++++++++++++++++
 tests/unit_tests/utils/oauth2_tests.py        | 135 +++++++++++++++++++++++++-
 6 files changed, 345 insertions(+), 14 deletions(-)

diff --git a/superset/commands/database/oauth2.py 
b/superset/commands/database/oauth2.py
index f7259077bc..71908b3caa 100644
--- a/superset/commands/database/oauth2.py
+++ b/superset/commands/database/oauth2.py
@@ -50,9 +50,12 @@ class OAuth2StoreTokenCommand(BaseCommand):
         if oauth2_config is None:
             raise OAuth2Error("No configuration found for OAuth2")
 
+        # Pass PKCE code_verifier if present in state (RFC 7636)
+        code_verifier = self._state.get("code_verifier")
         token_response = self._database.db_engine_spec.get_oauth2_token(
             oauth2_config,
             self._parameters["code"],
+            code_verifier=code_verifier,
         )
 
         # delete old tokens
diff --git a/superset/db_engine_specs/base.py b/superset/db_engine_specs/base.py
index 6c0cd77478..430e0bee9a 100644
--- a/superset/db_engine_specs/base.py
+++ b/superset/db_engine_specs/base.py
@@ -83,7 +83,11 @@ from superset.utils.core import ColumnSpec, GenericDataType, 
QuerySource
 from superset.utils.hashing import hash_from_str
 from superset.utils.json import redact_sensitive, reveal_sensitive
 from superset.utils.network import is_hostname_valid, is_port_open
-from superset.utils.oauth2 import encode_oauth2_state
+from superset.utils.oauth2 import (
+    encode_oauth2_state,
+    generate_code_challenge,
+    generate_code_verifier,
+)
 
 if TYPE_CHECKING:
     from superset.connectors.sqla.models import TableColumn
@@ -474,10 +478,17 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         tab sends a message to the original tab informing that authorization 
was
         successful (or not), and then closes. The original tab will 
automatically
         re-run the query after authorization.
+
+        PKCE (RFC 7636) is used to protect against authorization code 
interception
+        attacks. A code_verifier is generated and stored in the state, while 
the
+        code_challenge (derived from the verifier) is sent to the 
authorization server.
         """
         tab_id = str(uuid4())
         default_redirect_uri = url_for("DatabaseRestApi.oauth2", 
_external=True)
 
+        # Generate PKCE code verifier (RFC 7636)
+        code_verifier = generate_code_verifier()
+
         # The state is passed to the OAuth2 provider, and sent back to 
Superset after
         # the user authorizes the access. The redirect endpoint in Superset 
can then
         # inspect the state to figure out to which user/database the access 
token
@@ -499,6 +510,8 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
             # UUID to the original tab, and the second tab will use it when 
sending the
             # message.
             "tab_id": tab_id,
+            # PKCE code verifier stored in state to be retrieved during token 
exchange
+            "code_verifier": code_verifier,
         }
         oauth2_config = database.get_oauth2_config()
         if oauth2_config is None:
@@ -552,17 +565,24 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         """
         Return URI for initial OAuth2 request.
 
-        Uses standard OAuth 2.0 parameters only. Subclasses can override
-        to add provider-specific parameters (e.g., Google's prompt=consent).
+        Uses standard OAuth 2.0 parameters plus PKCE (RFC 7636) parameters.
+        Subclasses can override to add provider-specific parameters
+        (e.g., Google's prompt=consent).
         """
         uri = config["authorization_request_uri"]
-        params = {
+        params: dict[str, str] = {
             "scope": config["scope"],
             "response_type": "code",
             "state": encode_oauth2_state(state),
             "redirect_uri": config["redirect_uri"],
             "client_id": config["id"],
         }
+
+        # Add PKCE parameters (RFC 7636) if code_verifier is present in state
+        if "code_verifier" in state:
+            params["code_challenge"] = 
generate_code_challenge(state["code_verifier"])
+            params["code_challenge_method"] = "S256"
+
         return urljoin(uri, "?" + urlencode(params))
 
     @classmethod
@@ -570,19 +590,28 @@ class BaseEngineSpec:  # pylint: 
disable=too-many-public-methods
         cls,
         config: OAuth2ClientConfig,
         code: str,
+        code_verifier: str | None = None,
     ) -> OAuth2TokenResponse:
         """
         Exchange authorization code for refresh/access tokens.
+
+        If code_verifier is provided (PKCE flow), it will be included in the
+        token request per RFC 7636.
         """
         timeout = app.config["DATABASE_OAUTH2_TIMEOUT"].total_seconds()
         uri = config["token_request_uri"]
-        req_body = {
+        req_body: dict[str, str] = {
             "code": code,
             "client_id": config["id"],
             "client_secret": config["secret"],
             "redirect_uri": config["redirect_uri"],
             "grant_type": "authorization_code",
         }
+
+        # Add PKCE code_verifier if present (RFC 7636)
+        if code_verifier:
+            req_body["code_verifier"] = code_verifier
+
         if config["request_content_type"] == "data":
             return requests.post(uri, data=req_body, timeout=timeout).json()
         return requests.post(uri, json=req_body, timeout=timeout).json()
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index 105a28d4cf..a1d36e811f 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -354,7 +354,7 @@ class OAuth2TokenResponse(TypedDict, total=False):
     refresh_token: str
 
 
-class OAuth2State(TypedDict):
+class OAuth2State(TypedDict, total=False):
     """
     Type for the state passed during OAuth2.
     """
@@ -363,3 +363,5 @@ class OAuth2State(TypedDict):
     user_id: int
     default_redirect_uri: str
     tab_id: str
+    # PKCE code verifier (RFC 7636) - stored in state during token exchange
+    code_verifier: str
diff --git a/superset/utils/oauth2.py b/superset/utils/oauth2.py
index ebe1f4012e..02fba6ac0a 100644
--- a/superset/utils/oauth2.py
+++ b/superset/utils/oauth2.py
@@ -17,6 +17,9 @@
 
 from __future__ import annotations
 
+import base64
+import hashlib
+import secrets
 from contextlib import contextmanager
 from datetime import datetime, timedelta, timezone
 from typing import Any, Iterator, TYPE_CHECKING
@@ -37,6 +40,37 @@ if TYPE_CHECKING:
 
 JWT_EXPIRATION = timedelta(minutes=5)
 
+# PKCE code verifier length (RFC 7636 recommends 43-128 characters)
+PKCE_CODE_VERIFIER_LENGTH = 64
+
+
+def generate_code_verifier() -> str:
+    """
+    Generate a PKCE code verifier (RFC 7636).
+
+    The code verifier is a high-entropy cryptographic random string using
+    unreserved characters [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~",
+    with a minimum length of 43 characters and a maximum length of 128.
+    """
+    # Generate random bytes and encode as URL-safe base64
+    random_bytes = secrets.token_bytes(PKCE_CODE_VERIFIER_LENGTH)
+    # Use URL-safe base64 encoding without padding
+    code_verifier = 
base64.urlsafe_b64encode(random_bytes).rstrip(b"=").decode("ascii")
+    return code_verifier
+
+
+def generate_code_challenge(code_verifier: str) -> str:
+    """
+    Generate a PKCE code challenge from a code verifier (RFC 7636).
+
+    Uses the S256 method: BASE64URL(SHA256(code_verifier))
+    """
+    # Compute SHA-256 hash of the code verifier
+    digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
+    # Encode as URL-safe base64 without padding
+    code_challenge = 
base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
+    return code_challenge
+
 
 @backoff.on_exception(
     backoff.expo,
@@ -119,13 +153,17 @@ def encode_oauth2_state(state: OAuth2State) -> str:
     """
     Encode the OAuth2 state.
     """
-    payload = {
+    payload: dict[str, Any] = {
         "exp": datetime.now(tz=timezone.utc) + JWT_EXPIRATION,
         "database_id": state["database_id"],
         "user_id": state["user_id"],
         "default_redirect_uri": state["default_redirect_uri"],
         "tab_id": state["tab_id"],
     }
+    # Include PKCE code_verifier if present (RFC 7636)
+    if "code_verifier" in state:
+        payload["code_verifier"] = state["code_verifier"]
+
     encoded_state = jwt.encode(
         payload=payload,
         key=app.config["SECRET_KEY"],
@@ -143,6 +181,8 @@ class OAuth2StateSchema(Schema):
     user_id = fields.Int(required=True)
     default_redirect_uri = fields.Str(required=True)
     tab_id = fields.Str(required=True)
+    # PKCE code verifier (RFC 7636) - optional for backward compatibility
+    code_verifier = fields.Str(required=False, load_default=None)
 
     # pylint: disable=unused-argument
     @post_load
@@ -151,12 +191,16 @@ class OAuth2StateSchema(Schema):
         data: dict[str, Any],
         **kwargs: Any,
     ) -> OAuth2State:
-        return OAuth2State(
-            database_id=data["database_id"],
-            user_id=data["user_id"],
-            default_redirect_uri=data["default_redirect_uri"],
-            tab_id=data["tab_id"],
-        )
+        state: OAuth2State = {
+            "database_id": data["database_id"],
+            "user_id": data["user_id"],
+            "default_redirect_uri": data["default_redirect_uri"],
+            "tab_id": data["tab_id"],
+        }
+        # Include code_verifier if present (PKCE)
+        if data.get("code_verifier"):
+            state["code_verifier"] = data["code_verifier"]
+        return state
 
     class Meta:  # pylint: disable=too-few-public-methods
         # ignore `exp`
diff --git a/tests/unit_tests/db_engine_specs/test_base.py 
b/tests/unit_tests/db_engine_specs/test_base.py
index bff4c93117..7a4b8f9206 100644
--- a/tests/unit_tests/db_engine_specs/test_base.py
+++ b/tests/unit_tests/db_engine_specs/test_base.py
@@ -943,3 +943,123 @@ def 
test_get_oauth2_authorization_uri_standard_params(mocker: MockerFixture) ->
     assert "prompt" not in query
     assert "access_type" not in query
     assert "include_granted_scopes" not in query
+
+    # Verify PKCE parameters are NOT included when code_verifier is not in 
state
+    assert "code_challenge" not in query
+    assert "code_challenge_method" not in query
+
+
+def test_get_oauth2_authorization_uri_with_pkce(mocker: MockerFixture) -> None:
+    """
+    Test that BaseEngineSpec.get_oauth2_authorization_uri includes PKCE 
parameters
+    when code_verifier is present in state (RFC 7636).
+    """
+    from urllib.parse import parse_qs, urlparse
+
+    from superset.db_engine_specs.base import BaseEngineSpec
+    from superset.superset_typing import OAuth2ClientConfig, OAuth2State
+    from superset.utils.oauth2 import generate_code_challenge, 
generate_code_verifier
+
+    config: OAuth2ClientConfig = {
+        "id": "client-id",
+        "secret": "client-secret",
+        "scope": "read write",
+        "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/";,
+        "authorization_request_uri": "https://oauth.example.com/authorize";,
+        "token_request_uri": "https://oauth.example.com/token";,
+        "request_content_type": "json",
+    }
+
+    code_verifier = generate_code_verifier()
+    state: OAuth2State = {
+        "database_id": 1,
+        "user_id": 1,
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
+        "tab_id": "1234",
+        "code_verifier": code_verifier,
+    }
+
+    url = BaseEngineSpec.get_oauth2_authorization_uri(config, state)
+    parsed = urlparse(url)
+    query = parse_qs(parsed.query)
+
+    # Verify PKCE parameters are included (RFC 7636)
+    assert "code_challenge" in query
+    assert query["code_challenge_method"][0] == "S256"
+    # Verify the code_challenge matches the expected value
+    expected_challenge = generate_code_challenge(code_verifier)
+    assert query["code_challenge"][0] == expected_challenge
+
+
+def test_get_oauth2_token_without_pkce(mocker: MockerFixture) -> None:
+    """
+    Test that BaseEngineSpec.get_oauth2_token works without PKCE code_verifier.
+    """
+    from superset.db_engine_specs.base import BaseEngineSpec
+    from superset.superset_typing import OAuth2ClientConfig
+
+    mocker.patch(
+        "flask.current_app.config",
+        {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 
30)},
+    )
+    mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
+    mock_post.return_value.json.return_value = {
+        "access_token": "test-access-token",  # noqa: S105
+        "expires_in": 3600,
+    }
+
+    config: OAuth2ClientConfig = {
+        "id": "client-id",
+        "secret": "client-secret",
+        "scope": "read write",
+        "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/";,
+        "authorization_request_uri": "https://oauth.example.com/authorize";,
+        "token_request_uri": "https://oauth.example.com/token";,
+        "request_content_type": "json",
+    }
+
+    result = BaseEngineSpec.get_oauth2_token(config, "auth-code")
+
+    assert result["access_token"] == "test-access-token"  # noqa: S105
+    # Verify code_verifier is NOT in the request body
+    call_kwargs = mock_post.call_args
+    request_body = call_kwargs.kwargs.get("json") or 
call_kwargs.kwargs.get("data")
+    assert "code_verifier" not in request_body
+
+
+def test_get_oauth2_token_with_pkce(mocker: MockerFixture) -> None:
+    """
+    Test BaseEngineSpec.get_oauth2_token includes code_verifier when provided.
+    """
+    from superset.db_engine_specs.base import BaseEngineSpec
+    from superset.superset_typing import OAuth2ClientConfig
+    from superset.utils.oauth2 import generate_code_verifier
+
+    mocker.patch(
+        "flask.current_app.config",
+        {"DATABASE_OAUTH2_TIMEOUT": mocker.MagicMock(total_seconds=lambda: 
30)},
+    )
+    mock_post = mocker.patch("superset.db_engine_specs.base.requests.post")
+    mock_post.return_value.json.return_value = {
+        "access_token": "test-access-token",  # noqa: S105
+        "expires_in": 3600,
+    }
+
+    config: OAuth2ClientConfig = {
+        "id": "client-id",
+        "secret": "client-secret",
+        "scope": "read write",
+        "redirect_uri": "http://localhost:8088/api/v1/database/oauth2/";,
+        "authorization_request_uri": "https://oauth.example.com/authorize";,
+        "token_request_uri": "https://oauth.example.com/token";,
+        "request_content_type": "json",
+    }
+
+    code_verifier = generate_code_verifier()
+    result = BaseEngineSpec.get_oauth2_token(config, "auth-code", 
code_verifier)
+
+    assert result["access_token"] == "test-access-token"  # noqa: S105
+    # Verify code_verifier IS in the request body (PKCE)
+    call_kwargs = mock_post.call_args
+    request_body = call_kwargs.kwargs.get("json") or 
call_kwargs.kwargs.get("data")
+    assert request_body["code_verifier"] == code_verifier
diff --git a/tests/unit_tests/utils/oauth2_tests.py 
b/tests/unit_tests/utils/oauth2_tests.py
index e9aa283b1a..0b14e1868a 100644
--- a/tests/unit_tests/utils/oauth2_tests.py
+++ b/tests/unit_tests/utils/oauth2_tests.py
@@ -17,12 +17,20 @@
 
 # pylint: disable=invalid-name, disallowed-name
 
+import base64
+import hashlib
 from datetime import datetime
 
 from freezegun import freeze_time
 from pytest_mock import MockerFixture
 
-from superset.utils.oauth2 import get_oauth2_access_token
+from superset.utils.oauth2 import (
+    decode_oauth2_state,
+    encode_oauth2_state,
+    generate_code_challenge,
+    generate_code_verifier,
+    get_oauth2_access_token,
+)
 
 
 def test_get_oauth2_access_token_base_no_token(mocker: MockerFixture) -> None:
@@ -93,3 +101,128 @@ def test_get_oauth2_access_token_base_no_refresh(mocker: 
MockerFixture) -> None:
 
     # check that token was deleted
     db.session.delete.assert_called_with(token)
+
+
+def test_generate_code_verifier_length() -> None:
+    """
+    Test that generate_code_verifier produces a string of valid length (RFC 
7636).
+    """
+    code_verifier = generate_code_verifier()
+    # RFC 7636 requires 43-128 characters
+    assert 43 <= len(code_verifier) <= 128
+
+
+def test_generate_code_verifier_uniqueness() -> None:
+    """
+    Test that generate_code_verifier produces unique values.
+    """
+    verifiers = {generate_code_verifier() for _ in range(100)}
+    # All generated verifiers should be unique
+    assert len(verifiers) == 100
+
+
+def test_generate_code_verifier_valid_characters() -> None:
+    """
+    Test that generate_code_verifier only uses valid characters (RFC 7636).
+    """
+    code_verifier = generate_code_verifier()
+    # RFC 7636 allows: [A-Z] / [a-z] / [0-9] / "-" / "." / "_" / "~"
+    # URL-safe base64 uses: [A-Z] / [a-z] / [0-9] / "-" / "_"
+    valid_chars = set(
+        "ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_"
+    )
+    assert all(char in valid_chars for char in code_verifier)
+
+
+def test_generate_code_challenge_s256() -> None:
+    """
+    Test that generate_code_challenge produces correct S256 challenge.
+    """
+    # Use a known code_verifier to verify the challenge computation
+    code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
+
+    # Compute expected challenge manually
+    digest = hashlib.sha256(code_verifier.encode("ascii")).digest()
+    expected_challenge = 
base64.urlsafe_b64encode(digest).rstrip(b"=").decode("ascii")
+
+    code_challenge = generate_code_challenge(code_verifier)
+    assert code_challenge == expected_challenge
+
+
+def test_generate_code_challenge_rfc_example() -> None:
+    """
+    Test PKCE code challenge against RFC 7636 Appendix B example.
+
+    See: https://datatracker.ietf.org/doc/html/rfc7636#appendix-B
+    """
+    # RFC 7636 example code_verifier (Appendix B)
+    code_verifier = "dBjftJeZ4CVP-mB92K27uhbUJU1p1r_wW1gFWFOEjXk"
+    # RFC 7636 expected code_challenge for S256 method
+    expected_challenge = "E9Melhoa2OwvFrEMTJguCHaoeK1t8URWbuGJSstw-cM"
+
+    code_challenge = generate_code_challenge(code_verifier)
+    assert code_challenge == expected_challenge
+
+
+def test_encode_decode_oauth2_state_with_code_verifier(mocker: MockerFixture) 
-> None:
+    """
+    Test that code_verifier is preserved through encode/decode cycle.
+    """
+    from superset.superset_typing import OAuth2State
+
+    mocker.patch(
+        "flask.current_app.config",
+        {
+            "SECRET_KEY": "test-secret-key",
+            "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
+        },
+    )
+
+    code_verifier = generate_code_verifier()
+    state: OAuth2State = {
+        "database_id": 1,
+        "user_id": 2,
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
+        "tab_id": "test-tab-id",
+        "code_verifier": code_verifier,
+    }
+
+    with freeze_time("2024-01-01"):
+        encoded = encode_oauth2_state(state)
+        decoded = decode_oauth2_state(encoded)
+
+    assert decoded["code_verifier"] == code_verifier
+    assert decoded["database_id"] == 1
+    assert decoded["user_id"] == 2
+
+
+def test_encode_decode_oauth2_state_without_code_verifier(
+    mocker: MockerFixture,
+) -> None:
+    """
+    Test backward compatibility: state without code_verifier still works.
+    """
+    from superset.superset_typing import OAuth2State
+
+    mocker.patch(
+        "flask.current_app.config",
+        {
+            "SECRET_KEY": "test-secret-key",
+            "DATABASE_OAUTH2_JWT_ALGORITHM": "HS256",
+        },
+    )
+
+    state: OAuth2State = {
+        "database_id": 1,
+        "user_id": 2,
+        "default_redirect_uri": "http://localhost:8088/api/v1/oauth2/";,
+        "tab_id": "test-tab-id",
+    }
+
+    with freeze_time("2024-01-01"):
+        encoded = encode_oauth2_state(state)
+        decoded = decode_oauth2_state(encoded)
+
+    assert "code_verifier" not in decoded
+    assert decoded["database_id"] == 1
+    assert decoded["user_id"] == 2

Reply via email to