This is an automated email from the ASF dual-hosted git repository. fanng pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/gravitino.git
The following commit(s) were added to refs/heads/main by this push: new 94ee0a356 [#5895] support ADLSToken/AzureAccountKey credential for python client (#5940) 94ee0a356 is described below commit 94ee0a35634aad0ae9d22920a62941d9113a56e7 Author: JUN <oren....@gmail.com> AuthorDate: Tue Dec 24 09:26:24 2024 +0800 [#5895] support ADLSToken/AzureAccountKey credential for python client (#5940) ### What changes were proposed in this pull request? Support ADLS credential for python client ### Why are the changes needed? These changes are necessary to support authentication using ADLS credentials and to allow the CredentialFactory to generate ADLS credentials correctly. It ensures proper functionality and integration. Fix: #5895 ### Does this PR introduce _any_ user-facing change? No ### How was this patch tested? Unit tests. --- ...oken_credential.py => adls_token_credential.py} | 45 ++++++--- ...edential.py => azure_account_key_credential.py} | 44 ++++----- .../api/credential/gcs_token_credential.py | 2 +- .../api/credential/oss_secret_key_credential.py | 16 +-- .../api/credential/oss_token_credential.py | 22 ++--- .../api/credential/s3_secret_key_credential.py | 16 +-- .../api/credential/s3_token_credential.py | 22 ++--- .../gravitino/utils/credential_factory.py | 39 ++++++-- .../tests/unittests/test_credential_factory.py | 110 +++++++++++++++++---- 9 files changed, 209 insertions(+), 107 deletions(-) diff --git a/clients/client-python/gravitino/api/credential/gcs_token_credential.py b/clients/client-python/gravitino/api/credential/adls_token_credential.py similarity index 60% copy from clients/client-python/gravitino/api/credential/gcs_token_credential.py copy to clients/client-python/gravitino/api/credential/adls_token_credential.py index 1362383f0..40ad0eebb 100644 --- a/clients/client-python/gravitino/api/credential/gcs_token_credential.py +++ b/clients/client-python/gravitino/api/credential/adls_token_credential.py @@ -22,23 +22,27 @@ from gravitino.api.credential.credential import Credential from gravitino.utils.precondition import Precondition -class GCSTokenCredential(Credential, ABC): - """Represents the GCS token credential.""" +class ADLSTokenCredential(Credential, ABC): + """Represents ADLS token credential.""" - GCS_TOKEN_CREDENTIAL_TYPE: str = "gcs-token" - _GCS_TOKEN_NAME: str = "token" - - _expire_time_in_ms: int = 0 + ADLS_SAS_TOKEN_CREDENTIAL_TYPE: str = "adls-sas-token" + ADLS_DOMAIN: str = "dfs.core.windows.net" + _STORAGE_ACCOUNT_NAME: str = "azure-storage-account-name" + _SAS_TOKEN: str = "adls-sas-token" def __init__(self, credential_info: Dict[str, str], expire_time_in_ms: int): - self._token = credential_info[self._GCS_TOKEN_NAME] + self._account_name = credential_info.get(self._STORAGE_ACCOUNT_NAME, None) + self._sas_token = credential_info.get(self._SAS_TOKEN, None) self._expire_time_in_ms = expire_time_in_ms Precondition.check_string_not_empty( - self._token, "GCS token should not be empty" + self._account_name, "The ADLS account name should not be empty." + ) + Precondition.check_string_not_empty( + self._sas_token, "The ADLS SAS token should not be empty." ) Precondition.check_argument( self._expire_time_in_ms > 0, - "The expiration time of GCS token credential should be greater than 0", + "The expiration time of ADLS token credential should be greater than 0", ) def credential_type(self) -> str: @@ -47,7 +51,7 @@ class GCSTokenCredential(Credential, ABC): Returns: the type of the credential. """ - return self.GCS_TOKEN_CREDENTIAL_TYPE + return self.ADLS_SAS_TOKEN_CREDENTIAL_TYPE def expire_time_in_ms(self) -> int: """Returns the expiration time of the credential in milliseconds since @@ -64,12 +68,23 @@ class GCSTokenCredential(Credential, ABC): Returns: The credential information. """ - return {self._GCS_TOKEN_NAME: self._token} + return { + self._STORAGE_ACCOUNT_NAME: self._account_name, + self._SAS_TOKEN: self._sas_token, + } + + def account_name(self) -> str: + """The ADLS account name. + + Returns: + The ADLS account name. + """ + return self._account_name - def token(self) -> str: - """The GCS token. + def sas_token(self) -> str: + """The ADLS sas token. Returns: - The GCS token. + The ADLS sas token. """ - return self._token + return self._sas_token diff --git a/clients/client-python/gravitino/api/credential/oss_secret_key_credential.py b/clients/client-python/gravitino/api/credential/azure_account_key_credential.py similarity index 60% copy from clients/client-python/gravitino/api/credential/oss_secret_key_credential.py copy to clients/client-python/gravitino/api/credential/azure_account_key_credential.py index 919a3782e..aa60e3015 100644 --- a/clients/client-python/gravitino/api/credential/oss_secret_key_credential.py +++ b/clients/client-python/gravitino/api/credential/azure_account_key_credential.py @@ -22,27 +22,25 @@ from gravitino.api.credential.credential import Credential from gravitino.utils.precondition import Precondition -class OSSSecretKeyCredential(Credential, ABC): - """Represents OSS secret key credential.""" +class AzureAccountKeyCredential(Credential, ABC): + """Represents Azure account key credential.""" - OSS_SECRET_KEY_CREDENTIAL_TYPE: str = "oss-secret-key" - _GRAVITINO_OSS_STATIC_ACCESS_KEY_ID: str = "oss-access-key-id" - _GRAVITINO_OSS_STATIC_SECRET_ACCESS_KEY: str = "oss-secret-access-key" + AZURE_ACCOUNT_KEY_CREDENTIAL_TYPE: str = "azure-account-key" + _STORAGE_ACCOUNT_NAME: str = "azure-storage-account-name" + _STORAGE_ACCOUNT_KEY: str = "azure-storage-account-key" def __init__(self, credential_info: Dict[str, str], expire_time_in_ms: int): - self._access_key_id = credential_info[self._GRAVITINO_OSS_STATIC_ACCESS_KEY_ID] - self._secret_access_key = credential_info[ - self._GRAVITINO_OSS_STATIC_SECRET_ACCESS_KEY - ] + self._account_name = credential_info.get(self._STORAGE_ACCOUNT_NAME, None) + self._account_key = credential_info.get(self._STORAGE_ACCOUNT_KEY, None) Precondition.check_string_not_empty( - self._access_key_id, "The OSS access key ID should not be empty" + self._account_name, "The Azure account name should not be empty" ) Precondition.check_string_not_empty( - self._secret_access_key, "The OSS secret access key should not be empty" + self._account_key, "The Azure account key should not be empty" ) Precondition.check_argument( expire_time_in_ms == 0, - "The expiration time of OSS secret key credential should be 0", + "The expiration time of Azure account key credential should be 0", ) def credential_type(self) -> str: @@ -51,7 +49,7 @@ class OSSSecretKeyCredential(Credential, ABC): Returns: The type of the credential. """ - return self.OSS_SECRET_KEY_CREDENTIAL_TYPE + return self.AZURE_ACCOUNT_KEY_CREDENTIAL_TYPE def expire_time_in_ms(self) -> int: """Returns the expiration time of the credential in milliseconds since @@ -69,22 +67,22 @@ class OSSSecretKeyCredential(Credential, ABC): The credential information. """ return { - self._GRAVITINO_OSS_STATIC_SECRET_ACCESS_KEY: self._secret_access_key, - self._GRAVITINO_OSS_STATIC_ACCESS_KEY_ID: self._access_key_id, + self._STORAGE_ACCOUNT_NAME: self._account_name, + self._STORAGE_ACCOUNT_KEY: self._account_key, } - def access_key_id(self) -> str: - """The OSS access key ID. + def account_name(self) -> str: + """The Azure account name. Returns: - The OSS access key ID. + The Azure account name. """ - return self._access_key_id + return self._account_name - def secret_access_key(self) -> str: - """The OSS secret access key. + def account_key(self) -> str: + """The Azure account key. Returns: - The OSS secret access key. + The Azure account key. """ - return self._secret_access_key + return self._account_key diff --git a/clients/client-python/gravitino/api/credential/gcs_token_credential.py b/clients/client-python/gravitino/api/credential/gcs_token_credential.py index 1362383f0..0221ac07c 100644 --- a/clients/client-python/gravitino/api/credential/gcs_token_credential.py +++ b/clients/client-python/gravitino/api/credential/gcs_token_credential.py @@ -31,7 +31,7 @@ class GCSTokenCredential(Credential, ABC): _expire_time_in_ms: int = 0 def __init__(self, credential_info: Dict[str, str], expire_time_in_ms: int): - self._token = credential_info[self._GCS_TOKEN_NAME] + self._token = credential_info.get(self._GCS_TOKEN_NAME, None) self._expire_time_in_ms = expire_time_in_ms Precondition.check_string_not_empty( self._token, "GCS token should not be empty" diff --git a/clients/client-python/gravitino/api/credential/oss_secret_key_credential.py b/clients/client-python/gravitino/api/credential/oss_secret_key_credential.py index 919a3782e..69a964649 100644 --- a/clients/client-python/gravitino/api/credential/oss_secret_key_credential.py +++ b/clients/client-python/gravitino/api/credential/oss_secret_key_credential.py @@ -26,14 +26,14 @@ class OSSSecretKeyCredential(Credential, ABC): """Represents OSS secret key credential.""" OSS_SECRET_KEY_CREDENTIAL_TYPE: str = "oss-secret-key" - _GRAVITINO_OSS_STATIC_ACCESS_KEY_ID: str = "oss-access-key-id" - _GRAVITINO_OSS_STATIC_SECRET_ACCESS_KEY: str = "oss-secret-access-key" + _STATIC_ACCESS_KEY_ID: str = "oss-access-key-id" + _STATIC_SECRET_ACCESS_KEY: str = "oss-secret-access-key" def __init__(self, credential_info: Dict[str, str], expire_time_in_ms: int): - self._access_key_id = credential_info[self._GRAVITINO_OSS_STATIC_ACCESS_KEY_ID] - self._secret_access_key = credential_info[ - self._GRAVITINO_OSS_STATIC_SECRET_ACCESS_KEY - ] + self._access_key_id = credential_info.get(self._STATIC_ACCESS_KEY_ID, None) + self._secret_access_key = credential_info.get( + self._STATIC_SECRET_ACCESS_KEY, None + ) Precondition.check_string_not_empty( self._access_key_id, "The OSS access key ID should not be empty" ) @@ -69,8 +69,8 @@ class OSSSecretKeyCredential(Credential, ABC): The credential information. """ return { - self._GRAVITINO_OSS_STATIC_SECRET_ACCESS_KEY: self._secret_access_key, - self._GRAVITINO_OSS_STATIC_ACCESS_KEY_ID: self._access_key_id, + self._STATIC_ACCESS_KEY_ID: self._access_key_id, + self._STATIC_SECRET_ACCESS_KEY: self._secret_access_key, } def access_key_id(self) -> str: diff --git a/clients/client-python/gravitino/api/credential/oss_token_credential.py b/clients/client-python/gravitino/api/credential/oss_token_credential.py index 70dad14a1..d217ad8c8 100644 --- a/clients/client-python/gravitino/api/credential/oss_token_credential.py +++ b/clients/client-python/gravitino/api/credential/oss_token_credential.py @@ -26,16 +26,16 @@ class OSSTokenCredential(Credential, ABC): """Represents OSS token credential.""" OSS_TOKEN_CREDENTIAL_TYPE: str = "oss-token" - _GRAVITINO_OSS_SESSION_ACCESS_KEY_ID: str = "oss-access-key-id" - _GRAVITINO_OSS_SESSION_SECRET_ACCESS_KEY: str = "oss-secret-access-key" - _GRAVITINO_OSS_TOKEN: str = "oss-security-token" + _STATIC_ACCESS_KEY_ID: str = "oss-access-key-id" + _STATIC_SECRET_ACCESS_KEY: str = "oss-secret-access-key" + _OSS_TOKEN: str = "oss-security-token" def __init__(self, credential_info: Dict[str, str], expire_time_in_ms: int): - self._access_key_id = credential_info[self._GRAVITINO_OSS_SESSION_ACCESS_KEY_ID] - self._secret_access_key = credential_info[ - self._GRAVITINO_OSS_SESSION_SECRET_ACCESS_KEY - ] - self._security_token = credential_info[self._GRAVITINO_OSS_TOKEN] + self._access_key_id = credential_info.get(self._STATIC_ACCESS_KEY_ID, None) + self._secret_access_key = credential_info.get( + self._STATIC_SECRET_ACCESS_KEY, None + ) + self._security_token = credential_info.get(self._OSS_TOKEN, None) self._expire_time_in_ms = expire_time_in_ms Precondition.check_string_not_empty( self._access_key_id, "The OSS access key ID should not be empty" @@ -75,9 +75,9 @@ class OSSTokenCredential(Credential, ABC): The credential information. """ return { - self._GRAVITINO_OSS_TOKEN: self._security_token, - self._GRAVITINO_OSS_SESSION_ACCESS_KEY_ID: self._access_key_id, - self._GRAVITINO_OSS_SESSION_SECRET_ACCESS_KEY: self._secret_access_key, + self._STATIC_ACCESS_KEY_ID: self._access_key_id, + self._STATIC_SECRET_ACCESS_KEY: self._secret_access_key, + self._OSS_TOKEN: self._security_token, } def access_key_id(self) -> str: diff --git a/clients/client-python/gravitino/api/credential/s3_secret_key_credential.py b/clients/client-python/gravitino/api/credential/s3_secret_key_credential.py index 735c41e2e..05c221fe2 100644 --- a/clients/client-python/gravitino/api/credential/s3_secret_key_credential.py +++ b/clients/client-python/gravitino/api/credential/s3_secret_key_credential.py @@ -26,14 +26,14 @@ class S3SecretKeyCredential(Credential, ABC): """Represents S3 secret key credential.""" S3_SECRET_KEY_CREDENTIAL_TYPE: str = "s3-secret-key" - _GRAVITINO_S3_STATIC_ACCESS_KEY_ID: str = "s3-access-key-id" - _GRAVITINO_S3_STATIC_SECRET_ACCESS_KEY: str = "s3-secret-access-key" + _STATIC_ACCESS_KEY_ID: str = "s3-access-key-id" + _STATIC_SECRET_ACCESS_KEY: str = "s3-secret-access-key" def __init__(self, credential_info: Dict[str, str], expire_time: int): - self._access_key_id = credential_info[self._GRAVITINO_S3_STATIC_ACCESS_KEY_ID] - self._secret_access_key = credential_info[ - self._GRAVITINO_S3_STATIC_SECRET_ACCESS_KEY - ] + self._access_key_id = credential_info.get(self._STATIC_ACCESS_KEY_ID, None) + self._secret_access_key = credential_info.get( + self._STATIC_SECRET_ACCESS_KEY, None + ) Precondition.check_string_not_empty( self._access_key_id, "S3 access key id should not be empty" ) @@ -70,8 +70,8 @@ class S3SecretKeyCredential(Credential, ABC): The credential information. """ return { - self._GRAVITINO_S3_STATIC_SECRET_ACCESS_KEY: self._secret_access_key, - self._GRAVITINO_S3_STATIC_ACCESS_KEY_ID: self._access_key_id, + self._STATIC_ACCESS_KEY_ID: self._access_key_id, + self._STATIC_SECRET_ACCESS_KEY: self._secret_access_key, } def access_key_id(self) -> str: diff --git a/clients/client-python/gravitino/api/credential/s3_token_credential.py b/clients/client-python/gravitino/api/credential/s3_token_credential.py index c72d9f02a..d95919f66 100644 --- a/clients/client-python/gravitino/api/credential/s3_token_credential.py +++ b/clients/client-python/gravitino/api/credential/s3_token_credential.py @@ -26,9 +26,9 @@ class S3TokenCredential(Credential, ABC): """Represents the S3 token credential.""" S3_TOKEN_CREDENTIAL_TYPE: str = "s3-token" - _GRAVITINO_S3_SESSION_ACCESS_KEY_ID: str = "s3-access-key-id" - _GRAVITINO_S3_SESSION_SECRET_ACCESS_KEY: str = "s3-secret-access-key" - _GRAVITINO_S3_TOKEN: str = "s3-session-token" + _SESSION_ACCESS_KEY_ID: str = "s3-access-key-id" + _SESSION_SECRET_ACCESS_KEY: str = "s3-secret-access-key" + _SESSION_TOKEN: str = "s3-session-token" _expire_time_in_ms: int = 0 _access_key_id: str = None @@ -36,11 +36,11 @@ class S3TokenCredential(Credential, ABC): _session_token: str = None def __init__(self, credential_info: Dict[str, str], expire_time_in_ms: int): - self._access_key_id = credential_info[self._GRAVITINO_S3_SESSION_ACCESS_KEY_ID] - self._secret_access_key = credential_info[ - self._GRAVITINO_S3_SESSION_SECRET_ACCESS_KEY - ] - self._session_token = credential_info[self._GRAVITINO_S3_TOKEN] + self._access_key_id = credential_info.get(self._SESSION_ACCESS_KEY_ID, None) + self._secret_access_key = credential_info.get( + self._SESSION_SECRET_ACCESS_KEY, None + ) + self._session_token = credential_info.get(self._SESSION_TOKEN, None) self._expire_time_in_ms = expire_time_in_ms Precondition.check_string_not_empty( self._access_key_id, "The S3 access key ID should not be empty" @@ -80,9 +80,9 @@ class S3TokenCredential(Credential, ABC): The credential information. """ return { - self._GRAVITINO_S3_TOKEN: self._session_token, - self._GRAVITINO_S3_SESSION_ACCESS_KEY_ID: self._access_key_id, - self._GRAVITINO_S3_SESSION_SECRET_ACCESS_KEY: self._secret_access_key, + self._SESSION_ACCESS_KEY_ID: self._access_key_id, + self._SESSION_SECRET_ACCESS_KEY: self._secret_access_key, + self._SESSION_TOKEN: self._session_token, } def access_key_id(self) -> str: diff --git a/clients/client-python/gravitino/utils/credential_factory.py b/clients/client-python/gravitino/utils/credential_factory.py index 7a584caa3..32d7465b8 100644 --- a/clients/client-python/gravitino/utils/credential_factory.py +++ b/clients/client-python/gravitino/utils/credential_factory.py @@ -16,12 +16,17 @@ # under the License. from typing import Dict + from gravitino.api.credential.credential import Credential from gravitino.api.credential.gcs_token_credential import GCSTokenCredential from gravitino.api.credential.oss_token_credential import OSSTokenCredential from gravitino.api.credential.s3_secret_key_credential import S3SecretKeyCredential from gravitino.api.credential.s3_token_credential import S3TokenCredential from gravitino.api.credential.oss_secret_key_credential import OSSSecretKeyCredential +from gravitino.api.credential.adls_token_credential import ADLSTokenCredential +from gravitino.api.credential.azure_account_key_credential import ( + AzureAccountKeyCredential, +) class CredentialFactory: @@ -29,14 +34,28 @@ class CredentialFactory: def create( credential_type: str, credential_info: Dict[str, str], expire_time_in_ms: int ) -> Credential: + credential = None + if credential_type == S3TokenCredential.S3_TOKEN_CREDENTIAL_TYPE: - return S3TokenCredential(credential_info, expire_time_in_ms) - if credential_type == S3SecretKeyCredential.S3_SECRET_KEY_CREDENTIAL_TYPE: - return S3SecretKeyCredential(credential_info, expire_time_in_ms) - if credential_type == GCSTokenCredential.GCS_TOKEN_CREDENTIAL_TYPE: - return GCSTokenCredential(credential_info, expire_time_in_ms) - if credential_type == OSSTokenCredential.OSS_TOKEN_CREDENTIAL_TYPE: - return OSSTokenCredential(credential_info, expire_time_in_ms) - if credential_type == OSSSecretKeyCredential.OSS_SECRET_KEY_CREDENTIAL_TYPE: - return OSSSecretKeyCredential(credential_info, expire_time_in_ms) - raise NotImplementedError(f"Credential type {credential_type} is not supported") + credential = S3TokenCredential(credential_info, expire_time_in_ms) + elif credential_type == S3SecretKeyCredential.S3_SECRET_KEY_CREDENTIAL_TYPE: + credential = S3SecretKeyCredential(credential_info, expire_time_in_ms) + elif credential_type == GCSTokenCredential.GCS_TOKEN_CREDENTIAL_TYPE: + credential = GCSTokenCredential(credential_info, expire_time_in_ms) + elif credential_type == OSSTokenCredential.OSS_TOKEN_CREDENTIAL_TYPE: + credential = OSSTokenCredential(credential_info, expire_time_in_ms) + elif credential_type == OSSSecretKeyCredential.OSS_SECRET_KEY_CREDENTIAL_TYPE: + credential = OSSSecretKeyCredential(credential_info, expire_time_in_ms) + elif credential_type == ADLSTokenCredential.ADLS_SAS_TOKEN_CREDENTIAL_TYPE: + credential = ADLSTokenCredential(credential_info, expire_time_in_ms) + elif ( + credential_type + == AzureAccountKeyCredential.AZURE_ACCOUNT_KEY_CREDENTIAL_TYPE + ): + credential = AzureAccountKeyCredential(credential_info, expire_time_in_ms) + else: + raise NotImplementedError( + f"Credential type {credential_type} is not supported" + ) + + return credential diff --git a/clients/client-python/tests/unittests/test_credential_factory.py b/clients/client-python/tests/unittests/test_credential_factory.py index 94fd02d1d..4c4a91495 100644 --- a/clients/client-python/tests/unittests/test_credential_factory.py +++ b/clients/client-python/tests/unittests/test_credential_factory.py @@ -25,15 +25,19 @@ from gravitino.api.credential.s3_secret_key_credential import S3SecretKeyCredent from gravitino.api.credential.s3_token_credential import S3TokenCredential from gravitino.utils.credential_factory import CredentialFactory from gravitino.api.credential.oss_secret_key_credential import OSSSecretKeyCredential +from gravitino.api.credential.adls_token_credential import ADLSTokenCredential +from gravitino.api.credential.azure_account_key_credential import ( + AzureAccountKeyCredential, +) class TestCredentialFactory(unittest.TestCase): def test_s3_token_credential(self): s3_credential_info = { - S3TokenCredential._GRAVITINO_S3_SESSION_ACCESS_KEY_ID: "access_key", - S3TokenCredential._GRAVITINO_S3_SESSION_SECRET_ACCESS_KEY: "secret_key", - S3TokenCredential._GRAVITINO_S3_TOKEN: "session_token", + S3TokenCredential._SESSION_ACCESS_KEY_ID: "access_key", + S3TokenCredential._SESSION_SECRET_ACCESS_KEY: "secret_key", + S3TokenCredential._SESSION_TOKEN: "session_token", } s3_credential = S3TokenCredential(s3_credential_info, 1000) credential_info = s3_credential.credential_info() @@ -42,6 +46,12 @@ class TestCredentialFactory(unittest.TestCase): check_credential = CredentialFactory.create( s3_credential.S3_TOKEN_CREDENTIAL_TYPE, credential_info, expire_time ) + self.assertEqual( + S3TokenCredential.S3_TOKEN_CREDENTIAL_TYPE, + check_credential.credential_type(), + ) + + self.assertIsInstance(check_credential, S3TokenCredential) self.assertEqual("access_key", check_credential.access_key_id()) self.assertEqual("secret_key", check_credential.secret_access_key()) self.assertEqual("session_token", check_credential.session_token()) @@ -49,8 +59,8 @@ class TestCredentialFactory(unittest.TestCase): def test_s3_secret_key_credential(self): s3_credential_info = { - S3SecretKeyCredential._GRAVITINO_S3_STATIC_ACCESS_KEY_ID: "access_key", - S3SecretKeyCredential._GRAVITINO_S3_STATIC_SECRET_ACCESS_KEY: "secret_key", + S3SecretKeyCredential._STATIC_ACCESS_KEY_ID: "access_key", + S3SecretKeyCredential._STATIC_SECRET_ACCESS_KEY: "secret_key", } s3_credential = S3SecretKeyCredential(s3_credential_info, 0) credential_info = s3_credential.credential_info() @@ -59,43 +69,53 @@ class TestCredentialFactory(unittest.TestCase): check_credential = CredentialFactory.create( s3_credential.S3_SECRET_KEY_CREDENTIAL_TYPE, credential_info, expire_time ) + self.assertEqual( + S3SecretKeyCredential.S3_SECRET_KEY_CREDENTIAL_TYPE, + check_credential.credential_type(), + ) + + self.assertIsInstance(check_credential, S3SecretKeyCredential) self.assertEqual("access_key", check_credential.access_key_id()) self.assertEqual("secret_key", check_credential.secret_access_key()) self.assertEqual(0, check_credential.expire_time_in_ms()) def test_gcs_token_credential(self): - credential_info = {GCSTokenCredential._GCS_TOKEN_NAME: "token"} - credential = GCSTokenCredential(credential_info, 1000) - credential_info = credential.credential_info() - expire_time = credential.expire_time_in_ms() + gcs_credential_info = {GCSTokenCredential._GCS_TOKEN_NAME: "token"} + gcs_credential = GCSTokenCredential(gcs_credential_info, 1000) + credential_info = gcs_credential.credential_info() + expire_time = gcs_credential.expire_time_in_ms() check_credential = CredentialFactory.create( - credential.credential_type(), credential_info, expire_time + gcs_credential.credential_type(), credential_info, expire_time ) self.assertEqual( GCSTokenCredential.GCS_TOKEN_CREDENTIAL_TYPE, check_credential.credential_type(), ) + + self.assertIsInstance(check_credential, GCSTokenCredential) self.assertEqual("token", check_credential.token()) self.assertEqual(1000, check_credential.expire_time_in_ms()) def test_oss_token_credential(self): - credential_info = { - OSSTokenCredential._GRAVITINO_OSS_TOKEN: "token", - OSSTokenCredential._GRAVITINO_OSS_SESSION_ACCESS_KEY_ID: "access_id", - OSSTokenCredential._GRAVITINO_OSS_SESSION_SECRET_ACCESS_KEY: "secret_key", + oss_credential_info = { + OSSTokenCredential._STATIC_ACCESS_KEY_ID: "access_id", + OSSTokenCredential._STATIC_SECRET_ACCESS_KEY: "secret_key", + OSSTokenCredential._OSS_TOKEN: "token", } - credential = OSSTokenCredential(credential_info, 1000) - credential_info = credential.credential_info() - expire_time = credential.expire_time_in_ms() + oss_credential = OSSTokenCredential(oss_credential_info, 1000) + credential_info = oss_credential.credential_info() + expire_time = oss_credential.expire_time_in_ms() check_credential = CredentialFactory.create( - credential.credential_type(), credential_info, expire_time + oss_credential.credential_type(), credential_info, expire_time ) self.assertEqual( OSSTokenCredential.OSS_TOKEN_CREDENTIAL_TYPE, check_credential.credential_type(), ) + + self.assertIsInstance(check_credential, OSSTokenCredential) self.assertEqual("token", check_credential.security_token()) self.assertEqual("access_id", check_credential.access_key_id()) self.assertEqual("secret_key", check_credential.secret_access_key()) @@ -103,8 +123,8 @@ class TestCredentialFactory(unittest.TestCase): def test_oss_secret_key_credential(self): oss_credential_info = { - OSSSecretKeyCredential._GRAVITINO_OSS_STATIC_ACCESS_KEY_ID: "access_key", - OSSSecretKeyCredential._GRAVITINO_OSS_STATIC_SECRET_ACCESS_KEY: "secret_key", + OSSSecretKeyCredential._STATIC_ACCESS_KEY_ID: "access_key", + OSSSecretKeyCredential._STATIC_SECRET_ACCESS_KEY: "secret_key", } oss_credential = OSSSecretKeyCredential(oss_credential_info, 0) credential_info = oss_credential.credential_info() @@ -113,6 +133,56 @@ class TestCredentialFactory(unittest.TestCase): check_credential = CredentialFactory.create( oss_credential.OSS_SECRET_KEY_CREDENTIAL_TYPE, credential_info, expire_time ) + self.assertEqual( + OSSSecretKeyCredential.OSS_SECRET_KEY_CREDENTIAL_TYPE, + check_credential.credential_type(), + ) + + self.assertIsInstance(check_credential, OSSSecretKeyCredential) self.assertEqual("access_key", check_credential.access_key_id()) self.assertEqual("secret_key", check_credential.secret_access_key()) self.assertEqual(0, check_credential.expire_time_in_ms()) + + def test_adls_token_credential(self): + adls_credential_info = { + ADLSTokenCredential._STORAGE_ACCOUNT_NAME: "account_name", + ADLSTokenCredential._SAS_TOKEN: "sas_token", + } + adls_credential = ADLSTokenCredential(adls_credential_info, 1000) + credential_info = adls_credential.credential_info() + expire_time = adls_credential.expire_time_in_ms() + + check_credential = CredentialFactory.create( + adls_credential.credential_type(), credential_info, expire_time + ) + self.assertEqual( + ADLSTokenCredential.ADLS_SAS_TOKEN_CREDENTIAL_TYPE, + check_credential.credential_type(), + ) + + self.assertIsInstance(check_credential, ADLSTokenCredential) + self.assertEqual("account_name", check_credential.account_name()) + self.assertEqual("sas_token", check_credential.sas_token()) + self.assertEqual(1000, check_credential.expire_time_in_ms()) + + def test_azure_account_key_credential(self): + azure_credential_info = { + AzureAccountKeyCredential._STORAGE_ACCOUNT_NAME: "account_name", + AzureAccountKeyCredential._STORAGE_ACCOUNT_KEY: "account_key", + } + azure_credential = AzureAccountKeyCredential(azure_credential_info, 0) + credential_info = azure_credential.credential_info() + expire_time = azure_credential.expire_time_in_ms() + + check_credential = CredentialFactory.create( + azure_credential.credential_type(), credential_info, expire_time + ) + self.assertEqual( + AzureAccountKeyCredential.AZURE_ACCOUNT_KEY_CREDENTIAL_TYPE, + check_credential.credential_type(), + ) + + self.assertIsInstance(check_credential, AzureAccountKeyCredential) + self.assertEqual("account_name", check_credential.account_name()) + self.assertEqual("account_key", check_credential.account_key()) + self.assertEqual(0, check_credential.expire_time_in_ms())