This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch engine-manager
in repository https://gitbox.apache.org/repos/asf/superset.git

commit 5753dfbb6ed79f09b07a70b50cf25a21343dba1f
Author: Beto Dealmeida <[email protected]>
AuthorDate: Tue Jul 29 19:01:09 2025 -0400

    feat: engine manager
---
 superset/engines/__init__.py |  16 ++
 superset/engines/manager.py  | 403 +++++++++++++++++++++++++++++++++++++++++++
 2 files changed, 419 insertions(+)

diff --git a/superset/engines/__init__.py b/superset/engines/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/superset/engines/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/superset/engines/manager.py b/superset/engines/manager.py
new file mode 100644
index 00000000000..04ebe48b0b7
--- /dev/null
+++ b/superset/engines/manager.py
@@ -0,0 +1,403 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+import enum
+import logging
+import threading
+from collections import defaultdict
+from contextlib import contextmanager
+from io import StringIO
+from typing import Any, TYPE_CHECKING
+
+from flask import current_app
+from paramiko import RSAKey
+from sqlalchemy import create_engine, event, pool
+from sqlalchemy.engine import Engine
+from sqlalchemy.engine.url import URL
+from sshtunnel import SSHTunnelForwarder
+
+from superset import is_feature_enabled
+from superset.databases.ssh_tunnel.models import SSHTunnel
+from superset.databases.utils import make_url_safe
+from superset.extensions import security_manager
+from superset.utils.core import get_query_source_from_request, get_user_id, 
QuerySource
+from superset.utils.json import dumps
+from superset.utils.oauth2 import check_for_oauth2, get_oauth2_access_token
+from superset.utils.ssh_tunnel import get_default_port
+
+if TYPE_CHECKING:
+    from superset.models.core import Database
+
+
+logger = logging.getLogger(__name__)
+
+
+EngineKey = str
+TunnelKey = str
+
+
+class EngineModes(enum.Enum):
+    # reuse existing engine if available, otherwise create a new one; this 
mode should
+    # have a connection pool configured in the database
+    SINGLETON = enum.auto()
+
+    # always create a new engine for every connection; this mode will use a 
NullPool
+    # and is the default behavior for Superset
+    NEW = enum.auto()
+
+
+class EngineManager:
+    """
+    A manager for SQLAlchemy engines.
+
+    This class handles the creation and management of SQLAlchemy engines, 
allowing them
+    to be configured with connection pools and reused across requests. The 
default mode
+    is the default behavior for Superset, where we create a new engine for 
every
+    connection, using a NullPool. The `SINGLETON` mode allows for reusing of 
the
+    engines, as well as configuring the pool through the database settings.
+    """
+
+    def __init__(self, mode: EngineModes = EngineModes.NEW) -> None:
+        self.mode = mode
+
+        self._engines: dict[EngineKey, Engine] = {}
+        self._engine_locks: dict[EngineKey, threading.Lock] = defaultdict(
+            threading.Lock
+        )
+
+        self._tunnels: dict[TunnelKey, SSHTunnelForwarder] = {}
+        self._tunnel_locks: dict[TunnelKey, threading.Lock] = defaultdict(
+            threading.Lock
+        )
+
+    @contextmanager
+    def get_engine(
+        self,
+        database: Database,
+        catalog: str | None,
+        schema: str | None,
+        source: QuerySource | None,
+    ) -> Engine:
+        """
+        Context manager to get a SQLAlchemy engine.
+        """
+        # users can wrap the engine in their own context manager for different
+        # reasons
+        customization = current_app.config["ENGINE_CONTEXT_MANAGER"]
+        with customization(database, catalog, schema):
+            # we need to check for errors indicating that OAuth2 is needed, and
+            # return the proper exception so it starts the authentication flow
+            with check_for_oauth2(database):
+                yield self._get_engine(database, catalog, schema, source)
+
+    def _get_engine(
+        self,
+        database: Database,
+        catalog: str | None,
+        schema: str | None,
+        source: QuerySource | None,
+    ) -> Engine:
+        """
+        Get a specific engine, or create it if none exists.
+        """
+        source = source or get_query_source_from_request()
+        user_id = get_user_id()
+
+        if self.mode == EngineModes.NEW:
+            return self._create_engine(
+                database,
+                catalog,
+                schema,
+                source,
+                user_id,
+            )
+
+        engine_key = self._get_engine_key(
+            database,
+            catalog,
+            schema,
+            source,
+            user_id,
+        )
+
+        if engine_key not in self._engines:
+            with self._engine_locks[engine_key]:
+                # double-checked locking to ensure thread safety and prevent 
unnecessary
+                # engine creation
+                if engine_key not in self._engines:
+                    engine = self._create_engine(
+                        database,
+                        catalog,
+                        schema,
+                        source,
+                        user_id,
+                    )
+                    self._engines[engine_key] = engine
+                    self._add_disposal_listener(engine, engine_key)
+
+        return self._engines[engine_key]
+
+    def _get_engine_key(
+        self,
+        database: Database,
+        catalog: str | None,
+        schema: str | None,
+        source: QuerySource | None,
+        user_id: int | None,
+    ) -> EngineKey:
+        """
+        Generate a unique key for the engine based on the database and context.
+        """
+        uri, keys = self._get_engine_args(
+            database,
+            catalog,
+            schema,
+            source,
+            user_id,
+        )
+        keys["uri"] = uri
+        keys["source"] = source
+
+        return dumps(keys, sort_keys=True)
+
+    def _get_engine_args(
+        self,
+        database: Database,
+        catalog: str | None,
+        schema: str | None,
+        source: QuerySource | None,
+        user_id: int | None,
+    ) -> tuple[URL, dict[str, Any]]:
+        """
+        Build the almost final SQLAlchemy URI and engine kwargs.
+
+        "Almost" final because we may still need to mutate the URI if an SSH 
tunnel is
+        needed, since it needs to connect to the tunnel instead of the 
original DB. But
+        that information is only available after the tunnel is created.
+        """
+        uri = make_url_safe(database.sqlalchemy_uri_decrypted)
+
+        extra = database.get_extra(source)
+        kwargs = extra.get("engine_params", {})
+
+        # get pool class
+        if self.mode == EngineModes.NEW or "poolclass" not in extra:
+            kwargs["poolclass"] = pool.NullPool
+        else:
+            pools = {
+                "queue": pool.QueuePool,
+                "singleton": pool.SingletonThreadPool,
+                "assertion": pool.AssertionPool,
+                "null": pool.NullPool,
+                "static": pool.StaticPool,
+            }
+            kwargs["poolclass"] = pools.get(extra["poolclass"], pool.QueuePool)
+
+        # update URI for specific catalog/schema
+        connect_args = extra.setdefault("connect_args", {})
+        uri, connect_args = database.db_engine_spec.adjust_engine_params(
+            uri,
+            connect_args,
+            catalog,
+            schema,
+        )
+
+        # get effective username
+        username = database.get_effective_user(uri)
+        if username and is_feature_enabled("IMPERSONATE_WITH_EMAIL_PREFIX"):
+            user = security_manager.find_user(username=username)
+            if user and user.email and "@" in user.email:
+                username = user.email.split("@")[0]
+
+        # update URI/kwargs for user impersonation
+        if database.impersonate_user:
+            oauth2_config = database.get_oauth2_config()
+            access_token = (
+                get_oauth2_access_token(
+                    oauth2_config,
+                    database.id,
+                    user_id,
+                    database.db_engine_spec,
+                )
+                if oauth2_config and user_id
+                else None
+            )
+
+            uri, kwargs = database.db_engine_spec.impersonate_user(
+                database,
+                username,
+                access_token,
+                uri,
+                kwargs,
+            )
+
+        # update kwargs from params stored encrupted at rest
+        database.update_params_from_encrypted_extra(kwargs)
+
+        # mutate URI
+        if mutator := current_app.config["DB_CONNECTION_MUTATOR"]:
+            source = source or get_query_source_from_request()
+            uri, kwargs = mutator(
+                uri,
+                kwargs,
+                username,
+                security_manager,
+                source,
+            )
+
+        # validate final URI
+        database.db_engine_spec.validate_database_uri(uri)
+
+        return uri, kwargs
+
+    def _create_engine(
+        self,
+        database: Database,
+        catalog: str | None,
+        schema: str | None,
+        source: QuerySource | None,
+        user_id: int | None,
+    ) -> Engine:
+        """
+        Create the actual engine.
+
+        This should be the only place in Superset where a SQLAlchemy engine is 
created,
+        """
+        uri, kwargs = self._get_engine_args(
+            database,
+            catalog,
+            schema,
+            source,
+            user_id,
+        )
+
+        tunnel = None
+        if database.ssh_tunnel:
+            tunnel = self._get_tunnel(database.ssh_tunnel, uri)
+            uri = uri.set(
+                host=tunnel.local_bind_address[0],
+                port=tunnel.local_bind_port,
+            )
+
+        try:
+            engine = create_engine(uri, **kwargs)
+        except Exception as ex:
+            raise database.db_engine_spec.get_dbapi_mapped_exception(ex) from 
ex
+
+        return engine
+
+    def _get_tunnel(self, ssh_tunnel: SSHTunnel, uri: URL) -> 
SSHTunnelForwarder:
+        tunnel_key = self._get_tunnel_key(ssh_tunnel, uri)
+
+        #  tunnel exists and is healthy
+        if tunnel_key in self._tunnels:
+            tunnel = self._tunnels[tunnel_key]
+            if tunnel.is_active:
+                return tunnel
+
+        # create or recreate tunnel
+        with self._tunnel_locks[tunnel_key]:
+            existing_tunnel = self._tunnels.get(tunnel_key)
+            if existing_tunnel and existing_tunnel.is_active:
+                return existing_tunnel
+
+            # replace inactive or missing tunnel
+            return self._replace_tunnel(tunnel_key, ssh_tunnel, uri, 
existing_tunnel)
+
+    def _replace_tunnel(
+        self,
+        tunnel_key: str,
+        ssh_tunnel: SSHTunnel,
+        uri: URL,
+        old_tunnel: SSHTunnelForwarder | None,
+    ) -> SSHTunnelForwarder:
+        """
+        Replace tunnel with proper cleanup.
+
+        This function assumes caller holds lock.
+        """
+        if old_tunnel:
+            try:
+                old_tunnel.stop()
+            except Exception:
+                logger.exception("Error stopping old tunnel")
+
+        try:
+            new_tunnel = self._create_tunnel(ssh_tunnel, uri)
+            self._tunnels[tunnel_key] = new_tunnel
+        except Exception:
+            # Remove failed tunnel from cache
+            self._tunnels.pop(tunnel_key, None)
+            logger.exception("Failed to create tunnel")
+            raise
+
+        return new_tunnel
+
+    def _get_tunnel_key(self, ssh_tunnel: SSHTunnel, uri: URL) -> TunnelKey:
+        """
+        Build a unique key for the SSH tunnel.
+        """
+        keys = self._get_tunnel_kwargs(ssh_tunnel, uri)
+
+        return dumps(keys, sort_keys=True)
+
+    def _create_tunnel(self, ssh_tunnel: SSHTunnel, uri: URL) -> 
SSHTunnelForwarder:
+        kwargs = self._get_tunnel_kwargs(ssh_tunnel, uri)
+        tunnel = SSHTunnelForwarder(**kwargs)
+        tunnel.start()
+
+        return tunnel
+
+    def _get_tunnel_kwargs(self, ssh_tunnel: SSHTunnel, uri: URL) -> dict[str, 
Any]:
+        backend = uri.get_backend_name()
+        kwargs = {
+            "ssh_address_or_host": (ssh_tunnel.server_address, 
ssh_tunnel.server_port),
+            "ssh_username": ssh_tunnel.username,
+            "remote_bind_address": (uri.host, uri.port or 
get_default_port(backend)),
+            "local_bind_address": (ssh_tunnel.local_bind_address,),
+            "debug_level": logging.getLogger("flask_appbuilder").level,
+        }
+
+        if ssh_tunnel.password:
+            kwargs["ssh_password"] = ssh_tunnel.password
+        elif ssh_tunnel.private_key:
+            private_key_file = StringIO(ssh_tunnel.private_key)
+            private_key = RSAKey.from_private_key(
+                private_key_file,
+                ssh_tunnel.private_key_password,
+            )
+            kwargs["ssh_pkey"] = private_key
+
+        if self.mode == EngineModes.NEW:
+            kwargs["keepalive"] = 0  # disable
+
+        return kwargs
+
+    def _add_disposal_listener(self, engine: Engine, engine_key: EngineKey) -> 
None:
+        @event.listens_for(engine, "engine_disposed")
+        def on_engine_disposed(engine_instance: Engine) -> None:
+            try:
+                # `pop` is atomic -- no lock needed
+                if self._engines.pop(engine_key, None):
+                    logger.info(f"Engine disposed and removed from cache: 
{engine_key}")
+                    self._engine_locks.pop(engine_key, None)
+            except Exception as ex:
+                logger.error(
+                    "Error during engine disposal cleanup for %s: %s",
+                    engine_key,
+                    str(ex),
+                )

Reply via email to