This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch engine-manager
in repository https://gitbox.apache.org/repos/asf/superset.git

commit be31abeb7ee06f76868f990db86290fdf7b5cc48
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed Dec 3 17:26:28 2025 -0500

    Hash key
---
 superset/engines/manager.py              | 144 +++++++++++++--
 superset/extensions/engine_manager.py    |  11 +-
 superset/models/core.py                  |   8 +-
 superset/superset_typing.py              |   4 +-
 tests/unit_tests/engines/manager_test.py | 295 +++++++++++++++++++++++++++++++
 5 files changed, 432 insertions(+), 30 deletions(-)

diff --git a/superset/engines/manager.py b/superset/engines/manager.py
index adbe461e8bd..1d5d0efec65 100644
--- a/superset/engines/manager.py
+++ b/superset/engines/manager.py
@@ -16,6 +16,8 @@
 # under the License.
 
 import enum
+import hashlib
+import json
 import logging
 import threading
 from contextlib import contextmanager
@@ -33,7 +35,6 @@ from sshtunnel import SSHTunnelForwarder
 from superset.databases.utils import make_url_safe
 from superset.superset_typing import DBConnectionMutator, EngineContextManager
 from superset.utils.core import get_query_source_from_request, get_user_id, 
QuerySource
-from superset.utils.json import dumps
 
 if TYPE_CHECKING:
     from superset.databases.ssh_tunnel.models import SSHTunnel
@@ -48,7 +49,16 @@ class _LockManager:
     Manages per-key locks safely without defaultdict race conditions.
 
     This class provides a thread-safe way to create and manage locks for 
specific keys,
-    avoiding race conditions.
+    avoiding the race conditions that occur when using defaultdict with 
threading.Lock.
+
+    The implementation uses a two-level locking strategy:
+    1. A meta-lock to protect the lock dictionary itself
+    2. Per-key locks to protect specific resources
+
+    This ensures that:
+    - Different keys can be locked concurrently (scalability)
+    - Lock creation is thread-safe (no race conditions)
+    - The same key always gets the same lock instance
     """
 
     def __init__(self) -> None:
@@ -58,6 +68,16 @@ class _LockManager:
     def get_lock(self, key: str) -> threading.RLock:
         """
         Get or create a lock for the given key.
+
+        This method uses double-checked locking to ensure thread safety:
+        1. First check without lock (fast path)
+        2. Acquire meta-lock if needed
+        3. Double-check inside the lock to prevent race conditions
+
+        This approach minimizes lock contention while ensuring correctness.
+
+        :param key: The key to get a lock for
+        :returns: An RLock instance for the given key
         """
         if lock := self._locks.get(key):
             return lock
@@ -73,6 +93,11 @@ class _LockManager:
     def cleanup(self, active_keys: set[str]) -> None:
         """
         Remove locks for keys that are no longer in use.
+
+        This prevents memory leaks from accumulating locks for resources
+        that have been disposed.
+
+        :param active_keys: Set of keys that are still active
         """
         with self._meta_lock:
             # Find locks to remove
@@ -85,6 +110,64 @@ EngineKey = str
 TunnelKey = str
 
 
+def _normalize_value(value: Any) -> str:
+    """
+    Normalize a value for consistent hashing.
+
+    Converts various types to a consistent string representation for hashing.
+    Handles special cases like bytes, class objects, and nested structures.
+
+    :param value: The value to normalize
+    :returns: String representation suitable for hashing
+    """
+    if isinstance(value, bytes):
+        # For binary data (like private keys), hash it to avoid encoding issues
+        return hashlib.sha256(value).hexdigest()[:16]
+    elif isinstance(value, type):
+        # For class objects (like pool classes), use the class name
+        return value.__name__
+    elif isinstance(value, dict):
+        # For nested dicts, recursively normalize
+        normalized_dict = {}
+        for k, v in sorted(value.items()):
+            normalized_dict[k] = _normalize_value(v)
+        return json.dumps(normalized_dict, sort_keys=True, separators=(",", 
":"))
+    elif isinstance(value, (list, tuple)):
+        # For lists/tuples, normalize each item
+        normalized_list = [_normalize_value(item) for item in value]
+        return json.dumps(normalized_list, separators=(",", ":"))
+    else:
+        # For everything else, convert to string
+        return str(value)
+
+
+def _generate_secure_key(components: dict[str, Any]) -> str:
+    """
+    Generate a secure hash-based key from components.
+
+    Creates a SHA-256 hash of the components to ensure:
+    1. The key includes all parameters for proper caching
+    2. Sensitive data is not exposed in logs or errors
+    3. The key is deterministic for the same inputs
+
+    :param components: Dictionary of components to hash
+    :returns: 32-character hex string representing the secure key
+    """
+    # Create deterministic string representation
+    # Sort keys for consistency
+    key_data = {
+        k: _normalize_value(v) if v is not None else ""
+        for k, v in sorted(components.items())
+    }
+
+    # Create compact JSON representation
+    key_string = json.dumps(key_data, sort_keys=True, separators=(",", ":"))
+
+    # Generate SHA-256 hash and return first 32 hex characters
+    # 32 characters = 128 bits of entropy, sufficient for collision resistance
+    return hashlib.sha256(key_string.encode("utf-8")).hexdigest()[:32]
+
+
 class EngineModes(enum.Enum):
     # reuse existing engine if available, otherwise create a new one; this 
mode should
     # have a connection pool configured in the database
@@ -231,19 +314,37 @@ class EngineManager:
         user_id: int | None,
     ) -> EngineKey:
         """
-        Generate a unique key for the engine based on the database and context.
+        Generate a secure hash-based key for the engine.
+
+        The key includes all parameters (including OAuth tokens and other 
sensitive
+        data) to ensure proper cache isolation, but uses a one-way hash to 
prevent
+        credential exposure in logs or errors.
+
+        :returns: 32-character hex string representing the secure key
         """
-        uri, keys = self._get_engine_args(
+        # Get all parameters that affect the engine
+        uri, kwargs = self._get_engine_args(
             database,
             catalog,
             schema,
             source,
             user_id,
         )
-        keys["uri"] = uri
-        keys["source"] = source
 
-        return dumps(keys, sort_keys=True)
+        # Create components for the key
+        # Include all parameters to ensure proper cache isolation
+        key_components = {
+            "database_id": database.id,
+            "catalog": catalog,
+            "schema": schema,
+            "uri": str(uri),  # SQLAlchemy URLs mask passwords
+            "source": str(source) if source else None,
+            "user_id": user_id,
+            "kwargs": kwargs,  # Includes OAuth tokens and other sensitive 
params
+        }
+
+        # Generate secure hash-based key
+        return _generate_secure_key(key_components)
 
     def _get_engine_args(
         self,
@@ -432,11 +533,20 @@ class EngineManager:
 
     def _get_tunnel_key(self, ssh_tunnel: "SSHTunnel", uri: URL) -> TunnelKey:
         """
-        Build a unique key for the SSH tunnel.
+        Generate a secure hash-based key for the SSH tunnel.
+
+        The key includes all tunnel parameters (including passwords and 
private keys)
+        to ensure proper cache isolation, but uses a one-way hash to prevent
+        credential exposure in logs or errors.
+
+        :returns: 32-character hex string representing the secure key
         """
-        keys = self._get_tunnel_kwargs(ssh_tunnel, uri)
+        # Get all tunnel parameters
+        tunnel_kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
 
-        return dumps(keys, sort_keys=True)
+        # Generate secure hash-based key
+        # The tunnel_kwargs may contain sensitive data like passwords and 
private keys
+        return _generate_secure_key(tunnel_kwargs)
 
     def _create_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> 
SSHTunnelForwarder:
         kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
@@ -565,12 +675,10 @@ class EngineManager:
             try:
                 # Remove engine from cache - no per-key locks to clean up 
anymore
                 if self._engines.pop(engine_key, None):
-                    logger.info(
-                        "Engine disposed and removed from cache: %s", 
engine_key
-                    )
+                    # Log only first 8 chars of hash for safety
+                    # (still enough for debugging, but doesn't expose full key)
+                    log_key = engine_key[:8] + "..."
+                    logger.info("Engine disposed and removed from cache: %s", 
log_key)
             except Exception as ex:
-                logger.error(
-                    "Error during engine disposal cleanup for %s: %s",
-                    engine_key,
-                    str(ex),
-                )
+                logger.error("Error during engine disposal cleanup: %s", 
str(ex))
+                # Don't log engine_key to avoid exposing credential hash
diff --git a/superset/extensions/engine_manager.py 
b/superset/extensions/engine_manager.py
index 5a4cd0301b5..df391ad5d8e 100644
--- a/superset/extensions/engine_manager.py
+++ b/superset/extensions/engine_manager.py
@@ -17,15 +17,11 @@
 
 import logging
 from datetime import timedelta
-from typing import TYPE_CHECKING
 
 from flask import Flask
 
 from superset.engines.manager import EngineManager, EngineModes
 
-if TYPE_CHECKING:
-    pass
-
 logger = logging.getLogger(__name__)
 
 
@@ -74,7 +70,12 @@ class EngineManagerExtension:
         def shutdown_engine_manager() -> None:
             if self.engine_manager:
                 self.engine_manager.stop_cleanup_thread()
-                logger.info("Stopped EngineManager cleanup thread")
+                # Use a try-except to handle closed log file handlers during 
tests
+                try:
+                    logger.info("Stopped EngineManager cleanup thread")
+                except ValueError:
+                    # Ignore logging errors during test shutdown when file 
handles are closed
+                    pass
 
         app.teardown_appcontext_funcs.append(lambda exc: None)
 
diff --git a/superset/models/core.py b/superset/models/core.py
index 04d3c533bd4..f20e0fce1e5 100755
--- a/superset/models/core.py
+++ b/superset/models/core.py
@@ -136,9 +136,7 @@ class ConfigurationMethod(StrEnum):
     DYNAMIC_FORM = "dynamic_form"
 
 
-class Database(
-    CoreDatabase, AuditMixinNullable, ImportExportMixin
-):  # pylint: disable=too-many-public-methods
+class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin):  # 
pylint: disable=too-many-public-methods
     """An ORM object that stores Database related information"""
 
     __tablename__ = "dbs"
@@ -415,7 +413,9 @@ class Database(
         return (
             username
             if (username := get_username())
-            else object_url.username if self.impersonate_user else None
+            else object_url.username
+            if self.impersonate_user
+            else None
         )
 
     @contextmanager
diff --git a/superset/superset_typing.py b/superset/superset_typing.py
index 5a69250b8bb..e3252483e84 100644
--- a/superset/superset_typing.py
+++ b/superset/superset_typing.py
@@ -34,6 +34,7 @@ from typing_extensions import NotRequired
 from werkzeug.wrappers import Response
 
 if TYPE_CHECKING:
+    from superset.models.core import Database
     from superset.utils.core import (
         GenericDataType,
         QueryObjectFilterClause,
@@ -49,9 +50,6 @@ DBConnectionMutator: TypeAlias = Callable[
 ]
 
 # Type alias for engine context manager
-if TYPE_CHECKING:
-    from superset.models.core import Database
-
 EngineContextManager: TypeAlias = Callable[
     ["Database", str | None, str | None], ContextManager[None]
 ]
diff --git a/tests/unit_tests/engines/manager_test.py 
b/tests/unit_tests/engines/manager_test.py
new file mode 100644
index 00000000000..287820eaf1b
--- /dev/null
+++ b/tests/unit_tests/engines/manager_test.py
@@ -0,0 +1,295 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Unit tests for EngineManager."""
+
+import threading
+from collections.abc import Iterator
+from unittest.mock import MagicMock, patch
+
+import pytest
+from sqlalchemy.pool import NullPool
+
+from superset.engines.manager import _LockManager, EngineManager, EngineModes
+
+
+class TestLockManager:
+    """Test the _LockManager class."""
+
+    def test_get_lock_creates_new_lock(self):
+        """Test that get_lock creates a new lock when needed."""
+        manager = _LockManager()
+        lock1 = manager.get_lock("key1")
+
+        assert isinstance(lock1, type(threading.RLock()))
+        assert lock1 is manager.get_lock("key1")  # Same lock returned
+
+    def test_get_lock_different_keys_different_locks(self):
+        """Test that different keys get different locks."""
+        manager = _LockManager()
+        lock1 = manager.get_lock("key1")
+        lock2 = manager.get_lock("key2")
+
+        assert lock1 is not lock2
+
+    def test_cleanup_removes_unused_locks(self):
+        """Test that cleanup removes locks for inactive keys."""
+        manager = _LockManager()
+
+        # Create locks
+        lock1 = manager.get_lock("key1")
+        lock2 = manager.get_lock("key2")
+
+        # Cleanup with only key1 active
+        manager.cleanup({"key1"})
+
+        # key2 lock should be removed
+        lock3 = manager.get_lock("key2")
+        assert lock3 is not lock2  # New lock created
+
+    def test_concurrent_lock_creation(self):
+        """Test that concurrent lock creation doesn't create duplicates."""
+        manager = _LockManager()
+        locks_created = []
+        exceptions = []
+
+        def create_lock():
+            try:
+                lock = manager.get_lock("concurrent_key")
+                locks_created.append(lock)
+            except Exception as e:
+                exceptions.append(e)
+
+        # Create multiple threads trying to get the same lock
+        threads = [threading.Thread(target=create_lock) for _ in range(10)]
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+        assert len(exceptions) == 0
+        assert len(locks_created) == 10
+
+        # All should be the same lock
+        first_lock = locks_created[0]
+        for lock in locks_created[1:]:
+            assert lock is first_lock
+
+
+class TestEngineManager:
+    """Test the EngineManager class."""
+
+    @pytest.fixture
+    def engine_manager(self):
+        """Create a mock EngineManager instance."""
+
+        def dummy_context_manager(
+            database: MagicMock, catalog: str | None, schema: str | None
+        ) -> Iterator[None]:
+            yield
+
+        return EngineManager(engine_context_manager=dummy_context_manager)
+
+    @pytest.fixture
+    def mock_database(self):
+        """Create a mock database."""
+        database = MagicMock()
+        database.sqlalchemy_uri_decrypted = 
"postgresql://user:pass@localhost/test"
+        database.get_extra.return_value = {"engine_params": {"poolclass": 
NullPool}}
+        database.get_effective_user.return_value = "test_user"
+        database.impersonate_user = False
+        database.update_params_from_encrypted_extra = MagicMock()
+        database.db_engine_spec = MagicMock()
+        database.db_engine_spec.adjust_engine_params.return_value = 
(MagicMock(), {})
+        database.db_engine_spec.impersonate_user = MagicMock(
+            return_value=(MagicMock(), {})
+        )
+        database.db_engine_spec.validate_database_uri = MagicMock()
+        database.ssh_tunnel = None
+        return database
+
+    @patch("superset.engines.manager.create_engine")
+    @patch("superset.engines.manager.make_url_safe")
+    def test_get_engine_new_mode(
+        self, mock_make_url, mock_create_engine, engine_manager, mock_database
+    ):
+        """Test getting an engine in NEW mode (no caching)."""
+        engine_manager.mode = EngineModes.NEW
+
+        mock_make_url.return_value = MagicMock()
+        mock_engine1 = MagicMock()
+        mock_engine2 = MagicMock()
+        mock_create_engine.side_effect = [mock_engine1, mock_engine2]
+
+        result = engine_manager._get_engine(mock_database, "catalog1", 
"schema1", None)
+
+        assert result is mock_engine1
+        mock_create_engine.assert_called_once()
+
+        # Calling again should create a new engine (no caching)
+        mock_create_engine.reset_mock()
+        result2 = engine_manager._get_engine(mock_database, "catalog2", 
"schema2", None)
+
+        assert result2 is mock_engine2  # Different engine
+        mock_create_engine.assert_called_once()
+
+    @patch("superset.engines.manager.create_engine")
+    @patch("superset.engines.manager.make_url_safe")
+    def test_get_engine_singleton_mode_caching(
+        self, mock_make_url, mock_create_engine, engine_manager, mock_database
+    ):
+        """Test that engines are cached in SINGLETON mode."""
+        engine_manager.mode = EngineModes.SINGLETON
+
+        # Use a real engine instead of MagicMock to avoid event listener issues
+        from sqlalchemy import create_engine
+        from sqlalchemy.pool import StaticPool
+
+        real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool)
+        mock_create_engine.return_value = real_engine
+        mock_make_url.return_value = real_engine
+
+        # Call twice with same params - should be cached
+        result1 = engine_manager._get_engine(mock_database, "catalog1", 
"schema1", None)
+        result2 = engine_manager._get_engine(mock_database, "catalog1", 
"schema1", None)
+
+        assert result1 is result2  # Same engine returned (cached)
+        mock_create_engine.assert_called_once()  # Only created once
+
+        # Call with different params - should create new engine
+
+    @patch("superset.engines.manager.create_engine")
+    @patch("superset.engines.manager.make_url_safe")
+    def test_concurrent_engine_creation(
+        self, mock_make_url, mock_create_engine, engine_manager, mock_database
+    ):
+        """Test concurrent engine creation doesn't create duplicates."""
+        engine_manager.mode = EngineModes.SINGLETON
+
+        # Use a real engine to avoid event listener issues with MagicMock
+        from sqlalchemy import create_engine
+        from sqlalchemy.pool import StaticPool
+
+        real_engine = create_engine("sqlite:///:memory:", poolclass=StaticPool)
+        mock_make_url.return_value = real_engine
+
+        create_count = [0]
+
+        def counting_create_engine(*args, **kwargs):
+            create_count[0] += 1
+            return real_engine
+
+        mock_create_engine.side_effect = counting_create_engine
+
+        results = []
+        exceptions = []
+
+        def get_engine_thread():
+            try:
+                engine = engine_manager._get_engine(
+                    mock_database, "catalog1", "schema1", None
+                )
+                results.append(engine)
+            except Exception as e:
+                exceptions.append(e)
+
+        # Run multiple threads
+        threads = [threading.Thread(target=get_engine_thread) for _ in 
range(10)]
+        for t in threads:
+            t.start()
+        for t in threads:
+            t.join()
+
+        assert len(exceptions) == 0
+        assert len(results) == 10
+        assert create_count[0] == 1  # Engine created only once
+
+        # All results should be the same engine
+        for engine in results:
+            assert engine is real_engine
+
+    @patch("superset.engines.manager.SSHTunnelForwarder")
+    def test_ssh_tunnel_creation(self, mock_tunnel_class, engine_manager):
+        """Test SSH tunnel creation and caching."""
+        ssh_tunnel = MagicMock()
+        ssh_tunnel.server_address = "ssh.example.com"
+        ssh_tunnel.server_port = 22
+        ssh_tunnel.username = "ssh_user"
+        ssh_tunnel.password = "ssh_pass"
+        ssh_tunnel.private_key = None
+        ssh_tunnel.private_key_password = None
+
+        tunnel_instance = MagicMock()
+        tunnel_instance.is_active = True
+        tunnel_instance.local_bind_address = ("127.0.0.1", 12345)
+        mock_tunnel_class.return_value = tunnel_instance
+
+        uri = MagicMock()
+        uri.host = "db.example.com"
+        uri.port = 5432
+        uri.get_backend_name.return_value = "postgresql"
+
+        result = engine_manager._get_tunnel(ssh_tunnel, uri)
+
+        assert result is tunnel_instance
+        mock_tunnel_class.assert_called_once()
+
+        # Getting same tunnel again should return cached version
+        mock_tunnel_class.reset_mock()
+        result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
+
+        assert result2 is tunnel_instance
+        mock_tunnel_class.assert_not_called()
+
+    @patch("superset.engines.manager.SSHTunnelForwarder")
+    def test_ssh_tunnel_recreation_when_inactive(
+        self, mock_tunnel_class, engine_manager
+    ):
+        """Test that inactive tunnels are replaced."""
+        ssh_tunnel = MagicMock()
+        ssh_tunnel.server_address = "ssh.example.com"
+        ssh_tunnel.server_port = 22
+        ssh_tunnel.username = "ssh_user"
+        ssh_tunnel.password = "ssh_pass"
+        ssh_tunnel.private_key = None
+        ssh_tunnel.private_key_password = None
+
+        # First tunnel is inactive
+        inactive_tunnel = MagicMock()
+        inactive_tunnel.is_active = False
+        inactive_tunnel.local_bind_address = ("127.0.0.1", 12345)
+
+        # Second tunnel is active
+        active_tunnel = MagicMock()
+        active_tunnel.is_active = True
+        active_tunnel.local_bind_address = ("127.0.0.1", 23456)
+
+        mock_tunnel_class.side_effect = [inactive_tunnel, active_tunnel]
+
+        uri = MagicMock()
+        uri.host = "db.example.com"
+        uri.port = 5432
+        uri.get_backend_name.return_value = "postgresql"
+
+        # First call creates inactive tunnel
+        result1 = engine_manager._get_tunnel(ssh_tunnel, uri)
+        assert result1 is inactive_tunnel
+
+        # Second call should create new tunnel since first is inactive
+        result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
+        assert result2 is active_tunnel
+        assert mock_tunnel_class.call_count == 2

Reply via email to