This is an automated email from the ASF dual-hosted git repository. hope pushed a commit to branch release-1.4 in repository https://gitbox.apache.org/repos/asf/paimon.git
commit a413c697c42d5a4739674777d96b712b233366f0 Author: shyjsarah <[email protected]> AuthorDate: Tue Mar 31 00:52:57 2026 -0700 [python] Fix token cache pollution in Python `RESTTokenFileIO` by aligning with the Java implementation. (#7562) Fix token cache pollution in Python `RESTTokenFileIO` by aligning with the Java implementation. **Problem:** Python's `RESTTokenFileIO` had a class-level `_TOKEN_CACHE` (keyed by table identifier string) that caused token sharing across different catalog instances. When two catalogs with different AK/SK credentials operated on the same table within one process, the second catalog would reuse the first catalog's token — leading to permission errors (e.g., read AK/SK token used for write operations). **Root Cause:** Java's `RESTTokenFileIO` has **no token cache** — each instance manages its own `token` field independently. Python added an extra `_TOKEN_CACHE` class-level dict that Java never had. **Changes:** - Remove class-level `_TOKEN_CACHE`, `_TOKEN_LOCKS`, `_TOKEN_LOCKS_LOCK` and their associated methods (`_get_cached_token`, `_set_cached_token`, `_get_global_token_lock`, `_is_token_expired`) - Simplify `try_to_refresh_token()` to use instance-level lock with double-check pattern, aligned with Java's `tryToRefreshToken()` - Merge `should_refresh()` and `_is_token_expired()` into a single `_should_refresh()` method - Add system table identifier handling in `refresh_token()` (strip `$snapshots` suffix before requesting token), aligned with Java The `FILE_IO_CACHE` (keyed by `RESTToken` object) is kept unchanged — it correctly isolates different credentials since different tokens produce different cache keys. --- .../pypaimon/catalog/rest/rest_token_file_io.py | 95 ++++-------- .../pypaimon/tests/rest/rest_token_file_io_test.py | 164 ++++++++++++++++++--- 2 files changed, 177 insertions(+), 82 deletions(-) diff --git a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py index 7bf984d1f5..16b16e6972 100644 --- a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py +++ b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py @@ -27,7 +27,7 @@ from pypaimon.api.rest_util import RESTUtil from pypaimon.catalog.rest.rest_token import RESTToken from pypaimon.common.file_io import FileIO from pypaimon.filesystem.pyarrow_file_io import PyArrowFileIO -from pypaimon.common.identifier import Identifier +from pypaimon.common.identifier import Identifier, SYSTEM_TABLE_SPLITTER from pypaimon.common.options import Options from pypaimon.common.options.config import CatalogOptions, OssOptions from pypaimon.common.uri_reader import UriReaderFactory @@ -37,17 +37,13 @@ class RESTTokenFileIO(FileIO): """ A FileIO to support getting token from REST Server. """ - + _FILE_IO_CACHE_MAXSIZE = 1000 _FILE_IO_CACHE_TTL = 36000 # 10 hours in seconds - + _FILE_IO_CACHE: TTLCache = None _FILE_IO_CACHE_LOCK = threading.Lock() - - _TOKEN_CACHE: dict = {} - _TOKEN_LOCKS: dict = {} - _TOKEN_LOCKS_LOCK = threading.Lock() - + @classmethod def _get_file_io_cache(cls) -> TTLCache: if cls._FILE_IO_CACHE is None: @@ -58,7 +54,7 @@ class RESTTokenFileIO(FileIO): ttl=cls._FILE_IO_CACHE_TTL ) return cls._FILE_IO_CACHE - + def __init__(self, identifier: Identifier, path: str, catalog_options: Optional[Union[dict, Options]] = None): self.identifier = identifier @@ -99,26 +95,26 @@ class RESTTokenFileIO(FileIO): if self.token is None: return FileIO.get(self.path, self.catalog_options or Options({})) - + cache_key = self.token cache = self._get_file_io_cache() - + file_io = cache.get(cache_key) if file_io is not None: return file_io - + with self._FILE_IO_CACHE_LOCK: self.try_to_refresh_token() - + if self.token is None: return FileIO.get(self.path, self.catalog_options or Options({})) - + cache_key = self.token cache = self._get_file_io_cache() file_io = cache.get(cache_key) if file_io is not None: return file_io - + merged_properties = RESTUtil.merge( self.catalog_options.to_map() if self.catalog_options else {}, self.token.token @@ -128,7 +124,7 @@ class RESTTokenFileIO(FileIO): if dlf_oss_endpoint and dlf_oss_endpoint.strip(): merged_properties[OssOptions.OSS_ENDPOINT.key()] = dlf_oss_endpoint merged_options = Options(merged_properties) - + file_io = PyArrowFileIO(self.path, merged_options) cache[cache_key] = file_io return file_io @@ -198,7 +194,7 @@ class RESTTokenFileIO(FileIO): if self._uri_reader_factory_cache is None: catalog_options = self.catalog_options or Options({}) self._uri_reader_factory_cache = UriReaderFactory(catalog_options) - + return self._uri_reader_factory_cache @property @@ -206,66 +202,35 @@ class RESTTokenFileIO(FileIO): return self.file_io().filesystem def try_to_refresh_token(self): - identifier_str = str(self.identifier) - - if self.token is not None and not self._is_token_expired(self.token): - return - - cached_token = self._get_cached_token(identifier_str) - if cached_token and not self._is_token_expired(cached_token): - self.token = cached_token - return - - global_lock = self._get_global_token_lock(identifier_str) - - with global_lock: - cached_token = self._get_cached_token(identifier_str) - if cached_token and not self._is_token_expired(cached_token): - self.token = cached_token - return - - token_to_check = cached_token if cached_token else self.token - if token_to_check is None or self._is_token_expired(token_to_check): - self.refresh_token() - self._set_cached_token(identifier_str, self.token) - - def _get_cached_token(self, identifier_str: str) -> Optional[RESTToken]: - with self._TOKEN_LOCKS_LOCK: - return self._TOKEN_CACHE.get(identifier_str) - - def _set_cached_token(self, identifier_str: str, token: RESTToken): - with self._TOKEN_LOCKS_LOCK: - self._TOKEN_CACHE[identifier_str] = token - - def _is_token_expired(self, token: Optional[RESTToken]) -> bool: - if token is None: - return True - current_time = int(time.time() * 1000) - return (token.expire_at_millis - current_time) < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS - - def _get_global_token_lock(self, identifier_str: str) -> threading.Lock: - with self._TOKEN_LOCKS_LOCK: - if identifier_str not in self._TOKEN_LOCKS: - self._TOKEN_LOCKS[identifier_str] = threading.Lock() - return self._TOKEN_LOCKS[identifier_str] - - def should_refresh(self): + if self._should_refresh(): + with self.lock: + if self._should_refresh(): + self.refresh_token() + + def _should_refresh(self): if self.token is None: return True current_time = int(time.time() * 1000) - time_until_expiry = self.token.expire_at_millis - current_time - return time_until_expiry < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS + return (self.token.expire_at_millis - current_time) < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS def refresh_token(self): self.log.info(f"begin refresh data token for identifier [{self.identifier}]") if self.api_instance is None: self.api_instance = RESTApi(self.properties, False) - response = self.api_instance.load_table_token(self.identifier) + table_identifier = self.identifier + if SYSTEM_TABLE_SPLITTER in self.identifier.get_object_name(): + base_table = self.identifier.get_object_name().split(SYSTEM_TABLE_SPLITTER)[0] + table_identifier = Identifier( + database=self.identifier.get_database_name(), + object=base_table, + branch=self.identifier.get_branch_name()) + + response = self.api_instance.load_table_token(table_identifier) self.log.info( f"end refresh data token for identifier [{self.identifier}] expiresAtMillis [{response.expires_at_millis}]" ) - + merged_token_dict = self._merge_token_with_catalog_options(response.token) new_token = RESTToken(merged_token_dict, response.expires_at_millis) self.token = new_token diff --git a/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py b/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py index 47ea8e6cb6..cdcd5ed36c 100644 --- a/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py +++ b/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py @@ -18,10 +18,12 @@ import os import pickle import tempfile +import time import unittest -from unittest.mock import patch +from unittest.mock import patch, MagicMock from pypaimon.catalog.rest.rest_token_file_io import RESTTokenFileIO +from pypaimon.catalog.rest.rest_token import RESTToken from pypaimon.common.identifier import Identifier from pypaimon.common.options import Options @@ -101,13 +103,13 @@ class RESTTokenFileIOTest(unittest.TestCase): target_dir = os.path.join(self.temp_dir, "target_dir") os.makedirs(target_dir) - + result = file_io.try_to_write_atomic(f"file://{target_dir}", "test content") self.assertFalse(result, "try_to_write_atomic should return False when target is a directory") - + self.assertTrue(os.path.isdir(target_dir)) self.assertEqual(len(os.listdir(target_dir)), 0, "No file should be created inside the directory") - + normal_file = os.path.join(self.temp_dir, "normal_file.txt") result = file_io.try_to_write_atomic(f"file://{normal_file}", "test content") self.assertTrue(result, "try_to_write_atomic should succeed for a normal file path") @@ -223,35 +225,35 @@ class RESTTokenFileIOTest(unittest.TestCase): CatalogOptions.URI.key(): "http://test-uri", "custom.key": "custom.value" }) - + catalog_options_copy = Options(original_catalog_options.to_map()) - + with patch.object(RESTTokenFileIO, 'try_to_refresh_token'): file_io = RESTTokenFileIO( self.identifier, self.warehouse_path, original_catalog_options ) - + token_dict = { OssOptions.OSS_ACCESS_KEY_ID.key(): "token-access-key", OssOptions.OSS_ACCESS_KEY_SECRET.key(): "token-secret-key", OssOptions.OSS_ENDPOINT.key(): "token-endpoint" } - + merged_token = file_io._merge_token_with_catalog_options(token_dict) - + self.assertEqual( original_catalog_options.to_map(), catalog_options_copy.to_map(), "Original catalog_options should not be modified" ) - + merged_properties = RESTUtil.merge( original_catalog_options.to_map(), merged_token ) - + self.assertIn("custom.key", merged_properties) self.assertEqual(merged_properties["custom.key"], "custom.value") self.assertIn(OssOptions.OSS_ACCESS_KEY_ID.key(), merged_properties) @@ -264,11 +266,11 @@ class RESTTokenFileIOTest(unittest.TestCase): self.warehouse_path, self.catalog_options ) - + self.assertTrue(hasattr(file_io, 'filesystem'), "RESTTokenFileIO should have filesystem property") filesystem = file_io.filesystem self.assertIsNotNone(filesystem, "filesystem should not be None") - + self.assertTrue(hasattr(filesystem, 'open_input_file'), "filesystem should support open_input_file method") @@ -279,12 +281,12 @@ class RESTTokenFileIOTest(unittest.TestCase): self.warehouse_path, self.catalog_options ) - + self.assertTrue(hasattr(file_io, 'uri_reader_factory'), "RESTTokenFileIO should have uri_reader_factory property") uri_reader_factory = file_io.uri_reader_factory self.assertIsNotNone(uri_reader_factory, "uri_reader_factory should not be None") - + self.assertTrue(hasattr(uri_reader_factory, 'create'), "uri_reader_factory should support create method") @@ -295,15 +297,143 @@ class RESTTokenFileIOTest(unittest.TestCase): self.warehouse_path, self.catalog_options ) - + pickled = pickle.dumps(original_file_io) restored_file_io = pickle.loads(pickled) - + self.assertIsNotNone(restored_file_io.filesystem, "filesystem should work after deserialization") self.assertIsNotNone(restored_file_io.uri_reader_factory, "uri_reader_factory should work after deserialization") + def test_should_refresh_when_token_is_none(self): + """_should_refresh() returns True when token is None.""" + with patch.object(RESTTokenFileIO, 'try_to_refresh_token'): + file_io = RESTTokenFileIO( + self.identifier, self.warehouse_path, self.catalog_options) + self.assertIsNone(file_io.token) + self.assertTrue(file_io._should_refresh()) + + def test_should_refresh_when_token_not_expired(self): + """_should_refresh() returns False when token is far from expiry.""" + with patch.object(RESTTokenFileIO, 'try_to_refresh_token'): + file_io = RESTTokenFileIO( + self.identifier, self.warehouse_path, self.catalog_options) + # Token that expires 2 hours from now (well beyond the 1-hour safe margin) + future_millis = int(time.time() * 1000) + 7200_000 + file_io.token = RESTToken({'ak': 'v'}, future_millis) + self.assertFalse(file_io._should_refresh()) + + def test_should_refresh_when_token_expired(self): + """_should_refresh() returns True when token is already expired.""" + with patch.object(RESTTokenFileIO, 'try_to_refresh_token'): + file_io = RESTTokenFileIO( + self.identifier, self.warehouse_path, self.catalog_options) + # Token that expired 1 second ago + past_millis = int(time.time() * 1000) - 1000 + file_io.token = RESTToken({'ak': 'v'}, past_millis) + self.assertTrue(file_io._should_refresh()) + + def test_try_to_refresh_token_calls_refresh_once(self): + """try_to_refresh_token() calls refresh_token() exactly once via double-check.""" + file_io = RESTTokenFileIO( + self.identifier, self.warehouse_path, self.catalog_options) + self.assertIsNone(file_io.token) + + mock_response = MagicMock() + mock_response.token = {'ak': 'test-ak'} + mock_response.expires_at_millis = int(time.time() * 1000) + 7200_000 + + mock_api = MagicMock() + mock_api.load_table_token.return_value = mock_response + file_io.api_instance = mock_api + + file_io.try_to_refresh_token() + + mock_api.load_table_token.assert_called_once() + self.assertIsNotNone(file_io.token) + + # Second call should NOT trigger refresh again (token is valid) + file_io.try_to_refresh_token() + mock_api.load_table_token.assert_called_once() + + def test_refresh_token_strips_system_table_suffix(self): + """refresh_token() strips $snapshots suffix before requesting token.""" + system_identifier = Identifier.create("db", "my_table$snapshots") + file_io = RESTTokenFileIO( + system_identifier, self.warehouse_path, self.catalog_options) + + mock_response = MagicMock() + mock_response.token = {'ak': 'test-ak'} + mock_response.expires_at_millis = int(time.time() * 1000) + 7200_000 + + mock_api = MagicMock() + mock_api.load_table_token.return_value = mock_response + file_io.api_instance = mock_api + + file_io.refresh_token() + + # Verify load_table_token was called with base table identifier (no $snapshots) + called_identifier = mock_api.load_table_token.call_args[0][0] + self.assertEqual(called_identifier.get_database_name(), "db") + self.assertEqual(called_identifier.get_object_name(), "my_table") + + def test_refresh_token_keeps_normal_identifier(self): + """refresh_token() does not modify normal (non-system) identifiers.""" + normal_identifier = Identifier.create("db", "my_table") + file_io = RESTTokenFileIO( + normal_identifier, self.warehouse_path, self.catalog_options) + + mock_response = MagicMock() + mock_response.token = {'ak': 'test-ak'} + mock_response.expires_at_millis = int(time.time() * 1000) + 7200_000 + + mock_api = MagicMock() + mock_api.load_table_token.return_value = mock_response + file_io.api_instance = mock_api + + file_io.refresh_token() + + called_identifier = mock_api.load_table_token.call_args[0][0] + self.assertEqual(called_identifier.get_object_name(), "my_table") + + def test_different_instances_do_not_share_token(self): + """Two instances with same identifier get independent tokens (no class-level cache).""" + same_identifier = Identifier.from_string("db.shared_table") + + file_io_a = RESTTokenFileIO( + same_identifier, self.warehouse_path, self.catalog_options) + file_io_b = RESTTokenFileIO( + same_identifier, self.warehouse_path, self.catalog_options) + + token_a = RESTToken({'ak': 'ak-A'}, int(time.time() * 1000) + 7200_000) + token_b = RESTToken({'ak': 'ak-B'}, int(time.time() * 1000) + 7200_000) + + mock_response_a = MagicMock() + mock_response_a.token = token_a.token + mock_response_a.expires_at_millis = token_a.expire_at_millis + + mock_response_b = MagicMock() + mock_response_b.token = token_b.token + mock_response_b.expires_at_millis = token_b.expire_at_millis + + mock_api_a = MagicMock() + mock_api_a.load_table_token.return_value = mock_response_a + file_io_a.api_instance = mock_api_a + + mock_api_b = MagicMock() + mock_api_b.load_table_token.return_value = mock_response_b + file_io_b.api_instance = mock_api_b + + # Refresh both + file_io_a.try_to_refresh_token() + file_io_b.try_to_refresh_token() + + # Each instance should hold its own token + self.assertEqual(file_io_a.token.token['ak'], 'ak-A') + self.assertEqual(file_io_b.token.token['ak'], 'ak-B') + self.assertIsNot(file_io_a.token, file_io_b.token) + if __name__ == '__main__': unittest.main()
