This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch engine-manager in repository https://gitbox.apache.org/repos/asf/superset.git
commit 5f61bb8d76a8e52c78f66c59ee8a941163591492 Author: Beto Dealmeida <[email protected]> AuthorDate: Wed Dec 3 15:59:46 2025 -0500 Small improvements --- superset/config.py | 11 ++- superset/engines/manager.py | 169 +++++++++++++++++++------------- superset/extensions/__init__.py | 2 - superset/extensions/engine_manager.py | 17 +++- superset/extensions/ssh.py | 94 ------------------ superset/initialization/__init__.py | 5 - superset/models/core.py | 17 ++-- superset/superset_typing.py | 31 +++++- tests/unit_tests/extensions/ssh_test.py | 36 ------- 9 files changed, 163 insertions(+), 219 deletions(-) diff --git a/superset/config.py b/superset/config.py index ee4dc92feca..4110d902ffa 100644 --- a/superset/config.py +++ b/superset/config.py @@ -56,7 +56,11 @@ from superset.engines.manager import EngineModes from superset.jinja_context import BaseTemplateProcessor from superset.key_value.types import JsonKeyValueCodec from superset.stats_logger import DummyStatsLogger -from superset.superset_typing import CacheConfig +from superset.superset_typing import ( + CacheConfig, + DBConnectionMutator, + EngineContextManager, +) from superset.tasks.types import ExecutorType from superset.themes.types import Theme from superset.utils import core as utils @@ -826,7 +830,6 @@ DEFAULT_FEATURE_FLAGS: dict[str, bool] = { # FIREWALL (only port 22 is open) # ---------------------------------------------------------------------- -SSH_TUNNEL_MANAGER_CLASS = "superset.extensions.ssh.SSHManager" SSH_TUNNEL_LOCAL_BIND_ADDRESS = "127.0.0.1" #: Timeout (seconds) for tunnel connection (open_channel timeout) SSH_TUNNEL_TIMEOUT_SEC = 10.0 @@ -1701,7 +1704,7 @@ def engine_context_manager( # pylint: disable=unused-argument yield None -ENGINE_CONTEXT_MANAGER = engine_context_manager +ENGINE_CONTEXT_MANAGER: EngineContextManager = engine_context_manager # A callable that allows altering the database connection URL and params # on the fly, at runtime. This allows for things like impersonation or @@ -1718,7 +1721,7 @@ ENGINE_CONTEXT_MANAGER = engine_context_manager # # Note that the returned uri and params are passed directly to sqlalchemy's # as such `create_engine(url, **params)` -DB_CONNECTION_MUTATOR = None +DB_CONNECTION_MUTATOR: DBConnectionMutator | None = None # A callable that is invoked for every invocation of DB Engine Specs diff --git a/superset/engines/manager.py b/superset/engines/manager.py index cbe855fd1fd..adbe461e8bd 100644 --- a/superset/engines/manager.py +++ b/superset/engines/manager.py @@ -18,13 +18,12 @@ import enum import logging import threading -from collections import defaultdict from contextlib import contextmanager from datetime import timedelta from io import StringIO -from typing import Any, TYPE_CHECKING +from typing import Any, Iterator, TYPE_CHECKING -from flask import current_app +import sshtunnel from paramiko import RSAKey from sqlalchemy import create_engine, event, pool from sqlalchemy.engine import Engine @@ -32,6 +31,7 @@ from sqlalchemy.engine.url import URL from sshtunnel import SSHTunnelForwarder from superset.databases.utils import make_url_safe +from superset.superset_typing import DBConnectionMutator, EngineContextManager from superset.utils.core import get_query_source_from_request, get_user_id, QuerySource from superset.utils.json import dumps @@ -43,6 +43,44 @@ if TYPE_CHECKING: logger = logging.getLogger(__name__) +class _LockManager: + """ + Manages per-key locks safely without defaultdict race conditions. + + This class provides a thread-safe way to create and manage locks for specific keys, + avoiding race conditions. + """ + + def __init__(self) -> None: + self._locks: dict[str, threading.RLock] = {} + self._meta_lock = threading.Lock() + + def get_lock(self, key: str) -> threading.RLock: + """ + Get or create a lock for the given key. + """ + if lock := self._locks.get(key): + return lock + + with self._meta_lock: + # Double-check inside the lock + lock = self._locks.get(key) + if lock is None: + lock = threading.RLock() + self._locks[key] = lock + return lock + + def cleanup(self, active_keys: set[str]) -> None: + """ + Remove locks for keys that are no longer in use. + """ + with self._meta_lock: + # Find locks to remove + locks_to_remove = self._locks.keys() - active_keys + for key in locks_to_remove: + self._locks.pop(key, None) + + EngineKey = str TunnelKey = str @@ -71,21 +109,27 @@ class EngineManager: def __init__( self, + engine_context_manager: EngineContextManager, + db_connection_mutator: DBConnectionMutator | None = None, mode: EngineModes = EngineModes.NEW, cleanup_interval: timedelta = timedelta(minutes=5), + local_bind_address: str = "127.0.0.1", + tunnel_timeout: timedelta = timedelta(seconds=30), + ssh_timeout: timedelta = timedelta(seconds=1), ) -> None: + self.engine_context_manager = engine_context_manager + self.db_connection_mutator = db_connection_mutator self.mode = mode self.cleanup_interval = cleanup_interval + self.local_bind_address = local_bind_address - self._engines: dict[EngineKey, Engine] = {} - self._engine_locks: dict[EngineKey, threading.Lock] = defaultdict( - threading.Lock - ) + sshtunnel.TUNNEL_TIMEOUT = tunnel_timeout.total_seconds() + sshtunnel.SSH_TIMEOUT = ssh_timeout.total_seconds() + self._engines: dict[EngineKey, Engine] = {} + self._engine_locks = _LockManager() self._tunnels: dict[TunnelKey, SSHTunnelForwarder] = {} - self._tunnel_locks: dict[TunnelKey, threading.Lock] = defaultdict( - threading.Lock - ) + self._tunnel_locks = _LockManager() # Background cleanup thread management self._cleanup_thread: threading.Thread | None = None @@ -113,14 +157,13 @@ class EngineManager: catalog: str | None, schema: str | None, source: QuerySource | None, - ) -> Engine: + ) -> Iterator[Engine]: """ Context manager to get a SQLAlchemy engine. """ # users can wrap the engine in their own context manager for different # reasons - customization = current_app.config["ENGINE_CONTEXT_MANAGER"] - with customization(database, catalog, schema): + with self.engine_context_manager(database, catalog, schema): # we need to check for errors indicating that OAuth2 is needed, and # return the proper exception so it starts the authentication flow from superset.utils.oauth2 import check_for_oauth2 @@ -158,22 +201,26 @@ class EngineManager: user_id, ) - if engine_key not in self._engines: - with self._engine_locks[engine_key]: - # double-checked locking to ensure thread safety and prevent unnecessary - # engine creation - if engine_key not in self._engines: - engine = self._create_engine( - database, - catalog, - schema, - source, - user_id, - ) - self._engines[engine_key] = engine - self._add_disposal_listener(engine, engine_key) + if engine := self._engines.get(engine_key): + return engine + + lock = self._engine_locks.get_lock(engine_key) + with lock: + # Double-check inside the lock + if engine := self._engines.get(engine_key): + return engine - return self._engines[engine_key] + # Create and cache the engine + engine = self._create_engine( + database, + catalog, + schema, + source, + user_id, + ) + self._engines[engine_key] = engine + self._add_disposal_listener(engine, engine_key) + return engine def _get_engine_key( self, @@ -284,12 +331,12 @@ class EngineManager: database.update_params_from_encrypted_extra(kwargs) # mutate URI - if mutator := current_app.config["DB_CONNECTION_MUTATOR"]: + if self.db_connection_mutator: source = source or get_query_source_from_request() # Import here to avoid circular imports from superset.extensions import security_manager - uri, kwargs = mutator( + uri, kwargs = self.db_connection_mutator( uri, kwargs, username, @@ -340,20 +387,19 @@ class EngineManager: def _get_tunnel(self, ssh_tunnel: "SSHTunnel", uri: URL) -> SSHTunnelForwarder: tunnel_key = self._get_tunnel_key(ssh_tunnel, uri) - # tunnel exists and is healthy - if tunnel_key in self._tunnels: - tunnel = self._tunnels[tunnel_key] - if tunnel.is_active: - return tunnel + tunnel = self._tunnels.get(tunnel_key) + if tunnel is not None and tunnel.is_active: + return tunnel - # create or recreate tunnel - with self._tunnel_locks[tunnel_key]: - existing_tunnel = self._tunnels.get(tunnel_key) - if existing_tunnel and existing_tunnel.is_active: - return existing_tunnel + lock = self._tunnel_locks.get_lock(tunnel_key) + with lock: + # Double-check inside the lock + tunnel = self._tunnels.get(tunnel_key) + if tunnel is not None and tunnel.is_active: + return tunnel - # replace inactive or missing tunnel - return self._replace_tunnel(tunnel_key, ssh_tunnel, uri, existing_tunnel) + # Create or replace tunnel + return self._replace_tunnel(tunnel_key, ssh_tunnel, uri, tunnel) def _replace_tunnel( self, @@ -400,15 +446,15 @@ class EngineManager: return tunnel def _get_tunnel_kwargs(self, ssh_tunnel: "SSHTunnel", uri: URL) -> dict[str, Any]: - backend = uri.get_backend_name() # Import here to avoid circular imports from superset.utils.ssh_tunnel import get_default_port + backend = uri.get_backend_name() kwargs = { "ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port), "ssh_username": ssh_tunnel.username, "remote_bind_address": (uri.host, uri.port or get_default_port(backend)), - "local_bind_address": (ssh_tunnel.local_bind_address,), + "local_bind_address": (self.local_bind_address,), "debug_level": logging.getLogger("flask_appbuilder").level, } @@ -492,45 +538,36 @@ class EngineManager: def _cleanup_abandoned_locks(self) -> None: """ - Remove locks for engines and tunnels that no longer exist. + Clean up locks for engines and tunnels that no longer exist. - This prevents memory leaks from accumulating locks in defaultdict - when engines/tunnels are disposed outside of normal cleanup paths. + This prevents memory leaks from accumulating locks when engines/tunnels + are disposed outside of normal cleanup paths. """ - # Clean up engine locks + # Clean up engine locks for inactive engines active_engine_keys = set(self._engines.keys()) - abandoned_engine_locks = set(self._engine_locks.keys()) - active_engine_keys - for key in abandoned_engine_locks: - self._engine_locks.pop(key, None) - - if abandoned_engine_locks: - logger.debug( - "Cleaned up %d abandoned engine locks", - len(abandoned_engine_locks), - ) + self._engine_locks.cleanup(active_engine_keys) - # Clean up tunnel locks + # Clean up tunnel locks for inactive tunnels active_tunnel_keys = set(self._tunnels.keys()) - abandoned_tunnel_locks = set(self._tunnel_locks.keys()) - active_tunnel_keys - for key in abandoned_tunnel_locks: - self._tunnel_locks.pop(key, None) + self._tunnel_locks.cleanup(active_tunnel_keys) - if abandoned_tunnel_locks: + # Log for debugging + if active_engine_keys or active_tunnel_keys: logger.debug( - "Cleaned up %d abandoned tunnel locks", - len(abandoned_tunnel_locks), + "EngineManager resources - Engines: %d, Tunnels: %d", + len(active_engine_keys), + len(active_tunnel_keys), ) def _add_disposal_listener(self, engine: Engine, engine_key: EngineKey) -> None: @event.listens_for(engine, "engine_disposed") def on_engine_disposed(engine_instance: Engine) -> None: try: - # `pop` is atomic -- no lock needed + # Remove engine from cache - no per-key locks to clean up anymore if self._engines.pop(engine_key, None): logger.info( "Engine disposed and removed from cache: %s", engine_key ) - self._engine_locks.pop(engine_key, None) except Exception as ex: logger.error( "Error during engine disposal cleanup for %s: %s", diff --git a/superset/extensions/__init__.py b/superset/extensions/__init__.py index 07ffbe743a3..a396e207ee9 100644 --- a/superset/extensions/__init__.py +++ b/superset/extensions/__init__.py @@ -42,7 +42,6 @@ from werkzeug.local import LocalProxy from superset.async_events.async_query_manager import AsyncQueryManager from superset.async_events.async_query_manager_factory import AsyncQueryManagerFactory from superset.extensions.engine_manager import EngineManagerExtension -from superset.extensions.ssh import SSHManagerFactory from superset.extensions.stats_logger import BaseStatsLoggerManager from superset.security.manager import SupersetSecurityManager from superset.utils.cache_manager import CacheManager @@ -148,6 +147,5 @@ migrate = Migrate() profiling = ProfilingExtension() results_backend_manager = ResultsBackendManager() security_manager: SupersetSecurityManager = LocalProxy(lambda: appbuilder.sm) -ssh_manager_factory = SSHManagerFactory() stats_logger_manager = BaseStatsLoggerManager() talisman = Talisman() diff --git a/superset/extensions/engine_manager.py b/superset/extensions/engine_manager.py index c341364e2a5..5a4cd0301b5 100644 --- a/superset/extensions/engine_manager.py +++ b/superset/extensions/engine_manager.py @@ -16,6 +16,7 @@ # under the License. import logging +from datetime import timedelta from typing import TYPE_CHECKING from flask import Flask @@ -44,13 +45,25 @@ class EngineManagerExtension: """ Initialize the EngineManager with Flask app configuration. """ - # Get configuration values with defaults + engine_context_manager = app.config["ENGINE_CONTEXT_MANAGER"] + db_connection_mutator = app.config["DB_CONNECTION_MUTATOR"] mode = app.config["ENGINE_MANAGER_MODE"] cleanup_interval = app.config["ENGINE_MANAGER_CLEANUP_INTERVAL"] + local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"] + tunnel_timeout = timedelta(seconds=app.config["SSH_TUNNEL_TIMEOUT_SEC"]) + ssh_timeout = timedelta(seconds=app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"]) auto_start_cleanup = app.config["ENGINE_MANAGER_AUTO_START_CLEANUP"] # Create the engine manager - self.engine_manager = EngineManager(mode, cleanup_interval) + self.engine_manager = EngineManager( + engine_context_manager, + db_connection_mutator, + mode, + cleanup_interval, + local_bind_address, + tunnel_timeout, + ssh_timeout, + ) # Start cleanup thread if requested and in SINGLETON mode if auto_start_cleanup and mode == EngineModes.SINGLETON: diff --git a/superset/extensions/ssh.py b/superset/extensions/ssh.py deleted file mode 100644 index 74fb44cfd75..00000000000 --- a/superset/extensions/ssh.py +++ /dev/null @@ -1,94 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. - -import logging -from io import StringIO -from typing import TYPE_CHECKING - -import sshtunnel -from flask import Flask -from paramiko import RSAKey - -from superset.commands.database.ssh_tunnel.exceptions import SSHTunnelDatabasePortError -from superset.databases.utils import make_url_safe -from superset.utils.class_utils import load_class_from_name - -if TYPE_CHECKING: - from superset.databases.ssh_tunnel.models import SSHTunnel - - -class SSHManager: - def __init__(self, app: Flask) -> None: - super().__init__() - self.local_bind_address = app.config["SSH_TUNNEL_LOCAL_BIND_ADDRESS"] - sshtunnel.TUNNEL_TIMEOUT = app.config["SSH_TUNNEL_TIMEOUT_SEC"] - sshtunnel.SSH_TIMEOUT = app.config["SSH_TUNNEL_PACKET_TIMEOUT_SEC"] - - def build_sqla_url( - self, sqlalchemy_url: str, server: sshtunnel.SSHTunnelForwarder - ) -> str: - # override any ssh tunnel configuration object - url = make_url_safe(sqlalchemy_url) - return url.set( - host=server.local_bind_address[0], - port=server.local_bind_port, - ) - - def create_tunnel( - self, - ssh_tunnel: "SSHTunnel", - sqlalchemy_database_uri: str, - ) -> sshtunnel.SSHTunnelForwarder: - from superset.utils.ssh_tunnel import get_default_port - - url = make_url_safe(sqlalchemy_database_uri) - backend = url.get_backend_name() - port = url.port or get_default_port(backend) - if not port: - raise SSHTunnelDatabasePortError() - params = { - "ssh_address_or_host": (ssh_tunnel.server_address, ssh_tunnel.server_port), - "ssh_username": ssh_tunnel.username, - "remote_bind_address": (url.host, port), - "local_bind_address": (self.local_bind_address,), - "debug_level": logging.getLogger("flask_appbuilder").level, - } - - if ssh_tunnel.password: - params["ssh_password"] = ssh_tunnel.password - elif ssh_tunnel.private_key: - private_key_file = StringIO(ssh_tunnel.private_key) - private_key = RSAKey.from_private_key( - private_key_file, ssh_tunnel.private_key_password - ) - params["ssh_pkey"] = private_key - - return sshtunnel.open_tunnel(**params) - - -class SSHManagerFactory: - def __init__(self) -> None: - self._ssh_manager = None - - def init_app(self, app: Flask) -> None: - self._ssh_manager = load_class_from_name( - app.config["SSH_TUNNEL_MANAGER_CLASS"] - )(app) - - @property - def instance(self) -> SSHManager: - return self._ssh_manager # type: ignore diff --git a/superset/initialization/__init__.py b/superset/initialization/__init__.py index f18149d61f8..4cefa0c7337 100644 --- a/superset/initialization/__init__.py +++ b/superset/initialization/__init__.py @@ -56,7 +56,6 @@ from superset.extensions import ( migrate, profiling, results_backend_manager, - ssh_manager_factory, stats_logger_manager, talisman, ) @@ -588,7 +587,6 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods self.configure_auth_provider() self.configure_engine_manager() self.configure_async_queries() - self.configure_ssh_manager() self.configure_stats_manager() # Hook that provides administrators a handle on the Flask APP @@ -766,9 +764,6 @@ class SupersetAppInitializer: # pylint: disable=too-many-public-methods def configure_engine_manager(self) -> None: engine_manager_extension.init_app(self.superset_app) - def configure_ssh_manager(self) -> None: - ssh_manager_factory.init_app(self.superset_app) - def configure_stats_manager(self) -> None: stats_logger_manager.init_app(self.superset_app) diff --git a/superset/models/core.py b/superset/models/core.py index c6dcd9ad05b..04d3c533bd4 100755 --- a/superset/models/core.py +++ b/superset/models/core.py @@ -30,7 +30,7 @@ from copy import deepcopy from datetime import datetime from functools import lru_cache from inspect import signature -from typing import Any, Callable, cast, Optional, TYPE_CHECKING +from typing import Any, Callable, cast, Iterator, Optional, TYPE_CHECKING import numpy import pandas as pd @@ -136,7 +136,9 @@ class ConfigurationMethod(StrEnum): DYNAMIC_FORM = "dynamic_form" -class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: disable=too-many-public-methods +class Database( + CoreDatabase, AuditMixinNullable, ImportExportMixin +): # pylint: disable=too-many-public-methods """An ORM object that stores Database related information""" __tablename__ = "dbs" @@ -413,9 +415,7 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: return ( username if (username := get_username()) - else object_url.username - if self.impersonate_user - else None + else object_url.username if self.impersonate_user else None ) @contextmanager @@ -424,7 +424,7 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: catalog: str | None = None, schema: str | None = None, source: utils.QuerySource | None = None, - ) -> Engine: + ) -> Iterator[Engine]: """ Context manager for a SQLAlchemy engine. @@ -437,12 +437,13 @@ class Database(CoreDatabase, AuditMixinNullable, ImportExportMixin): # pylint: # Use the engine manager to get the engine engine_manager = engine_manager_extension.manager - return engine_manager.get_engine( + with engine_manager.get_engine( database=self, catalog=catalog, schema=schema, source=source, - ) + ) as engine: + yield engine def add_database_to_signature( self, diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 02e294a08cf..5a69250b8bb 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -18,17 +18,44 @@ from __future__ import annotations from collections.abc import Hashable, Sequence from datetime import datetime -from typing import Any, Literal, TYPE_CHECKING, TypeAlias, TypedDict +from typing import ( + Any, + Callable, + ContextManager, + Literal, + TYPE_CHECKING, + TypeAlias, + TypedDict, +) +from sqlalchemy.engine.url import URL from sqlalchemy.sql.type_api import TypeEngine from typing_extensions import NotRequired from werkzeug.wrappers import Response if TYPE_CHECKING: - from superset.utils.core import GenericDataType, QueryObjectFilterClause + from superset.utils.core import ( + GenericDataType, + QueryObjectFilterClause, + QuerySource, + ) SQLType: TypeAlias = TypeEngine | type[TypeEngine] +# Type alias for database connection mutator function +DBConnectionMutator: TypeAlias = Callable[ + [URL, dict[str, Any], str | None, Any, "QuerySource | None"], + tuple[URL, dict[str, Any]], +] + +# Type alias for engine context manager +if TYPE_CHECKING: + from superset.models.core import Database + +EngineContextManager: TypeAlias = Callable[ + ["Database", str | None, str | None], ContextManager[None] +] + class LegacyMetric(TypedDict): label: str | None diff --git a/tests/unit_tests/extensions/ssh_test.py b/tests/unit_tests/extensions/ssh_test.py deleted file mode 100644 index a36f0fe03eb..00000000000 --- a/tests/unit_tests/extensions/ssh_test.py +++ /dev/null @@ -1,36 +0,0 @@ -# Licensed to the Apache Software Foundation (ASF) under one -# or more contributor license agreements. See the NOTICE file -# distributed with this work for additional information -# regarding copyright ownership. The ASF licenses this file -# to you under the Apache License, Version 2.0 (the -# "License"); you may not use this file except in compliance -# with the License. You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, -# software distributed under the License is distributed on an -# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY -# KIND, either express or implied. See the License for the -# specific language governing permissions and limitations -# under the License. -from unittest.mock import Mock - -import sshtunnel - -from superset.extensions.ssh import SSHManagerFactory - - -def test_ssh_tunnel_timeout_setting() -> None: - app = Mock() - app.config = { - "SSH_TUNNEL_MAX_RETRIES": 2, - "SSH_TUNNEL_LOCAL_BIND_ADDRESS": "test", - "SSH_TUNNEL_TIMEOUT_SEC": 123.0, - "SSH_TUNNEL_PACKET_TIMEOUT_SEC": 321.0, - "SSH_TUNNEL_MANAGER_CLASS": "superset.extensions.ssh.SSHManager", - } - factory = SSHManagerFactory() - factory.init_app(app) - assert sshtunnel.TUNNEL_TIMEOUT == 123.0 - assert sshtunnel.SSH_TIMEOUT == 321.0
