This is an automated email from the ASF dual-hosted git repository. tai pushed a commit to branch feat/starrocks-catalog in repository https://gitbox.apache.org/repos/asf/superset.git
commit 74996ff6a838a680f37661edbe4261e40141fc4d Author: Tai Dupree <[email protected]> AuthorDate: Thu Jan 8 11:55:06 2026 -0800 feat(starrocks): add catalog support for StarRocks database connections --- superset/db_engine_specs/starrocks.py | 152 ++++++++++++++++----- tests/unit_tests/db_engine_specs/test_starrocks.py | 114 +++++++++++++++- 2 files changed, 229 insertions(+), 37 deletions(-) diff --git a/superset/db_engine_specs/starrocks.py b/superset/db_engine_specs/starrocks.py index d3e2172f2b..3777e67149 100644 --- a/superset/db_engine_specs/starrocks.py +++ b/superset/db_engine_specs/starrocks.py @@ -18,11 +18,12 @@ import logging import re from re import Pattern -from typing import Any, Optional, Union +from typing import Any from urllib import parse from flask_babel import gettext as __ from sqlalchemy import Float, Integer, Numeric, types +from sqlalchemy.engine.reflection import Inspector from sqlalchemy.engine.url import URL from sqlalchemy.sql.type_api import TypeEngine @@ -31,6 +32,8 @@ from superset.errors import SupersetErrorType from superset.models.core import Database from superset.utils.core import GenericDataType +DEFAULT_CATALOG = "default_catalog" + # Regular expressions to catch custom errors CONNECTION_ACCESS_DENIED_REGEX = re.compile( "Access denied for user '(?P<username>.*?)'" @@ -68,7 +71,7 @@ class ARRAY(TypeEngine): __visit_name__ = "ARRAY" @property - def python_type(self) -> Optional[type[list[Any]]]: + def python_type(self) -> type[list[Any]] | None: return list @@ -76,7 +79,7 @@ class MAP(TypeEngine): __visit_name__ = "MAP" @property - def python_type(self) -> Optional[type[dict[Any, Any]]]: + def python_type(self) -> type[dict[Any, Any]] | None: return dict @@ -84,7 +87,7 @@ class STRUCT(TypeEngine): __visit_name__ = "STRUCT" @property - def python_type(self) -> Optional[type[Any]]: + def python_type(self) -> type[Any] | None: return None @@ -94,8 +97,11 @@ class StarRocksEngineSpec(MySQLEngineSpec): default_driver = "starrocks" sqlalchemy_uri_placeholder = ( - "starrocks://user:password@host:port/catalog.db[?key=value&key=value...]" + "starrocks://user:password@host:port[/catalog.db]" ) + supports_dynamic_schema = True + supports_catalog = supports_dynamic_catalog = True + supports_cross_catalog_queries = True column_type_mappings = ( # type: ignore ( @@ -168,17 +174,39 @@ class StarRocksEngineSpec(MySQLEngineSpec): cls, uri: URL, connect_args: dict[str, Any], - catalog: Optional[str] = None, - schema: Optional[str] = None, + catalog: str | None = None, + schema: str | None = None, ) -> tuple[URL, dict[str, Any]]: - database = uri.database - if schema and database: + """ + Adjust engine parameters for StarRocks catalog and schema support. + + StarRocks uses a "catalog.schema" format in the database field: + - "catalog.schema" - both specified + - "catalog." - catalog only (for browsing schemas) + - None - neither specified + """ + if uri.database and "." in uri.database: + current_catalog, current_schema = uri.database.split(".", 1) + elif uri.database: + current_catalog, current_schema = uri.database, None + else: + current_catalog, current_schema = None, None + + if schema: schema = parse.quote(schema, safe="") - if "." in database: - database = database.split(".")[0] + "." + schema - else: - database = "default_catalog." + schema - uri = uri.set(database=database) + + effective_catalog = catalog or current_catalog or DEFAULT_CATALOG + # only use the schema/db from uri if we're not overriding catalog + effective_schema = schema + if not effective_schema and (not catalog or catalog == current_catalog): + effective_schema = current_schema + + if effective_schema: + adjusted_database = f"{effective_catalog}.{effective_schema}" + else: + adjusted_database = f"{effective_catalog}." + + uri = uri.set(database=adjusted_database) return uri, connect_args @@ -187,21 +215,85 @@ class StarRocksEngineSpec(MySQLEngineSpec): cls, sqlalchemy_uri: URL, connect_args: dict[str, Any], - ) -> Optional[str]: + ) -> str | None: + """ + Extract schema from engine parameters. + + Returns the schema portion from formats like: + - "catalog.schema" -> "schema" + - "schema" -> None (ambiguous - could be catalog or schema) + - "" or None -> None """ - Return the configured schema. + if not sqlalchemy_uri.database: + return None - For StarRocks the SQLAlchemy URI looks like this: + database = sqlalchemy_uri.database.strip("/") + if not database or "." not in database: + return None - starrocks://localhost:9030/catalog.schema + schema = database.split(".")[-1] + return parse.unquote(schema) + @classmethod + def get_default_catalog(cls, database: Database) -> str: """ - database = sqlalchemy_uri.database.strip("/") + Return the default catalog. - if "." not in database: - return None + Extracts catalog from URI (e.g., "iceberg" from "iceberg.schema"), + otherwise returns DEFAULT_CATALOG. + """ + if database.url_object.database and "." in database.url_object.database: + return database.url_object.database.split(".")[0] - return parse.unquote(database.split(".")[1]) + return DEFAULT_CATALOG + + @classmethod + def get_catalog_names( + cls, + database: Database, + inspector: Inspector, + ) -> set[str]: + """ + Get all available catalogs. + + Executes SHOW CATALOGS and extracts catalog names from the result. + The command returns columns: Catalog, Type, Comment + """ + try: + result = inspector.bind.execute("SHOW CATALOGS") + catalogs = set() + + for row in result: + try: + if hasattr(row, "keys") and "Catalog" in row.keys(): + catalogs.add(row["Catalog"]) + elif hasattr(row, "Catalog"): + catalogs.add(row.Catalog) + else: + catalogs.add(row[0]) + except (AttributeError, TypeError, IndexError, KeyError) as ex: + logger.warning("Unable to extract catalog name from row: %s (%s)", row, ex) + continue + + return catalogs + except Exception as ex: # pylint: disable=broad-except + logger.exception("Error fetching catalog names from SHOW CATALOGS: %s", ex) + return set() + + @classmethod + def get_schema_names(cls, inspector: Inspector) -> set[str]: + """ + Get all schemas/databases using SHOW DATABASES. + + The catalog context is set via the database field in the connection URL + (e.g., "catalog." sets the context to that catalog). + """ + try: + result = inspector.bind.execute("SHOW DATABASES") + return {row[0] for row in result} + except Exception as ex: # pylint: disable=broad-except + logger.exception("Error fetching schema names from SHOW DATABASES: %s", ex) + return set() @classmethod def impersonate_user( @@ -225,21 +317,13 @@ class StarRocksEngineSpec(MySQLEngineSpec): def get_prequeries( cls, database: Database, - catalog: Union[str, None] = None, - schema: Union[str, None] = None, + catalog: str | None = None, + schema: str | None = None, ) -> list[str]: """ - Return pre-session queries. - - These are currently used as an alternative to ``adjust_engine_params`` for - databases where the selected schema cannot be specified in the SQLAlchemy URI or - connection arguments. - - For example, in order to specify a default schema in RDS we need to run a query - at the beginning of the session: - - sql> set search_path = my_schema; + Get pre-session queries. + For StarRocks with user impersonation enabled, returns an EXECUTE AS statement. """ if database.impersonate_user: username = database.get_effective_user(database.url_object) diff --git a/tests/unit_tests/db_engine_specs/test_starrocks.py b/tests/unit_tests/db_engine_specs/test_starrocks.py index 67016a0801..e37aeab901 100644 --- a/tests/unit_tests/db_engine_specs/test_starrocks.py +++ b/tests/unit_tests/db_engine_specs/test_starrocks.py @@ -79,7 +79,7 @@ def test_get_column_spec( ( "starrocks://user:password@host/db1", {"param1": "some_value"}, - "db1", + "db1.", # Single value is treated as schema (in default catalog) {"param1": "some_value"}, ), ( @@ -88,12 +88,18 @@ def test_get_column_spec( "catalog1.db1", {"param1": "some_value"}, ), + ( + "starrocks://user:password@host", + {"param1": "some_value"}, + "default_catalog.", + {"param1": "some_value"}, + ), ], ) def test_adjust_engine_params( sqlalchemy_uri: str, connect_args: dict[str, Any], - return_schema: str, + return_schema: Optional[str], return_connect_args: dict[str, Any], ) -> None: from superset.db_engine_specs.starrocks import StarRocksEngineSpec @@ -112,6 +118,7 @@ def test_get_schema_from_engine_params() -> None: """ from superset.db_engine_specs.starrocks import StarRocksEngineSpec + # With catalog.schema format assert ( StarRocksEngineSpec.get_schema_from_engine_params( make_url("starrocks://localhost:9030/hive.default"), @@ -120,9 +127,19 @@ def test_get_schema_from_engine_params() -> None: == "default" ) + # With only catalog (no schema) - should return None + assert ( + StarRocksEngineSpec.get_schema_from_engine_params( + make_url("starrocks://localhost:9030/sales"), + {}, + ) + is None + ) + + # With no database - should return None assert ( StarRocksEngineSpec.get_schema_from_engine_params( - make_url("starrocks://localhost:9030/hive"), + make_url("starrocks://localhost:9030"), {}, ) is None @@ -173,3 +190,94 @@ def test_impersonation_disabled(mocker: MockerFixture) -> None: ) == (make_url("starrocks://service_user@localhost:9030/hive.default"), {}) assert StarRocksEngineSpec.get_prequeries(database) == [] + + +def test_get_default_catalog(mocker: MockerFixture) -> None: + """ + Test the ``get_default_catalog`` method. + """ + from superset.db_engine_specs.starrocks import StarRocksEngineSpec + + # Test case 1: Catalog is in the URI + database = mocker.MagicMock() + database.url_object.database = "hive.default" + + assert StarRocksEngineSpec.get_default_catalog(database) == "hive" + + # Test case 2: Catalog is not in the URI, returns default + database = mocker.MagicMock() + database.url_object.database = "default" + + assert StarRocksEngineSpec.get_default_catalog(database) == "default_catalog" + + +def test_get_catalog_names(mocker: MockerFixture) -> None: + """ + Test the ``get_catalog_names`` method. + """ + from superset.db_engine_specs.starrocks import StarRocksEngineSpec + + database = mocker.MagicMock() + inspector = mocker.MagicMock() + + # Mock the actual StarRocks SHOW CATALOGS format + # StarRocks returns rows with keys: ['Catalog', 'Type', 'Comment'] + mock_row_1 = mocker.MagicMock() + mock_row_1.keys.return_value = ["Catalog", "Type", "Comment"] + mock_row_1.__getitem__ = lambda self, key: "default_catalog" if key == "Catalog" else None + + mock_row_2 = mocker.MagicMock() + mock_row_2.keys.return_value = ["Catalog", "Type", "Comment"] + mock_row_2.__getitem__ = lambda self, key: "hive" if key == "Catalog" else None + + mock_row_3 = mocker.MagicMock() + mock_row_3.keys.return_value = ["Catalog", "Type", "Comment"] + mock_row_3.__getitem__ = lambda self, key: "iceberg" if key == "Catalog" else None + + inspector.bind.execute.return_value = [mock_row_1, mock_row_2, mock_row_3] + + catalogs = StarRocksEngineSpec.get_catalog_names(database, inspector) + assert catalogs == {"default_catalog", "hive", "iceberg"} + + [email protected]( + "uri,catalog,schema,expected_database", + [ + # Test with catalog and schema/db in URI + ("starrocks://host/hive.sales", None, None, "hive.sales"), + # Test overriding catalog + ("starrocks://host/hive.sales", "iceberg", None, "iceberg."), + # Test overriding schema/db + ("starrocks://host/hive.sales", None, "marketing", "hive.marketing"), + # Test overriding both + ("starrocks://host/hive.sales", "iceberg", "marketing", "iceberg.marketing"), + # Test with only catalog in URI (no schema/db), add new schema + ("starrocks://host/hive", None, "marketing", "hive.marketing"), + # Test with catalog in URI, override catalog + ("starrocks://host/hive", "iceberg", None, "iceberg."), + # Test with no catalog/database in URI, overriding catalog" + ("starrocks://host", "iceberg", None, "iceberg."), + # Test with no catalog/database in URI, catalog and schema/db + ("starrocks://host", "iceberg", "sales", "iceberg.sales"), + # Test with empty database and empty overrides, uses default catalog + ("starrocks://host", None, None, 'default_catalog.'), + # Test schema only (no catalog) when URI has no database, uses default_catalog + ("starrocks://host", None, "sales", "default_catalog.sales"), + ], +) +def test_adjust_engine_params_with_catalog( + uri: str, + catalog: Optional[str], + schema: Optional[str], + expected_database: Optional[str], +) -> None: + """ + Test the ``adjust_engine_params`` method with catalog parameter. + """ + from superset.db_engine_specs.starrocks import StarRocksEngineSpec + + url = make_url(uri) + returned_url, _ = StarRocksEngineSpec.adjust_engine_params( + url, {}, catalog=catalog, schema=schema + ) + assert returned_url.database == expected_database
