This is an automated email from the ASF dual-hosted git repository.

hope pushed a commit to branch release-1.4
in repository https://gitbox.apache.org/repos/asf/paimon.git

commit a413c697c42d5a4739674777d96b712b233366f0
Author: shyjsarah <[email protected]>
AuthorDate: Tue Mar 31 00:52:57 2026 -0700

    [python] Fix token cache pollution in Python `RESTTokenFileIO` by aligning 
with the Java implementation. (#7562)
    
    Fix token cache pollution in Python `RESTTokenFileIO` by aligning with
    the Java implementation.
    
    **Problem:**
    Python's `RESTTokenFileIO` had a class-level `_TOKEN_CACHE` (keyed by
    table identifier string) that caused token sharing across different
    catalog instances. When two catalogs with different AK/SK credentials
    operated on the same table within one process, the second catalog would
    reuse the first catalog's token — leading to permission errors (e.g.,
    read AK/SK token used for write operations).
    
    **Root Cause:**
    Java's `RESTTokenFileIO` has **no token cache** — each instance manages
    its own `token` field independently. Python added an extra
    `_TOKEN_CACHE` class-level dict that Java never had.
    
    **Changes:**
    - Remove class-level `_TOKEN_CACHE`, `_TOKEN_LOCKS`, `_TOKEN_LOCKS_LOCK`
    and their associated methods (`_get_cached_token`, `_set_cached_token`,
    `_get_global_token_lock`, `_is_token_expired`)
    - Simplify `try_to_refresh_token()` to use instance-level lock with
    double-check pattern, aligned with Java's `tryToRefreshToken()`
    - Merge `should_refresh()` and `_is_token_expired()` into a single
    `_should_refresh()` method
    - Add system table identifier handling in `refresh_token()` (strip
    `$snapshots` suffix before requesting token), aligned with Java
    
    The `FILE_IO_CACHE` (keyed by `RESTToken` object) is kept unchanged — it
    correctly isolates different credentials since different tokens produce
    different cache keys.
---
 .../pypaimon/catalog/rest/rest_token_file_io.py    |  95 ++++--------
 .../pypaimon/tests/rest/rest_token_file_io_test.py | 164 ++++++++++++++++++---
 2 files changed, 177 insertions(+), 82 deletions(-)

diff --git a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py 
b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
index 7bf984d1f5..16b16e6972 100644
--- a/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
+++ b/paimon-python/pypaimon/catalog/rest/rest_token_file_io.py
@@ -27,7 +27,7 @@ from pypaimon.api.rest_util import RESTUtil
 from pypaimon.catalog.rest.rest_token import RESTToken
 from pypaimon.common.file_io import FileIO
 from pypaimon.filesystem.pyarrow_file_io import PyArrowFileIO
-from pypaimon.common.identifier import Identifier
+from pypaimon.common.identifier import Identifier, SYSTEM_TABLE_SPLITTER
 from pypaimon.common.options import Options
 from pypaimon.common.options.config import CatalogOptions, OssOptions
 from pypaimon.common.uri_reader import UriReaderFactory
@@ -37,17 +37,13 @@ class RESTTokenFileIO(FileIO):
     """
     A FileIO to support getting token from REST Server.
     """
-    
+
     _FILE_IO_CACHE_MAXSIZE = 1000
     _FILE_IO_CACHE_TTL = 36000  # 10 hours in seconds
-    
+
     _FILE_IO_CACHE: TTLCache = None
     _FILE_IO_CACHE_LOCK = threading.Lock()
-    
-    _TOKEN_CACHE: dict = {}
-    _TOKEN_LOCKS: dict = {}
-    _TOKEN_LOCKS_LOCK = threading.Lock()
-    
+
     @classmethod
     def _get_file_io_cache(cls) -> TTLCache:
         if cls._FILE_IO_CACHE is None:
@@ -58,7 +54,7 @@ class RESTTokenFileIO(FileIO):
                         ttl=cls._FILE_IO_CACHE_TTL
                     )
         return cls._FILE_IO_CACHE
-    
+
     def __init__(self, identifier: Identifier, path: str,
                  catalog_options: Optional[Union[dict, Options]] = None):
         self.identifier = identifier
@@ -99,26 +95,26 @@ class RESTTokenFileIO(FileIO):
 
         if self.token is None:
             return FileIO.get(self.path, self.catalog_options or Options({}))
-        
+
         cache_key = self.token
         cache = self._get_file_io_cache()
-        
+
         file_io = cache.get(cache_key)
         if file_io is not None:
             return file_io
-        
+
         with self._FILE_IO_CACHE_LOCK:
             self.try_to_refresh_token()
-            
+
             if self.token is None:
                 return FileIO.get(self.path, self.catalog_options or 
Options({}))
-            
+
             cache_key = self.token
             cache = self._get_file_io_cache()
             file_io = cache.get(cache_key)
             if file_io is not None:
                 return file_io
-            
+
             merged_properties = RESTUtil.merge(
                 self.catalog_options.to_map() if self.catalog_options else {},
                 self.token.token
@@ -128,7 +124,7 @@ class RESTTokenFileIO(FileIO):
                 if dlf_oss_endpoint and dlf_oss_endpoint.strip():
                     merged_properties[OssOptions.OSS_ENDPOINT.key()] = 
dlf_oss_endpoint
             merged_options = Options(merged_properties)
-            
+
             file_io = PyArrowFileIO(self.path, merged_options)
             cache[cache_key] = file_io
             return file_io
@@ -198,7 +194,7 @@ class RESTTokenFileIO(FileIO):
         if self._uri_reader_factory_cache is None:
             catalog_options = self.catalog_options or Options({})
             self._uri_reader_factory_cache = UriReaderFactory(catalog_options)
-        
+
         return self._uri_reader_factory_cache
 
     @property
@@ -206,66 +202,35 @@ class RESTTokenFileIO(FileIO):
         return self.file_io().filesystem
 
     def try_to_refresh_token(self):
-        identifier_str = str(self.identifier)
-        
-        if self.token is not None and not self._is_token_expired(self.token):
-            return
-        
-        cached_token = self._get_cached_token(identifier_str)
-        if cached_token and not self._is_token_expired(cached_token):
-            self.token = cached_token
-            return
-        
-        global_lock = self._get_global_token_lock(identifier_str)
-        
-        with global_lock:
-            cached_token = self._get_cached_token(identifier_str)
-            if cached_token and not self._is_token_expired(cached_token):
-                self.token = cached_token
-                return
-            
-            token_to_check = cached_token if cached_token else self.token
-            if token_to_check is None or 
self._is_token_expired(token_to_check):
-                self.refresh_token()
-                self._set_cached_token(identifier_str, self.token)
-
-    def _get_cached_token(self, identifier_str: str) -> Optional[RESTToken]:
-        with self._TOKEN_LOCKS_LOCK:
-            return self._TOKEN_CACHE.get(identifier_str)
-    
-    def _set_cached_token(self, identifier_str: str, token: RESTToken):
-        with self._TOKEN_LOCKS_LOCK:
-            self._TOKEN_CACHE[identifier_str] = token
-    
-    def _is_token_expired(self, token: Optional[RESTToken]) -> bool:
-        if token is None:
-            return True
-        current_time = int(time.time() * 1000)
-        return (token.expire_at_millis - current_time) < 
RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
-    
-    def _get_global_token_lock(self, identifier_str: str) -> threading.Lock:
-        with self._TOKEN_LOCKS_LOCK:
-            if identifier_str not in self._TOKEN_LOCKS:
-                self._TOKEN_LOCKS[identifier_str] = threading.Lock()
-            return self._TOKEN_LOCKS[identifier_str]
-
-    def should_refresh(self):
+        if self._should_refresh():
+            with self.lock:
+                if self._should_refresh():
+                    self.refresh_token()
+
+    def _should_refresh(self):
         if self.token is None:
             return True
         current_time = int(time.time() * 1000)
-        time_until_expiry = self.token.expire_at_millis - current_time
-        return time_until_expiry < RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
+        return (self.token.expire_at_millis - current_time) < 
RESTApi.TOKEN_EXPIRATION_SAFE_TIME_MILLIS
 
     def refresh_token(self):
         self.log.info(f"begin refresh data token for identifier 
[{self.identifier}]")
         if self.api_instance is None:
             self.api_instance = RESTApi(self.properties, False)
 
-        response = self.api_instance.load_table_token(self.identifier)
+        table_identifier = self.identifier
+        if SYSTEM_TABLE_SPLITTER in self.identifier.get_object_name():
+            base_table = 
self.identifier.get_object_name().split(SYSTEM_TABLE_SPLITTER)[0]
+            table_identifier = Identifier(
+                database=self.identifier.get_database_name(),
+                object=base_table,
+                branch=self.identifier.get_branch_name())
+
+        response = self.api_instance.load_table_token(table_identifier)
         self.log.info(
             f"end refresh data token for identifier [{self.identifier}] 
expiresAtMillis [{response.expires_at_millis}]"
         )
-        
+
         merged_token_dict = 
self._merge_token_with_catalog_options(response.token)
         new_token = RESTToken(merged_token_dict, response.expires_at_millis)
         self.token = new_token
diff --git a/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py 
b/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py
index 47ea8e6cb6..cdcd5ed36c 100644
--- a/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py
+++ b/paimon-python/pypaimon/tests/rest/rest_token_file_io_test.py
@@ -18,10 +18,12 @@
 import os
 import pickle
 import tempfile
+import time
 import unittest
-from unittest.mock import patch
+from unittest.mock import patch, MagicMock
 
 from pypaimon.catalog.rest.rest_token_file_io import RESTTokenFileIO
+from pypaimon.catalog.rest.rest_token import RESTToken
 
 from pypaimon.common.identifier import Identifier
 from pypaimon.common.options import Options
@@ -101,13 +103,13 @@ class RESTTokenFileIOTest(unittest.TestCase):
 
             target_dir = os.path.join(self.temp_dir, "target_dir")
             os.makedirs(target_dir)
-            
+
             result = file_io.try_to_write_atomic(f"file://{target_dir}", "test 
content")
             self.assertFalse(result, "try_to_write_atomic should return False 
when target is a directory")
-            
+
             self.assertTrue(os.path.isdir(target_dir))
             self.assertEqual(len(os.listdir(target_dir)), 0, "No file should 
be created inside the directory")
-            
+
             normal_file = os.path.join(self.temp_dir, "normal_file.txt")
             result = file_io.try_to_write_atomic(f"file://{normal_file}", 
"test content")
             self.assertTrue(result, "try_to_write_atomic should succeed for a 
normal file path")
@@ -223,35 +225,35 @@ class RESTTokenFileIOTest(unittest.TestCase):
             CatalogOptions.URI.key(): "http://test-uri";,
             "custom.key": "custom.value"
         })
-        
+
         catalog_options_copy = Options(original_catalog_options.to_map())
-        
+
         with patch.object(RESTTokenFileIO, 'try_to_refresh_token'):
             file_io = RESTTokenFileIO(
                 self.identifier,
                 self.warehouse_path,
                 original_catalog_options
             )
-            
+
             token_dict = {
                 OssOptions.OSS_ACCESS_KEY_ID.key(): "token-access-key",
                 OssOptions.OSS_ACCESS_KEY_SECRET.key(): "token-secret-key",
                 OssOptions.OSS_ENDPOINT.key(): "token-endpoint"
             }
-            
+
             merged_token = 
file_io._merge_token_with_catalog_options(token_dict)
-            
+
             self.assertEqual(
                 original_catalog_options.to_map(),
                 catalog_options_copy.to_map(),
                 "Original catalog_options should not be modified"
             )
-            
+
             merged_properties = RESTUtil.merge(
                 original_catalog_options.to_map(),
                 merged_token
             )
-            
+
             self.assertIn("custom.key", merged_properties)
             self.assertEqual(merged_properties["custom.key"], "custom.value")
             self.assertIn(OssOptions.OSS_ACCESS_KEY_ID.key(), 
merged_properties)
@@ -264,11 +266,11 @@ class RESTTokenFileIOTest(unittest.TestCase):
                 self.warehouse_path,
                 self.catalog_options
             )
-            
+
             self.assertTrue(hasattr(file_io, 'filesystem'), "RESTTokenFileIO 
should have filesystem property")
             filesystem = file_io.filesystem
             self.assertIsNotNone(filesystem, "filesystem should not be None")
-            
+
             self.assertTrue(hasattr(filesystem, 'open_input_file'),
                             "filesystem should support open_input_file method")
 
@@ -279,12 +281,12 @@ class RESTTokenFileIOTest(unittest.TestCase):
                 self.warehouse_path,
                 self.catalog_options
             )
-            
+
             self.assertTrue(hasattr(file_io, 'uri_reader_factory'),
                             "RESTTokenFileIO should have uri_reader_factory 
property")
             uri_reader_factory = file_io.uri_reader_factory
             self.assertIsNotNone(uri_reader_factory, "uri_reader_factory 
should not be None")
-            
+
             self.assertTrue(hasattr(uri_reader_factory, 'create'),
                             "uri_reader_factory should support create method")
 
@@ -295,15 +297,143 @@ class RESTTokenFileIOTest(unittest.TestCase):
                 self.warehouse_path,
                 self.catalog_options
             )
-            
+
             pickled = pickle.dumps(original_file_io)
             restored_file_io = pickle.loads(pickled)
-            
+
             self.assertIsNotNone(restored_file_io.filesystem,
                                  "filesystem should work after 
deserialization")
             self.assertIsNotNone(restored_file_io.uri_reader_factory,
                                  "uri_reader_factory should work after 
deserialization")
 
+    def test_should_refresh_when_token_is_none(self):
+        """_should_refresh() returns True when token is None."""
+        with patch.object(RESTTokenFileIO, 'try_to_refresh_token'):
+            file_io = RESTTokenFileIO(
+                self.identifier, self.warehouse_path, self.catalog_options)
+            self.assertIsNone(file_io.token)
+            self.assertTrue(file_io._should_refresh())
+
+    def test_should_refresh_when_token_not_expired(self):
+        """_should_refresh() returns False when token is far from expiry."""
+        with patch.object(RESTTokenFileIO, 'try_to_refresh_token'):
+            file_io = RESTTokenFileIO(
+                self.identifier, self.warehouse_path, self.catalog_options)
+            # Token that expires 2 hours from now (well beyond the 1-hour safe 
margin)
+            future_millis = int(time.time() * 1000) + 7200_000
+            file_io.token = RESTToken({'ak': 'v'}, future_millis)
+            self.assertFalse(file_io._should_refresh())
+
+    def test_should_refresh_when_token_expired(self):
+        """_should_refresh() returns True when token is already expired."""
+        with patch.object(RESTTokenFileIO, 'try_to_refresh_token'):
+            file_io = RESTTokenFileIO(
+                self.identifier, self.warehouse_path, self.catalog_options)
+            # Token that expired 1 second ago
+            past_millis = int(time.time() * 1000) - 1000
+            file_io.token = RESTToken({'ak': 'v'}, past_millis)
+            self.assertTrue(file_io._should_refresh())
+
+    def test_try_to_refresh_token_calls_refresh_once(self):
+        """try_to_refresh_token() calls refresh_token() exactly once via 
double-check."""
+        file_io = RESTTokenFileIO(
+            self.identifier, self.warehouse_path, self.catalog_options)
+        self.assertIsNone(file_io.token)
+
+        mock_response = MagicMock()
+        mock_response.token = {'ak': 'test-ak'}
+        mock_response.expires_at_millis = int(time.time() * 1000) + 7200_000
+
+        mock_api = MagicMock()
+        mock_api.load_table_token.return_value = mock_response
+        file_io.api_instance = mock_api
+
+        file_io.try_to_refresh_token()
+
+        mock_api.load_table_token.assert_called_once()
+        self.assertIsNotNone(file_io.token)
+
+        # Second call should NOT trigger refresh again (token is valid)
+        file_io.try_to_refresh_token()
+        mock_api.load_table_token.assert_called_once()
+
+    def test_refresh_token_strips_system_table_suffix(self):
+        """refresh_token() strips $snapshots suffix before requesting token."""
+        system_identifier = Identifier.create("db", "my_table$snapshots")
+        file_io = RESTTokenFileIO(
+            system_identifier, self.warehouse_path, self.catalog_options)
+
+        mock_response = MagicMock()
+        mock_response.token = {'ak': 'test-ak'}
+        mock_response.expires_at_millis = int(time.time() * 1000) + 7200_000
+
+        mock_api = MagicMock()
+        mock_api.load_table_token.return_value = mock_response
+        file_io.api_instance = mock_api
+
+        file_io.refresh_token()
+
+        # Verify load_table_token was called with base table identifier (no 
$snapshots)
+        called_identifier = mock_api.load_table_token.call_args[0][0]
+        self.assertEqual(called_identifier.get_database_name(), "db")
+        self.assertEqual(called_identifier.get_object_name(), "my_table")
+
+    def test_refresh_token_keeps_normal_identifier(self):
+        """refresh_token() does not modify normal (non-system) identifiers."""
+        normal_identifier = Identifier.create("db", "my_table")
+        file_io = RESTTokenFileIO(
+            normal_identifier, self.warehouse_path, self.catalog_options)
+
+        mock_response = MagicMock()
+        mock_response.token = {'ak': 'test-ak'}
+        mock_response.expires_at_millis = int(time.time() * 1000) + 7200_000
+
+        mock_api = MagicMock()
+        mock_api.load_table_token.return_value = mock_response
+        file_io.api_instance = mock_api
+
+        file_io.refresh_token()
+
+        called_identifier = mock_api.load_table_token.call_args[0][0]
+        self.assertEqual(called_identifier.get_object_name(), "my_table")
+
+    def test_different_instances_do_not_share_token(self):
+        """Two instances with same identifier get independent tokens (no 
class-level cache)."""
+        same_identifier = Identifier.from_string("db.shared_table")
+
+        file_io_a = RESTTokenFileIO(
+            same_identifier, self.warehouse_path, self.catalog_options)
+        file_io_b = RESTTokenFileIO(
+            same_identifier, self.warehouse_path, self.catalog_options)
+
+        token_a = RESTToken({'ak': 'ak-A'}, int(time.time() * 1000) + 7200_000)
+        token_b = RESTToken({'ak': 'ak-B'}, int(time.time() * 1000) + 7200_000)
+
+        mock_response_a = MagicMock()
+        mock_response_a.token = token_a.token
+        mock_response_a.expires_at_millis = token_a.expire_at_millis
+
+        mock_response_b = MagicMock()
+        mock_response_b.token = token_b.token
+        mock_response_b.expires_at_millis = token_b.expire_at_millis
+
+        mock_api_a = MagicMock()
+        mock_api_a.load_table_token.return_value = mock_response_a
+        file_io_a.api_instance = mock_api_a
+
+        mock_api_b = MagicMock()
+        mock_api_b.load_table_token.return_value = mock_response_b
+        file_io_b.api_instance = mock_api_b
+
+        # Refresh both
+        file_io_a.try_to_refresh_token()
+        file_io_b.try_to_refresh_token()
+
+        # Each instance should hold its own token
+        self.assertEqual(file_io_a.token.token['ak'], 'ak-A')
+        self.assertEqual(file_io_b.token.token['ak'], 'ak-B')
+        self.assertIsNot(file_io_a.token, file_io_b.token)
+
 
 if __name__ == '__main__':
     unittest.main()

Reply via email to