This is an automated email from the ASF dual-hosted git repository.
jshao pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push:
new 213bcc9f2 [#3755] improvement(client-python): Support
OAuth2TokenProvider for Python client (#4011)
213bcc9f2 is described below
commit 213bcc9f28102a3b472a8b2d9629525e9d00d269
Author: noidname01 <[email protected]>
AuthorDate: Fri Jul 19 10:53:57 2024 +0800
[#3755] improvement(client-python): Support OAuth2TokenProvider for Python
client (#4011)
### What changes were proposed in this pull request?
* Add `OAuth2TokenProvider` and `DefaultOAuth2TokenProvider` in
`client-python`
* There are some components and tests missing because it would be a big
code change if they were also done in this PR, they will be added in the
following PRs
- [ ] Error Handling: #4173
- [ ] Integration Test: #4208
* Modify test file structure, and found issue #4136, solve it by reset
environment variable.
### Why are the changes needed?
Fix: #3755, #4136
### Does this PR introduce _any_ user-facing change?
No
### How was this patch tested?
Add UT and tested by `./gradlew clients:client-python:unittest`
---------
Co-authored-by: TimWang <[email protected]>
---
.../client-python/gravitino/auth/auth_constants.py | 2 +
.../auth/default_oauth2_token_provider.py | 133 +++++++++++++++++++
.../gravitino/auth/oauth2_token_provider.py | 75 +++++++++++
.../gravitino/auth/simple_auth_provider.py | 4 +-
.../requests/oauth2_client_credential_request.py} | 15 ++-
.../dto/responses/oauth2_token_response.py | 55 ++++++++
.../client-python/gravitino/utils/http_client.py | 36 ++++--
clients/client-python/requirements-dev.txt | 3 +-
.../tests/integration/test_simple_auth_client.py | 2 +
.../unittests/auth/__init__.py} | 6 -
.../tests/unittests/auth/mock_base.py | 144 +++++++++++++++++++++
.../unittests/auth/test_oauth2_token_provider.py | 93 +++++++++++++
.../{ => auth}/test_simple_auth_provider.py | 4 +
13 files changed, 551 insertions(+), 21 deletions(-)
diff --git a/clients/client-python/gravitino/auth/auth_constants.py
b/clients/client-python/gravitino/auth/auth_constants.py
index 2494030fc..247abcaaa 100644
--- a/clients/client-python/gravitino/auth/auth_constants.py
+++ b/clients/client-python/gravitino/auth/auth_constants.py
@@ -21,4 +21,6 @@ under the License.
class AuthConstants:
HTTP_HEADER_AUTHORIZATION: str = "Authorization"
+ AUTHORIZATION_BEARER_HEADER: str = "Bearer "
+
AUTHORIZATION_BASIC_HEADER: str = "Basic "
diff --git
a/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
b/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
new file mode 100644
index 000000000..3fb730395
--- /dev/null
+++ b/clients/client-python/gravitino/auth/default_oauth2_token_provider.py
@@ -0,0 +1,133 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+import time
+import json
+import base64
+from typing import Optional
+from gravitino.auth.oauth2_token_provider import OAuth2TokenProvider
+from gravitino.dto.responses.oauth2_token_response import OAuth2TokenResponse
+from gravitino.dto.requests.oauth2_client_credential_request import (
+ OAuth2ClientCredentialRequest,
+)
+from gravitino.exceptions.base import GravitinoRuntimeException
+
+CLIENT_CREDENTIALS = "client_credentials"
+CREDENTIAL_SPLITTER = ":"
+TOKEN_SPLITTER = "."
+JWT_EXPIRE = "exp"
+
+
+class DefaultOAuth2TokenProvider(OAuth2TokenProvider):
+ """This class is the default implement of OAuth2TokenProvider."""
+
+ _credential: Optional[str]
+ _scope: Optional[str]
+ _path: Optional[str]
+ _token: Optional[str]
+
+ def __init__(
+ self,
+ uri: str = None,
+ credential: str = None,
+ scope: str = None,
+ path: str = None,
+ ):
+ super().__init__(uri)
+
+ self._credential = credential
+ self._scope = scope
+ self._path = path
+
+ self.validate()
+
+ self._token = self._fetch_token()
+
+ def validate(self):
+ assert (
+ self._credential and self._credential.strip()
+ ), "OAuth2TokenProvider must set credential"
+ assert self._scope and self._scope.strip(), "OAuth2TokenProvider must
set scope"
+ assert self._path and self._path.strip(), "OAuth2TokenProvider must
set path"
+
+ def _get_access_token(self) -> Optional[str]:
+
+ expires = self._expires_at_millis()
+
+ if expires is None:
+ return None
+
+ if expires > time.time() * 1000:
+ return self._token
+
+ self._token = self._fetch_token()
+ return self._token
+
+ def _parse_credential(self):
+ assert self._credential is not None, "Invalid credential: None"
+
+ credential_info = self._credential.split(CREDENTIAL_SPLITTER,
maxsplit=1)
+ client_id = None
+ client_secret = None
+
+ if len(credential_info) == 2:
+ client_id, client_secret = credential_info
+ elif len(credential_info) == 1:
+ client_secret = credential_info[0]
+ else:
+ raise GravitinoRuntimeException(f"Invalid credential:
{self._credential}")
+
+ return client_id, client_secret
+
+ def _fetch_token(self) -> str:
+
+ client_id, client_secret = self._parse_credential()
+
+ client_credential_request = OAuth2ClientCredentialRequest(
+ grant_type=CLIENT_CREDENTIALS,
+ client_id=client_id,
+ client_secret=client_secret,
+ scope=self._scope,
+ )
+
+ resp = self._client.post_form(
+ self._path, data=client_credential_request.to_dict()
+ )
+ oauth2_resp = OAuth2TokenResponse.from_json(resp.body,
infer_missing=True)
+ oauth2_resp.validate()
+
+ return oauth2_resp.access_token()
+
+ def _expires_at_millis(self) -> int:
+ if self._token is None:
+ return None
+
+ parts = self._token.split(TOKEN_SPLITTER)
+
+ if len(parts) != 3:
+ return None
+
+ jwt = json.loads(
+ base64.b64decode(parts[1] + "=" * (-len(parts[1]) %
4)).decode("utf-8")
+ )
+
+ if JWT_EXPIRE not in jwt or not isinstance(jwt[JWT_EXPIRE], int):
+ return None
+
+ return jwt[JWT_EXPIRE] * 1000
diff --git a/clients/client-python/gravitino/auth/oauth2_token_provider.py
b/clients/client-python/gravitino/auth/oauth2_token_provider.py
new file mode 100644
index 000000000..5d243053f
--- /dev/null
+++ b/clients/client-python/gravitino/auth/oauth2_token_provider.py
@@ -0,0 +1,75 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+from abc import abstractmethod
+from typing import Optional
+
+from gravitino.utils.http_client import HTTPClient
+from gravitino.auth.auth_data_provider import AuthDataProvider
+from gravitino.auth.auth_constants import AuthConstants
+
+
+class OAuth2TokenProvider(AuthDataProvider):
+ """OAuth2TokenProvider will request the access token from the
authorization server and then provide
+ the access token for every request.
+ """
+
+ # The HTTP client used to request the access token from the authorization
server.
+ _client: HTTPClient
+
+ def __init__(self, uri: str):
+ self._client = HTTPClient(uri)
+
+ def has_token_data(self) -> bool:
+ """Judge whether AuthDataProvider can provide token data.
+
+ Returns:
+ true if the AuthDataProvider can provide token data otherwise
false.
+ """
+ return True
+
+ def get_token_data(self) -> Optional[bytes]:
+ """Acquire the data of token for authentication. The client will set
the token data as HTTP header
+ Authorization directly. So the return value should ensure token data
contain the token header
+ (eg: Bearer, Basic) if necessary.
+
+ Returns:
+ the token data is used for authentication.
+ """
+ access_token = self._get_access_token()
+
+ if access_token is None:
+ return None
+
+ return (AuthConstants.AUTHORIZATION_BEARER_HEADER +
access_token).encode(
+ "utf-8"
+ )
+
+ def close(self):
+ """Closes the OAuth2TokenProvider and releases any underlying
resources."""
+ if self._client is not None:
+ self._client.close()
+
+ @abstractmethod
+ def _get_access_token(self) -> Optional[str]:
+ """Get the access token from the authorization server."""
+
+ @abstractmethod
+ def validate(self):
+ """Validate the OAuth2TokenProvider"""
diff --git a/clients/client-python/gravitino/auth/simple_auth_provider.py
b/clients/client-python/gravitino/auth/simple_auth_provider.py
index ef013a7fe..96aae06a0 100644
--- a/clients/client-python/gravitino/auth/simple_auth_provider.py
+++ b/clients/client-python/gravitino/auth/simple_auth_provider.py
@@ -20,8 +20,8 @@ under the License.
import base64
import os
-from .auth_constants import AuthConstants
-from .auth_data_provider import AuthDataProvider
+from gravitino.auth.auth_constants import AuthConstants
+from gravitino.auth.auth_data_provider import AuthDataProvider
class SimpleAuthProvider(AuthDataProvider):
diff --git a/clients/client-python/gravitino/auth/auth_constants.py
b/clients/client-python/gravitino/dto/requests/oauth2_client_credential_request.py
similarity index 71%
copy from clients/client-python/gravitino/auth/auth_constants.py
copy to
clients/client-python/gravitino/dto/requests/oauth2_client_credential_request.py
index 2494030fc..4d4de57a4 100644
--- a/clients/client-python/gravitino/auth/auth_constants.py
+++
b/clients/client-python/gravitino/dto/requests/oauth2_client_credential_request.py
@@ -17,8 +17,17 @@ specific language governing permissions and limitations
under the License.
"""
+from typing import Optional
+from dataclasses import dataclass
-class AuthConstants:
- HTTP_HEADER_AUTHORIZATION: str = "Authorization"
- AUTHORIZATION_BASIC_HEADER: str = "Basic "
+@dataclass
+class OAuth2ClientCredentialRequest:
+
+ grant_type: str
+ client_id: Optional[str]
+ client_secret: str
+ scope: str
+
+ def to_dict(self, **kwarg):
+ return {k: v for k, v in self.__dict__.items() if v is not None}
diff --git
a/clients/client-python/gravitino/dto/responses/oauth2_token_response.py
b/clients/client-python/gravitino/dto/responses/oauth2_token_response.py
new file mode 100644
index 000000000..07869ec03
--- /dev/null
+++ b/clients/client-python/gravitino/dto/responses/oauth2_token_response.py
@@ -0,0 +1,55 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+from typing import Optional
+from dataclasses import dataclass, field
+from dataclasses_json import config
+
+from gravitino.dto.responses.base_response import BaseResponse
+from gravitino.auth.auth_constants import AuthConstants
+
+
+@dataclass
+class OAuth2TokenResponse(BaseResponse):
+
+ _access_token: str = field(metadata=config(field_name="access_token"))
+ _issue_token_type: Optional[str] = field(
+ metadata=config(field_name="issued_token_type")
+ )
+ _token_type: str = field(metadata=config(field_name="token_type"))
+ _expires_in: int = field(metadata=config(field_name="expires_in"))
+ _scope: str = field(metadata=config(field_name="scope"))
+ _refresh_token: Optional[str] =
field(metadata=config(field_name="refresh_token"))
+
+ def validate(self):
+ """Validates the response.
+
+ Raise:
+ IllegalArgumentException If the response is invalid, this
exception is thrown.
+ """
+ super().validate()
+
+ assert self._access_token is not None, "Invalid access token: None"
+ assert (
+ AuthConstants.AUTHORIZATION_BEARER_HEADER.strip().lower()
+ == self._token_type.lower()
+ ), f'Unsupported token type: {self._token_type} (must be "bearer")'
+
+ def access_token(self) -> str:
+ return self._access_token
diff --git a/clients/client-python/gravitino/utils/http_client.py
b/clients/client-python/gravitino/utils/http_client.py
index 67504f12d..89b75d641 100644
--- a/clients/client-python/gravitino/utils/http_client.py
+++ b/clients/client-python/gravitino/utils/http_client.py
@@ -78,6 +78,17 @@ class Response:
class HTTPClient:
+
+ FORMDATA_HEADER = {
+ "Content-Type": "application/x-www-form-urlencoded",
+ "Accept": "application/vnd.gravitino.v1+json",
+ }
+
+ JSON_HEADER = {
+ "Content-Type": "application/json",
+ "Accept": "application/vnd.gravitino.v1+json",
+ }
+
def __init__(
self,
host,
@@ -139,12 +150,14 @@ class HTTPClient:
return (False, err_resp)
+ # pylint: disable=too-many-locals
def _request(
self,
method,
endpoint,
params=None,
json=None,
+ data=None,
headers=None,
timeout=None,
error_handler: ErrorHandler = None,
@@ -152,17 +165,17 @@ class HTTPClient:
method = method.upper()
request_data = None
- if headers:
- self._update_headers(headers)
+ if data:
+ request_data = urlencode(data.to_dict()).encode()
+ self._update_headers(self.FORMDATA_HEADER)
else:
- headers = {
- "Content-Type": "application/json",
- "Accept": "application/vnd.gravitino.v1+json",
- }
- self._update_headers(headers)
+ if json:
+ request_data = json.to_json().encode("utf-8")
- if json:
- request_data = json.to_json().encode("utf-8")
+ self._update_headers(self.JSON_HEADER)
+
+ if headers:
+ self._update_headers(headers)
opener = build_opener()
request = Request(self._build_url(endpoint, params), data=request_data)
@@ -213,6 +226,11 @@ class HTTPClient:
"put", endpoint, json=json, error_handler=error_handler, **kwargs
)
+ def post_form(self, endpoint, data=None, error_handler=None, **kwargs):
+ return self._request(
+ "post", endpoint, data=data, error_handler=error_handler**kwargs
+ )
+
def close(self):
self._request("close", "/")
if self.auth_data_provider is not None:
diff --git a/clients/client-python/requirements-dev.txt
b/clients/client-python/requirements-dev.txt
index 06f634358..e91d966a4 100644
--- a/clients/client-python/requirements-dev.txt
+++ b/clients/client-python/requirements-dev.txt
@@ -27,4 +27,5 @@ llama-index==0.10.40
tenacity==8.3.0
cachetools==5.3.3
readerwriterlock==1.0.9
-docker==7.1.0
\ No newline at end of file
+docker==7.1.0
+pyjwt[crypto]==2.8.0
diff --git a/clients/client-python/tests/integration/test_simple_auth_client.py
b/clients/client-python/tests/integration/test_simple_auth_client.py
index a4ed77fe1..5dd8a553b 100644
--- a/clients/client-python/tests/integration/test_simple_auth_client.py
+++ b/clients/client-python/tests/integration/test_simple_auth_client.py
@@ -100,6 +100,8 @@ class TestSimpleAuthClient(IntegrationTestEnv):
)
except Exception as e:
logger.error("Clean test data failed: %s", e)
+ finally:
+ os.environ["GRAVITINO_USER"] = ""
def init_test_env(self):
self.gravitino_admin_client.create_metalake(
diff --git a/clients/client-python/gravitino/auth/auth_constants.py
b/clients/client-python/tests/unittests/auth/__init__.py
similarity index 86%
copy from clients/client-python/gravitino/auth/auth_constants.py
copy to clients/client-python/tests/unittests/auth/__init__.py
index 2494030fc..c206137f1 100644
--- a/clients/client-python/gravitino/auth/auth_constants.py
+++ b/clients/client-python/tests/unittests/auth/__init__.py
@@ -16,9 +16,3 @@ KIND, either express or implied. See the License for the
specific language governing permissions and limitations
under the License.
"""
-
-
-class AuthConstants:
- HTTP_HEADER_AUTHORIZATION: str = "Authorization"
-
- AUTHORIZATION_BASIC_HEADER: str = "Basic "
diff --git a/clients/client-python/tests/unittests/auth/mock_base.py
b/clients/client-python/tests/unittests/auth/mock_base.py
new file mode 100644
index 000000000..f7b66c6b3
--- /dev/null
+++ b/clients/client-python/tests/unittests/auth/mock_base.py
@@ -0,0 +1,144 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+import time
+import json
+from dataclasses import dataclass
+from http import HTTPStatus
+
+from dataclasses_json import dataclass_json
+import jwt
+from cryptography.hazmat.primitives import serialization as
crypto_serialization
+from cryptography.hazmat.primitives.asymmetric import rsa
+from cryptography.hazmat.backends import default_backend as
crypto_default_backend
+
+
+@dataclass
+class TestResponse:
+ body: bytes
+ status_code: int
+
+
+@dataclass_json
+@dataclass
+class TestJWT:
+ sub: str
+ exp: int
+ aud: str
+
+
+def generate_private_key():
+ key = rsa.generate_private_key(
+ backend=crypto_default_backend(), public_exponent=65537, key_size=2048
+ )
+
+ private_key = key.private_bytes(
+ crypto_serialization.Encoding.PEM,
+ crypto_serialization.PrivateFormat.PKCS8,
+ crypto_serialization.NoEncryption(),
+ )
+
+ return private_key
+
+
+JWT_PRIVATE_KEY = generate_private_key()
+GENERATED_TIME = int(time.time())
+
+
+def mock_authentication_with_error_authentication_type():
+ return TestResponse(
+ body=json.dumps(
+ {
+ "code": 0,
+ "access_token": "1",
+ "issued_token_type": "2",
+ "token_type": "3",
+ "expires_in": 1,
+ "scope": "test",
+ "refresh_token": None,
+ }
+ ).encode("utf-8"),
+ status_code=HTTPStatus.OK.value,
+ )
+
+
+def mock_authentication_with_non_jwt():
+ return TestResponse(
+ body=json.dumps(
+ {
+ "code": 0,
+ "access_token": "1",
+ "issued_token_type": "2",
+ "token_type": "bearer",
+ "expires_in": 1,
+ "scope": "test",
+ "refresh_token": None,
+ }
+ ),
+ status_code=HTTPStatus.OK.value,
+ )
+
+
+def mock_jwt(sub, exp, aud):
+ return jwt.encode(
+ TestJWT(sub, exp, aud).to_dict(),
+ JWT_PRIVATE_KEY,
+ algorithm="RS256",
+ )
+
+
+def mock_old_new_jwt():
+ return [
+ mock_jwt(sub="gravitino", exp=GENERATED_TIME - 10000, aud="service1"),
+ mock_jwt(sub="gravitino", exp=GENERATED_TIME + 10000, aud="service1"),
+ ]
+
+
+def mock_authentication_with_jwt():
+ old_access_token, new_access_token = mock_old_new_jwt()
+ return [
+ TestResponse(
+ body=json.dumps(
+ {
+ "code": 0,
+ "access_token": old_access_token,
+ "issued_token_type": "2",
+ "token_type": "bearer",
+ "expires_in": 1,
+ "scope": "test",
+ "refresh_token": None,
+ }
+ ),
+ status_code=HTTPStatus.OK.value,
+ ),
+ TestResponse(
+ body=json.dumps(
+ {
+ "code": 0,
+ "access_token": new_access_token,
+ "issued_token_type": "2",
+ "token_type": "bearer",
+ "expires_in": 1,
+ "scope": "test",
+ "refresh_token": None,
+ }
+ ),
+ status_code=HTTPStatus.OK.value,
+ ),
+ ]
diff --git
a/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
b/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
new file mode 100644
index 000000000..b60efbf04
--- /dev/null
+++ b/clients/client-python/tests/unittests/auth/test_oauth2_token_provider.py
@@ -0,0 +1,93 @@
+"""
+Licensed to the Apache Software Foundation (ASF) under one
+or more contributor license agreements. See the NOTICE file
+distributed with this work for additional information
+regarding copyright ownership. The ASF licenses this file
+to you under the Apache License, Version 2.0 (the
+"License"); you may not use this file except in compliance
+with the License. You may obtain a copy of the License at
+
+ http://www.apache.org/licenses/LICENSE-2.0
+
+Unless required by applicable law or agreed to in writing,
+software distributed under the License is distributed on an
+"AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+KIND, either express or implied. See the License for the
+specific language governing permissions and limitations
+under the License.
+"""
+
+import unittest
+from unittest.mock import patch
+
+from gravitino.auth.auth_constants import AuthConstants
+from gravitino.auth.default_oauth2_token_provider import
DefaultOAuth2TokenProvider
+from tests.unittests.auth import mock_base
+
+OAUTH_PORT = 1082
+
+
+class TestOAuth2TokenProvider(unittest.TestCase):
+
+ def test_provider_init_exception(self):
+
+ with self.assertRaises(AssertionError):
+ _ = DefaultOAuth2TokenProvider(uri="test")
+
+ with self.assertRaises(AssertionError):
+ _ = DefaultOAuth2TokenProvider(uri="test", credential="xx")
+
+ with self.assertRaises(AssertionError):
+ _ = DefaultOAuth2TokenProvider(uri="test", credential="xx",
scope="test")
+
+ # TODO
+ # Error Test
+
+ @patch(
+ "gravitino.utils.http_client.HTTPClient.post_form",
+
return_value=mock_base.mock_authentication_with_error_authentication_type(),
+ )
+ def test_authentication_with_error_authentication_type(self,
*mock_methods):
+
+ with self.assertRaises(AssertionError):
+ _ = DefaultOAuth2TokenProvider(
+ uri=f"http://127.0.0.1:{OAUTH_PORT}",
+ credential="yy:xx",
+ path="oauth/token",
+ scope="test",
+ )
+
+ @patch(
+ "gravitino.utils.http_client.HTTPClient.post_form",
+ return_value=mock_base.mock_authentication_with_non_jwt(),
+ )
+ def test_authentication_with_non_jwt(self, *mock_methods):
+ token_provider = DefaultOAuth2TokenProvider(
+ uri=f"http://127.0.0.1:{OAUTH_PORT}",
+ credential="yy:xx",
+ path="oauth/token",
+ scope="test",
+ )
+
+ self.assertTrue(token_provider.has_token_data())
+ self.assertIsNone(token_provider.get_token_data())
+
+ @patch(
+ "gravitino.utils.http_client.HTTPClient.post_form",
+ side_effect=mock_base.mock_authentication_with_jwt(),
+ )
+ def test_authentication_with_jwt(self, *mock_methods):
+ old_access_token, new_access_token = mock_base.mock_old_new_jwt()
+
+ token_provider = DefaultOAuth2TokenProvider(
+ uri=f"http://127.0.0.1:{OAUTH_PORT}",
+ credential="yy:xx",
+ path="oauth/token",
+ scope="test",
+ )
+
+ self.assertNotEqual(old_access_token, new_access_token)
+ self.assertEqual(
+ token_provider.get_token_data().decode("utf-8"),
+ AuthConstants.AUTHORIZATION_BEARER_HEADER + new_access_token,
+ )
diff --git a/clients/client-python/tests/unittests/test_simple_auth_provider.py
b/clients/client-python/tests/unittests/auth/test_simple_auth_provider.py
similarity index 91%
rename from clients/client-python/tests/unittests/test_simple_auth_provider.py
rename to
clients/client-python/tests/unittests/auth/test_simple_auth_provider.py
index d8c10e467..c7e7fdc39 100644
--- a/clients/client-python/tests/unittests/test_simple_auth_provider.py
+++ b/clients/client-python/tests/unittests/auth/test_simple_auth_provider.py
@@ -40,6 +40,9 @@ class TestSimpleAuthProvider(unittest.TestCase):
).decode("utf-8")
self.assertEqual(f"{user}:dummy", token_string)
+ original_gravitino_user = (
+ os.environ["GRAVITINO_USER"] if "GRAVITINO_USER" in os.environ
else ""
+ )
os.environ["GRAVITINO_USER"] = "test_auth2"
provider: AuthDataProvider = SimpleAuthProvider()
self.assertTrue(provider.has_token_data())
@@ -50,3 +53,4 @@ class TestSimpleAuthProvider(unittest.TestCase):
token[len(AuthConstants.AUTHORIZATION_BASIC_HEADER) :]
).decode("utf-8")
self.assertEqual(f"{user}:dummy", token_string)
+ os.environ["GRAVITINO_USER"] = original_gravitino_user