This is an automated email from the ASF dual-hosted git repository.

beto pushed a commit to branch engine-manager
in repository https://gitbox.apache.org/repos/asf/superset.git

commit de8c250f86d67a008dfec9d8d3fa72e7e1cd261e
Author: Beto Dealmeida <[email protected]>
AuthorDate: Wed Dec 3 20:00:29 2025 -0500

    Update existing tests
---
 superset/extensions/engine_manager.py              |   6 -
 tests/integration_tests/conftest.py                |   1 -
 .../integration_tests/databases/commands_tests.py  |  22 +-
 tests/unit_tests/engines/manager_test.py           | 232 +++++++++++++++++++++
 tests/unit_tests/initialization_test.py            |   2 +-
 tests/unit_tests/models/core_test.py               | 167 ---------------
 6 files changed, 244 insertions(+), 186 deletions(-)

diff --git a/superset/extensions/engine_manager.py 
b/superset/extensions/engine_manager.py
index df391ad5d8e..e15ead09b43 100644
--- a/superset/extensions/engine_manager.py
+++ b/superset/extensions/engine_manager.py
@@ -70,12 +70,6 @@ class EngineManagerExtension:
         def shutdown_engine_manager() -> None:
             if self.engine_manager:
                 self.engine_manager.stop_cleanup_thread()
-                # Use a try-except to handle closed log file handlers during 
tests
-                try:
-                    logger.info("Stopped EngineManager cleanup thread")
-                except ValueError:
-                    # Ignore logging errors during test shutdown when file 
handles are closed
-                    pass
 
         app.teardown_appcontext_funcs.append(lambda exc: None)
 
diff --git a/tests/integration_tests/conftest.py 
b/tests/integration_tests/conftest.py
index 4f6ce10b0f9..95f1015e85f 100644
--- a/tests/integration_tests/conftest.py
+++ b/tests/integration_tests/conftest.py
@@ -170,7 +170,6 @@ def example_db_provider() -> Callable[[], Database]:
             return self._db
 
         def _load_lazy_data_to_decouple_from_session(self) -> None:
-            self._db._get_sqla_engine()  # type: ignore
             self._db.backend  # type: ignore  # noqa: B018
 
         def remove(self) -> None:
diff --git a/tests/integration_tests/databases/commands_tests.py 
b/tests/integration_tests/databases/commands_tests.py
index 27c1ce56542..2d43ee14d0f 100644
--- a/tests/integration_tests/databases/commands_tests.py
+++ b/tests/integration_tests/databases/commands_tests.py
@@ -897,7 +897,7 @@ class TestImportDatabasesCommand(SupersetTestCase):
 
 
 class TestTestConnectionDatabaseCommand(SupersetTestCase):
-    @patch("superset.models.core.Database._get_sqla_engine")
+    @patch("superset.models.core.Database.get_sqla_engine")
     
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
     @patch("superset.utils.core.g")
     def test_connection_db_exception(
@@ -906,19 +906,19 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
         """Test to make sure event_logger is called when an exception is 
raised"""
         database = get_example_database()
         mock_g.user = security_manager.find_user("admin")
-        mock_get_sqla_engine.side_effect = Exception("An error has occurred!")
+        mock_get_sqla_engine.__enter__.side_effect = Exception("An error has 
occurred!")
         db_uri = database.sqlalchemy_uri_decrypted
         json_payload = {"sqlalchemy_uri": db_uri}
         command_without_db_name = TestConnectionDatabaseCommand(json_payload)
 
         with pytest.raises(DatabaseTestConnectionUnexpectedError) as excinfo:  
# noqa: PT012
             command_without_db_name.run()
-            assert str(excinfo.value) == (
-                "Unexpected error occurred, please check your logs for details"
-            )
+        assert str(excinfo.value) == (
+            "Unexpected error occurred, please check your logs for details"
+        )
         mock_event_logger.assert_called()
 
-    @patch("superset.models.core.Database._get_sqla_engine")
+    @patch("superset.models.core.Database.get_sqla_engine")
     
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
     @patch("superset.utils.core.g")
     def test_connection_do_ping_exception(
@@ -927,7 +927,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
         """Test to make sure do_ping exceptions gets captured"""
         database = get_example_database()
         mock_g.user = security_manager.find_user("admin")
-        mock_get_sqla_engine.return_value.dialect.do_ping.side_effect = 
Exception(
+        mock_get_sqla_engine.__enter__().dialect.do_ping.side_effect = 
Exception(
             "An error has occurred!"
         )
         db_uri = database.sqlalchemy_uri_decrypted
@@ -967,7 +967,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
             == SupersetErrorType.CONNECTION_DATABASE_TIMEOUT
         )
 
-    @patch("superset.models.core.Database._get_sqla_engine")
+    @patch("superset.models.core.Database.get_sqla_engine")
     
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
     @patch("superset.utils.core.g")
     def test_connection_superset_security_connection(
@@ -977,7 +977,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
         connection exc is raised"""
         database = get_example_database()
         mock_g.user = security_manager.find_user("admin")
-        mock_get_sqla_engine.side_effect = SupersetSecurityException(
+        mock_get_sqla_engine.__enter__.side_effect = SupersetSecurityException(
             SupersetError(error_type=500, message="test", level="info")
         )
         db_uri = database.sqlalchemy_uri_decrypted
@@ -990,7 +990,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
 
         mock_event_logger.assert_called()
 
-    @patch("superset.models.core.Database._get_sqla_engine")
+    @patch("superset.models.core.Database.get_sqla_engine")
     
@patch("superset.commands.database.test_connection.event_logger.log_with_context")
     @patch("superset.utils.core.g")
     def test_connection_db_api_exc(
@@ -999,7 +999,7 @@ class TestTestConnectionDatabaseCommand(SupersetTestCase):
         """Test to make sure event_logger is called when DBAPIError is 
raised"""
         database = get_example_database()
         mock_g.user = security_manager.find_user("admin")
-        mock_get_sqla_engine.side_effect = DBAPIError(
+        mock_get_sqla_engine.__enter__.side_effect = DBAPIError(
             statement="error", params={}, orig={}
         )
         db_uri = database.sqlalchemy_uri_decrypted
diff --git a/tests/unit_tests/engines/manager_test.py 
b/tests/unit_tests/engines/manager_test.py
index 287820eaf1b..871624f2a42 100644
--- a/tests/unit_tests/engines/manager_test.py
+++ b/tests/unit_tests/engines/manager_test.py
@@ -96,7 +96,9 @@ class TestEngineManager:
     @pytest.fixture
     def engine_manager(self):
         """Create a mock EngineManager instance."""
+        from contextlib import contextmanager
 
+        @contextmanager
         def dummy_context_manager(
             database: MagicMock, catalog: str | None, schema: str | None
         ) -> Iterator[None]:
@@ -293,3 +295,233 @@ class TestEngineManager:
         result2 = engine_manager._get_tunnel(ssh_tunnel, uri)
         assert result2 is active_tunnel
         assert mock_tunnel_class.call_count == 2
+
+    @patch("superset.engines.manager.create_engine")
+    @patch("superset.engines.manager.make_url_safe")
+    def test_get_engine_args_basic(
+        self, mock_make_url, mock_create_engine, engine_manager
+    ):
+        """Test _get_engine_args returns correct URI and kwargs."""
+        from sqlalchemy.engine.url import make_url
+
+        from superset.engines.manager import EngineModes
+
+        engine_manager.mode = EngineModes.NEW
+
+        mock_uri = make_url("trino://")
+        mock_make_url.return_value = mock_uri
+
+        database = MagicMock()
+        database.id = 1
+        database.sqlalchemy_uri_decrypted = "trino://"
+        database.get_extra.return_value = {
+            "engine_params": {},
+            "connect_args": {"source": "Apache Superset"},
+        }
+        database.get_effective_user.return_value = "alice"
+        database.impersonate_user = False
+        database.update_params_from_encrypted_extra = MagicMock()
+        database.db_engine_spec = MagicMock()
+        database.db_engine_spec.adjust_engine_params.return_value = (
+            mock_uri,
+            {"source": "Apache Superset"},
+        )
+        database.db_engine_spec.validate_database_uri = MagicMock()
+
+        uri, kwargs = engine_manager._get_engine_args(database, None, None, 
None, None)
+
+        assert str(uri) == "trino://"
+        assert "connect_args" in database.get_extra.return_value
+
+    @patch("superset.engines.manager.create_engine")
+    @patch("superset.engines.manager.make_url_safe")
+    def test_get_engine_args_user_impersonation(
+        self, mock_make_url, mock_create_engine, engine_manager
+    ):
+        """Test user impersonation in _get_engine_args."""
+        from sqlalchemy.engine.url import make_url
+
+        from superset.engines.manager import EngineModes
+
+        engine_manager.mode = EngineModes.NEW
+
+        mock_uri = make_url("trino://")
+        mock_make_url.return_value = mock_uri
+
+        database = MagicMock()
+        database.id = 1
+        database.sqlalchemy_uri_decrypted = "trino://"
+        database.get_extra.return_value = {
+            "engine_params": {},
+            "connect_args": {"source": "Apache Superset"},
+        }
+        database.get_effective_user.return_value = "alice"
+        database.impersonate_user = True
+        database.get_oauth2_config.return_value = None
+        database.update_params_from_encrypted_extra = MagicMock()
+        database.db_engine_spec = MagicMock()
+        database.db_engine_spec.adjust_engine_params.return_value = (
+            mock_uri,
+            {"source": "Apache Superset"},
+        )
+        database.db_engine_spec.impersonate_user.return_value = (
+            mock_uri,
+            {"connect_args": {"user": "alice", "source": "Apache Superset"}},
+        )
+        database.db_engine_spec.validate_database_uri = MagicMock()
+
+        uri, kwargs = engine_manager._get_engine_args(database, None, None, 
None, None)
+
+        # Verify impersonate_user was called
+        database.db_engine_spec.impersonate_user.assert_called_once()
+        call_args = database.db_engine_spec.impersonate_user.call_args
+        assert call_args[0][0] is database  # database
+        assert call_args[0][1] == "alice"  # username
+        assert call_args[0][2] is None  # access_token (no OAuth2)
+
+    @patch("superset.engines.manager.create_engine")
+    @patch("superset.engines.manager.make_url_safe")
+    def test_get_engine_args_user_impersonation_email_prefix(
+        self,
+        mock_make_url,
+        mock_create_engine,
+        engine_manager,
+    ):
+        """Test user impersonation with IMPERSONATE_WITH_EMAIL_PREFIX feature 
flag."""
+        from sqlalchemy.engine.url import make_url
+
+        from superset.engines.manager import EngineModes
+
+        engine_manager.mode = EngineModes.NEW
+
+        mock_uri = make_url("trino://")
+        mock_make_url.return_value = mock_uri
+
+        # Mock user with email
+        mock_user = MagicMock()
+        mock_user.email = "[email protected]"
+
+        database = MagicMock()
+        database.id = 1
+        database.sqlalchemy_uri_decrypted = "trino://"
+        database.get_extra.return_value = {
+            "engine_params": {},
+            "connect_args": {"source": "Apache Superset"},
+        }
+        database.get_effective_user.return_value = "alice"
+        database.impersonate_user = True
+        database.get_oauth2_config.return_value = None
+        database.update_params_from_encrypted_extra = MagicMock()
+        database.db_engine_spec = MagicMock()
+        database.db_engine_spec.adjust_engine_params.return_value = (
+            mock_uri,
+            {"source": "Apache Superset"},
+        )
+        database.db_engine_spec.impersonate_user.return_value = (
+            mock_uri,
+            {"connect_args": {"user": "alice.doe", "source": "Apache 
Superset"}},
+        )
+        database.db_engine_spec.validate_database_uri = MagicMock()
+
+        with (
+            patch(
+                
"superset.utils.feature_flag_manager.FeatureFlagManager.is_feature_enabled",
+                return_value=True,
+            ),
+            patch(
+                "superset.extensions.security_manager.find_user",
+                return_value=mock_user,
+            ),
+        ):
+            uri, kwargs = engine_manager._get_engine_args(
+                database, None, None, None, None
+            )
+
+        # Verify impersonate_user was called with the email prefix
+        database.db_engine_spec.impersonate_user.assert_called_once()
+        call_args = database.db_engine_spec.impersonate_user.call_args
+        assert call_args[0][1] == "alice.doe"  # username from email prefix
+
+    @patch("superset.engines.manager.create_engine")
+    @patch("superset.engines.manager.make_url_safe")
+    def test_engine_context_manager_called(
+        self, mock_make_url, mock_create_engine, engine_manager, mock_database
+    ):
+        """Test that the engine context manager is properly called."""
+        from sqlalchemy.engine.url import make_url
+
+        mock_uri = make_url("trino://")
+        mock_make_url.return_value = mock_uri
+        mock_engine = MagicMock()
+        mock_create_engine.return_value = mock_engine
+
+        # Track context manager calls
+        context_manager_calls = []
+
+        def tracking_context_manager(database, catalog, schema):
+            from contextlib import contextmanager
+
+            @contextmanager
+            def inner():
+                context_manager_calls.append(("enter", database, catalog, 
schema))
+                yield
+                context_manager_calls.append(("exit", database, catalog, 
schema))
+
+            return inner()
+
+        engine_manager.engine_context_manager = tracking_context_manager
+
+        with engine_manager.get_engine(mock_database, "catalog1", "schema1", 
None):
+            pass
+
+        assert len(context_manager_calls) == 2
+        assert context_manager_calls[0][0] == "enter"
+        assert context_manager_calls[0][1] is mock_database
+        assert context_manager_calls[0][2] == "catalog1"
+        assert context_manager_calls[0][3] == "schema1"
+        assert context_manager_calls[1][0] == "exit"
+
+    @patch("superset.utils.oauth2.check_for_oauth2")
+    @patch("superset.engines.manager.create_engine")
+    @patch("superset.engines.manager.make_url_safe")
+    def test_engine_oauth2_error_handling(
+        self,
+        mock_make_url,
+        mock_create_engine,
+        mock_check_for_oauth2,
+        engine_manager,
+        mock_database,
+    ):
+        """Test that OAuth2 errors are properly propagated from get_engine."""
+        from contextlib import contextmanager
+
+        from sqlalchemy.engine.url import make_url
+
+        mock_uri = make_url("trino://")
+        mock_make_url.return_value = mock_uri
+
+        # Simulate OAuth2 error during engine creation
+        class OAuth2TestError(Exception):
+            pass
+
+        oauth_error = OAuth2TestError("OAuth2 required")
+        mock_create_engine.side_effect = oauth_error
+
+        # Make get_dbapi_mapped_exception return the original exception
+        mock_database.db_engine_spec.get_dbapi_mapped_exception.return_value = 
(
+            oauth_error
+        )
+
+        # Mock check_for_oauth2 to re-raise the exception
+        @contextmanager
+        def mock_oauth2_context(database):
+            try:
+                yield
+            except OAuth2TestError:
+                raise
+
+        mock_check_for_oauth2.return_value = mock_oauth2_context(mock_database)
+
+        with pytest.raises(OAuth2TestError, match="OAuth2 required"):
+            with engine_manager.get_engine(mock_database, "catalog1", 
"schema1", None):
+                pass
diff --git a/tests/unit_tests/initialization_test.py 
b/tests/unit_tests/initialization_test.py
index 01fde0967c9..93fdf4d352e 100644
--- a/tests/unit_tests/initialization_test.py
+++ b/tests/unit_tests/initialization_test.py
@@ -123,7 +123,7 @@ class TestSupersetAppInitializer:
             patch.object(app_initializer, "configure_data_sources"),
             patch.object(app_initializer, "configure_auth_provider"),
             patch.object(app_initializer, "configure_async_queries"),
-            patch.object(app_initializer, "configure_ssh_manager"),
+            patch.object(app_initializer, "configure_engine_manager"),
             patch.object(app_initializer, "configure_stats_manager"),
             patch.object(app_initializer, "init_views"),
         ):
diff --git a/tests/unit_tests/models/core_test.py 
b/tests/unit_tests/models/core_test.py
index 7d7aa96ea19..b2a48df0592 100644
--- a/tests/unit_tests/models/core_test.py
+++ b/tests/unit_tests/models/core_test.py
@@ -19,7 +19,6 @@
 from datetime import datetime
 
 import pytest
-from flask import current_app
 from pytest_mock import MockerFixture
 from sqlalchemy import (
     Column,
@@ -29,7 +28,6 @@ from sqlalchemy import (
     Table as SqlalchemyTable,
 )
 from sqlalchemy.engine.reflection import Inspector
-from sqlalchemy.engine.url import make_url
 from sqlalchemy.orm.session import Session
 from sqlalchemy.sql import Select
 
@@ -525,60 +523,6 @@ def 
test_get_all_materialized_view_names_in_schema_needs_oauth2(
     assert excinfo.value.error.error_type == SupersetErrorType.OAUTH2_REDIRECT
 
 
-def test_get_sqla_engine(mocker: MockerFixture) -> None:
-    """
-    Test `_get_sqla_engine`.
-    """
-    from superset.models.core import Database
-
-    user = mocker.MagicMock()
-    user.email = "[email protected]"
-    mocker.patch(
-        "superset.models.core.security_manager.find_user",
-        return_value=user,
-    )
-    mocker.patch("superset.models.core.get_username", return_value="alice")
-
-    create_engine = mocker.patch("superset.models.core.create_engine")
-
-    database = Database(database_name="my_db", sqlalchemy_uri="trino://")
-    database._get_sqla_engine(nullpool=False)
-
-    create_engine.assert_called_with(
-        make_url("trino:///"),
-        connect_args={"source": "Apache Superset"},
-    )
-
-
-def test_get_sqla_engine_user_impersonation(mocker: MockerFixture) -> None:
-    """
-    Test user impersonation in `_get_sqla_engine`.
-    """
-    from superset.models.core import Database
-
-    user = mocker.MagicMock()
-    user.email = "[email protected]"
-    mocker.patch(
-        "superset.models.core.security_manager.find_user",
-        return_value=user,
-    )
-    mocker.patch("superset.models.core.get_username", return_value="alice")
-
-    create_engine = mocker.patch("superset.models.core.create_engine")
-
-    database = Database(
-        database_name="my_db",
-        sqlalchemy_uri="trino://",
-        impersonate_user=True,
-    )
-    database._get_sqla_engine(nullpool=False)
-
-    create_engine.assert_called_with(
-        make_url("trino:///"),
-        connect_args={"user": "alice", "source": "Apache Superset"},
-    )
-
-
 def test_add_database_to_signature():
     args = ["param1", "param2"]
 
@@ -604,36 +548,6 @@ def test_add_database_to_signature():
     assert args3 == ["param1", "param2", database]
 
 
-@with_feature_flags(IMPERSONATE_WITH_EMAIL_PREFIX=True)
-def test_get_sqla_engine_user_impersonation_email(mocker: MockerFixture) -> 
None:
-    """
-    Test user impersonation in `_get_sqla_engine` with `username_from_email`.
-    """
-    from superset.models.core import Database
-
-    user = mocker.MagicMock()
-    user.email = "[email protected]"
-    mocker.patch(
-        "superset.models.core.security_manager.find_user",
-        return_value=user,
-    )
-    mocker.patch("superset.models.core.get_username", return_value="alice")
-
-    create_engine = mocker.patch("superset.models.core.create_engine")
-
-    database = Database(
-        database_name="my_db",
-        sqlalchemy_uri="trino://",
-        impersonate_user=True,
-    )
-    database._get_sqla_engine(nullpool=False)
-
-    create_engine.assert_called_with(
-        make_url("trino:///"),
-        connect_args={"user": "alice.doe", "source": "Apache Superset"},
-    )
-
-
 def test_is_oauth2_enabled() -> None:
     """
     Test the `is_oauth2_enabled` method.
@@ -753,37 +667,6 @@ def test_get_oauth2_config_redirect_uri_from_config(
     assert config["redirect_uri"] == custom_redirect_uri
 
 
-def test_raw_connection_oauth_engine(mocker: MockerFixture) -> None:
-    """
-    Test that we can start OAuth2 from `raw_connection()` errors.
-
-    With OAuth2, some databases will raise an exception when the engine is 
first created
-    (eg, BigQuery). Others, like, Snowflake, when the connection is created. 
And
-    finally, GSheets will raise an exception when the query is executed.
-
-    This tests verifies that when calling `raw_connection()` the OAuth2 flow is
-    triggered when the engine is created.
-    """
-    g = mocker.patch("superset.db_engine_specs.base.g")
-    g.user = mocker.MagicMock()
-    g.user.id = 42
-
-    database = Database(
-        id=1,
-        database_name="my_db",
-        sqlalchemy_uri="sqlite://",
-        encrypted_extra=json.dumps(oauth2_client_info),
-    )
-    database.db_engine_spec.oauth2_exception = OAuth2Error  # type: ignore
-    _get_sqla_engine = mocker.patch.object(database, "_get_sqla_engine")
-    _get_sqla_engine.side_effect = OAuth2Error("OAuth2 required")
-
-    with pytest.raises(OAuth2RedirectError) as excinfo:
-        with database.get_raw_connection() as conn:
-            conn.cursor()
-    assert str(excinfo.value) == "You don't have permission to access the 
data."
-
-
 def test_raw_connection_oauth_connection(mocker: MockerFixture) -> None:
     """
     Test that we can start OAuth2 from `raw_connection()` errors.
@@ -879,56 +762,6 @@ def test_get_schema_access_for_file_upload() -> None:
     assert database.get_schema_access_for_file_upload() == {"public"}
 
 
-def test_engine_context_manager(mocker: MockerFixture, app_context: None) -> 
None:
-    """
-    Test the engine context manager.
-    """
-    from unittest.mock import MagicMock
-
-    engine_context_manager = MagicMock()
-    mocker.patch.dict(
-        current_app.config,
-        {"ENGINE_CONTEXT_MANAGER": engine_context_manager},
-    )
-    _get_sqla_engine = mocker.patch.object(Database, "_get_sqla_engine")
-
-    database = Database(database_name="my_db", sqlalchemy_uri="trino://")
-    with database.get_sqla_engine("catalog", "schema"):
-        pass
-
-    engine_context_manager.assert_called_once_with(database, "catalog", 
"schema")
-    engine_context_manager().__enter__.assert_called_once()
-    engine_context_manager().__exit__.assert_called_once_with(None, None, None)
-    _get_sqla_engine.assert_called_once_with(
-        catalog="catalog",
-        schema="schema",
-        nullpool=True,
-        source=None,
-        sqlalchemy_uri="trino://",
-    )
-
-
-def test_engine_oauth2(mocker: MockerFixture) -> None:
-    """
-    Test that we handle OAuth2 when `create_engine` fails.
-    """
-    database = Database(database_name="my_db", sqlalchemy_uri="trino://")
-    mocker.patch.object(database, "_get_sqla_engine", side_effect=Exception)
-    mocker.patch.object(database, "is_oauth2_enabled", return_value=True)
-    mocker.patch.object(database.db_engine_spec, "needs_oauth2", 
return_value=True)
-    start_oauth2_dance = mocker.patch.object(
-        database.db_engine_spec,
-        "start_oauth2_dance",
-        side_effect=OAuth2Error("OAuth2 required"),
-    )
-
-    with pytest.raises(OAuth2Error):
-        with database.get_sqla_engine("catalog", "schema"):
-            pass
-
-    start_oauth2_dance.assert_called_with(database)
-
-
 def test_purge_oauth2_tokens(session: Session) -> None:
     """
     Test the `purge_oauth2_tokens` method.

Reply via email to