This is an automated email from the ASF dual-hosted git repository. beto pushed a commit to branch sl-2-semantic-layer-core in repository https://gitbox.apache.org/repos/asf/superset.git
commit 7d99d5cef923410971b8ad67075098042b81eebc Author: Beto Dealmeida <[email protected]> AuthorDate: Mon Feb 2 17:27:02 2026 -0500 feat(semantic-layer): add core semantic layer infrastructure Add the foundational semantic layer implementation: - SemanticLayer and SemanticView SQLAlchemy models - Semantic layer registry for plugin-based implementations - Query mapper for translating Superset queries to semantic layer format - Type definitions for metrics, dimensions, entities, and grains - DAO layer for semantic layer CRUD operations - Database migration for semantic_layers and semantic_views tables - Updated Explorable base class with ColumnMetadata protocol - TypedDict updates for API response compatibility - Update sql_lab and sqla models for new TypedDict fields Co-Authored-By: Claude Opus 4.5 <[email protected]> --- superset/connectors/sqla/models.py | 6 +- superset/daos/semantic_layer.py | 152 ++++ superset/explorables/base.py | 128 ++- ...6_33d7e0e21daa_add_semantic_layers_and_views.py | 126 +++ superset/models/sql_lab.py | 6 +- superset/semantic_layers/__init__.py | 16 + superset/semantic_layers/mapper.py | 944 +++++++++++++++++++++ superset/semantic_layers/models.py | 373 ++++++++ superset/semantic_layers/registry.py | 132 +++ superset/semantic_layers/types.py | 497 +++++++++++ superset/superset_typing.py | 52 +- superset/utils/core.py | 43 +- 12 files changed, 2441 insertions(+), 34 deletions(-) diff --git a/superset/connectors/sqla/models.py b/superset/connectors/sqla/models.py index be74a199672..98f122d115b 100644 --- a/superset/connectors/sqla/models.py +++ b/superset/connectors/sqla/models.py @@ -107,6 +107,8 @@ from superset.sql.parse import Table from superset.superset_typing import ( AdhocColumn, AdhocMetric, + DatasetColumnData, + DatasetMetricData, ExplorableData, Metric, QueryObjectDict, @@ -463,8 +465,8 @@ class BaseDatasource( # sqla-specific "sql": self.sql, # one to many - "columns": [o.data for o in self.columns], - "metrics": [o.data for o in self.metrics], + "columns": cast(list[DatasetColumnData], [o.data for o in self.columns]), + "metrics": cast(list[DatasetMetricData], [o.data for o in self.metrics]), "folders": self.folders, # TODO deprecate, move logic to JS "order_by_choices": self.order_by_choices, diff --git a/superset/daos/semantic_layer.py b/superset/daos/semantic_layer.py new file mode 100644 index 00000000000..9c591e4a7a4 --- /dev/null +++ b/superset/daos/semantic_layer.py @@ -0,0 +1,152 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""DAOs for semantic layer models.""" + +from __future__ import annotations + +from superset.daos.base import BaseDAO +from superset.extensions import db +from superset.semantic_layers.models import SemanticLayer, SemanticView + + +class SemanticLayerDAO(BaseDAO[SemanticLayer]): + """ + Data Access Object for SemanticLayer model. + """ + + @staticmethod + def validate_uniqueness(name: str) -> bool: + """ + Validate that semantic layer name is unique. + + :param name: Semantic layer name + :return: True if name is unique, False otherwise + """ + query = db.session.query(SemanticLayer).filter(SemanticLayer.name == name) + return not db.session.query(query.exists()).scalar() + + @staticmethod + def validate_update_uniqueness(layer_uuid: str, name: str) -> bool: + """ + Validate that semantic layer name is unique for updates. + + :param layer_uuid: UUID of the semantic layer being updated + :param name: New name to validate + :return: True if name is unique, False otherwise + """ + query = db.session.query(SemanticLayer).filter( + SemanticLayer.name == name, + SemanticLayer.uuid != layer_uuid, + ) + return not db.session.query(query.exists()).scalar() + + @staticmethod + def find_by_name(name: str) -> SemanticLayer | None: + """ + Find semantic layer by name. + + :param name: Semantic layer name + :return: SemanticLayer instance or None + """ + return ( + db.session.query(SemanticLayer) + .filter(SemanticLayer.name == name) + .one_or_none() + ) + + @classmethod + def get_semantic_views(cls, layer_uuid: str) -> list[SemanticView]: + """ + Get all semantic views for a semantic layer. + + :param layer_uuid: UUID of the semantic layer + :return: List of SemanticView instances + """ + return ( + db.session.query(SemanticView) + .filter(SemanticView.semantic_layer_uuid == layer_uuid) + .all() + ) + + +class SemanticViewDAO(BaseDAO[SemanticView]): + """Data Access Object for SemanticView model.""" + + @staticmethod + def find_by_semantic_layer(layer_uuid: str) -> list[SemanticView]: + """ + Find all views for a semantic layer. + + :param layer_uuid: UUID of the semantic layer + :return: List of SemanticView instances + """ + return ( + db.session.query(SemanticView) + .filter(SemanticView.semantic_layer_uuid == layer_uuid) + .all() + ) + + @staticmethod + def validate_uniqueness(name: str, layer_uuid: str) -> bool: + """ + Validate that view name is unique within semantic layer. + + :param name: View name + :param layer_uuid: UUID of the semantic layer + :return: True if name is unique within layer, False otherwise + """ + query = db.session.query(SemanticView).filter( + SemanticView.name == name, + SemanticView.semantic_layer_uuid == layer_uuid, + ) + return not db.session.query(query.exists()).scalar() + + @staticmethod + def validate_update_uniqueness(view_uuid: str, name: str, layer_uuid: str) -> bool: + """ + Validate that view name is unique within semantic layer for updates. + + :param view_uuid: UUID of the view being updated + :param name: New name to validate + :param layer_uuid: UUID of the semantic layer + :return: True if name is unique within layer, False otherwise + """ + query = db.session.query(SemanticView).filter( + SemanticView.name == name, + SemanticView.semantic_layer_uuid == layer_uuid, + SemanticView.uuid != view_uuid, + ) + return not db.session.query(query.exists()).scalar() + + @staticmethod + def find_by_name(name: str, layer_uuid: str) -> SemanticView | None: + """ + Find semantic view by name within a semantic layer. + + :param name: View name + :param layer_uuid: UUID of the semantic layer + :return: SemanticView instance or None + """ + return ( + db.session.query(SemanticView) + .filter( + SemanticView.name == name, + SemanticView.semantic_layer_uuid == layer_uuid, + ) + .one_or_none() + ) diff --git a/superset/explorables/base.py b/superset/explorables/base.py index 2d534b72099..de69257a317 100644 --- a/superset/explorables/base.py +++ b/superset/explorables/base.py @@ -53,6 +53,130 @@ class TimeGrainDict(TypedDict): duration: str | None +@runtime_checkable +class MetricMetadata(Protocol): + """ + Protocol for metric metadata objects. + + Represents a metric that's available on an explorable data source. + Metrics contain SQL expressions or references to semantic layer measures. + + Attributes: + metric_name: Unique identifier for the metric + expression: SQL expression or reference for calculating the metric + verbose_name: Human-readable name for display in the UI + description: Description of what the metric represents + d3format: D3 format string for formatting numeric values + currency: Currency configuration for the metric (JSON object) + warning_text: Warning message to display when using this metric + certified_by: Person or entity that certified this metric + certification_details: Details about the certification + """ + + @property + def metric_name(self) -> str: + """Unique identifier for the metric.""" + + @property + def expression(self) -> str: + """SQL expression or reference for calculating the metric.""" + + @property + def verbose_name(self) -> str | None: + """Human-readable name for display in the UI.""" + + @property + def description(self) -> str | None: + """Description of what the metric represents.""" + + @property + def d3format(self) -> str | None: + """D3 format string for formatting numeric values.""" + + @property + def currency(self) -> dict[str, Any] | None: + """Currency configuration for the metric (JSON object).""" + + @property + def warning_text(self) -> str | None: + """Warning message to display when using this metric.""" + + @property + def certified_by(self) -> str | None: + """Person or entity that certified this metric.""" + + @property + def certification_details(self) -> str | None: + """Details about the certification.""" + + +@runtime_checkable +class ColumnMetadata(Protocol): + """ + Protocol for column metadata objects. + + Represents a column/dimension that's available on an explorable data source. + Used for grouping, filtering, and dimension-based analysis. + + Attributes: + column_name: Unique identifier for the column + type: SQL data type of the column (e.g., 'VARCHAR', 'INTEGER', 'DATETIME') + is_dttm: Whether this column represents a date or time value + verbose_name: Human-readable name for display in the UI + description: Description of what the column represents + groupby: Whether this column is allowed for grouping/aggregation + filterable: Whether this column can be used in filters + expression: SQL expression if this is a calculated column + python_date_format: Python datetime format string for temporal columns + advanced_data_type: Advanced data type classification + extra: Additional metadata stored as JSON + """ + + @property + def column_name(self) -> str: + """Unique identifier for the column.""" + + @property + def type(self) -> str: + """SQL data type of the column.""" + + @property + def is_dttm(self) -> bool: + """Whether this column represents a date or time value.""" + + @property + def verbose_name(self) -> str | None: + """Human-readable name for display in the UI.""" + + @property + def description(self) -> str | None: + """Description of what the column represents.""" + + @property + def groupby(self) -> bool: + """Whether this column is allowed for grouping/aggregation.""" + + @property + def filterable(self) -> bool: + """Whether this column can be used in filters.""" + + @property + def expression(self) -> str | None: + """SQL expression if this is a calculated column.""" + + @property + def python_date_format(self) -> str | None: + """Python datetime format string for temporal columns.""" + + @property + def advanced_data_type(self) -> str | None: + """Advanced data type classification.""" + + @property + def extra(self) -> str | None: + """Additional metadata stored as JSON.""" + + @runtime_checkable class Explorable(Protocol): """ @@ -132,7 +256,7 @@ class Explorable(Protocol): """ @property - def metrics(self) -> list[Any]: + def metrics(self) -> list[MetricMetadata]: """ List of metric metadata objects. @@ -147,7 +271,7 @@ class Explorable(Protocol): # TODO: rename to dimensions @property - def columns(self) -> list[Any]: + def columns(self) -> list[ColumnMetadata]: """ List of column metadata objects. diff --git a/superset/migrations/versions/2025-11-04_11-26_33d7e0e21daa_add_semantic_layers_and_views.py b/superset/migrations/versions/2025-11-04_11-26_33d7e0e21daa_add_semantic_layers_and_views.py new file mode 100644 index 00000000000..cd022dfdd62 --- /dev/null +++ b/superset/migrations/versions/2025-11-04_11-26_33d7e0e21daa_add_semantic_layers_and_views.py @@ -0,0 +1,126 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""add_semantic_layers_and_views + +Revision ID: 33d7e0e21daa +Revises: 9787190b3d89 +Create Date: 2025-11-04 11:26:00.000000 + +""" + +import uuid + +import sqlalchemy as sa +from sqlalchemy_utils import UUIDType +from sqlalchemy_utils.types.json import JSONType + +from superset.extensions import encrypted_field_factory +from superset.migrations.shared.utils import ( + create_fks_for_table, + create_table, + drop_table, +) + +# revision identifiers, used by Alembic. +revision = "33d7e0e21daa" +down_revision = "9787190b3d89" + + +def upgrade(): + # Create semantic_layers table + create_table( + "semantic_layers", + sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False), + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("name", sa.String(length=250), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column("type", sa.String(length=250), nullable=False), + sa.Column( + "configuration", + encrypted_field_factory.create(JSONType), + nullable=True, + ), + sa.Column("cache_timeout", sa.Integer(), nullable=True), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("uuid"), + ) + + # Create foreign key constraints for semantic_layers + create_fks_for_table( + "fk_semantic_layers_created_by_fk_ab_user", + "semantic_layers", + "ab_user", + ["created_by_fk"], + ["id"], + ) + + create_fks_for_table( + "fk_semantic_layers_changed_by_fk_ab_user", + "semantic_layers", + "ab_user", + ["changed_by_fk"], + ["id"], + ) + + # Create semantic_views table + create_table( + "semantic_views", + sa.Column("uuid", UUIDType(binary=True), default=uuid.uuid4, nullable=False), + sa.Column("created_on", sa.DateTime(), nullable=True), + sa.Column("changed_on", sa.DateTime(), nullable=True), + sa.Column("name", sa.String(length=250), nullable=False), + sa.Column("description", sa.Text(), nullable=True), + sa.Column( + "configuration", + encrypted_field_factory.create(JSONType), + nullable=True, + ), + sa.Column("cache_timeout", sa.Integer(), nullable=True), + sa.Column( + "semantic_layer_uuid", + UUIDType(binary=True), + sa.ForeignKey("semantic_layers.uuid", ondelete="CASCADE"), + nullable=False, + ), + sa.Column("created_by_fk", sa.Integer(), nullable=True), + sa.Column("changed_by_fk", sa.Integer(), nullable=True), + sa.PrimaryKeyConstraint("uuid"), + ) + + # Create foreign key constraints for semantic_views + create_fks_for_table( + "fk_semantic_views_created_by_fk_ab_user", + "semantic_views", + "ab_user", + ["created_by_fk"], + ["id"], + ) + + create_fks_for_table( + "fk_semantic_views_changed_by_fk_ab_user", + "semantic_views", + "ab_user", + ["changed_by_fk"], + ["id"], + ) + + +def downgrade(): + drop_table("semantic_views") + drop_table("semantic_layers") diff --git a/superset/models/sql_lab.py b/superset/models/sql_lab.py index 956d33053bc..e53c3c62687 100644 --- a/superset/models/sql_lab.py +++ b/superset/models/sql_lab.py @@ -22,7 +22,7 @@ import logging import re from collections.abc import Hashable from datetime import datetime -from typing import Any, Optional, TYPE_CHECKING +from typing import Any, cast, Optional, TYPE_CHECKING import sqlalchemy as sqla from flask import current_app as app @@ -64,7 +64,7 @@ from superset.sql.parse import ( Table, ) from superset.sqllab.limiting_factor import LimitingFactor -from superset.superset_typing import ExplorableData, QueryObjectDict +from superset.superset_typing import DatasetColumnData, ExplorableData, QueryObjectDict from superset.utils import json from superset.utils.core import ( get_column_name, @@ -258,7 +258,7 @@ class Query( ], "filter_select": True, "name": self.tab_name, - "columns": [o.data for o in self.columns], + "columns": cast(list[DatasetColumnData], [o.data for o in self.columns]), "metrics": [], "id": self.id, "type": self.type, diff --git a/superset/semantic_layers/__init__.py b/superset/semantic_layers/__init__.py new file mode 100644 index 00000000000..13a83393a91 --- /dev/null +++ b/superset/semantic_layers/__init__.py @@ -0,0 +1,16 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. diff --git a/superset/semantic_layers/mapper.py b/superset/semantic_layers/mapper.py new file mode 100644 index 00000000000..31d4e32b1e4 --- /dev/null +++ b/superset/semantic_layers/mapper.py @@ -0,0 +1,944 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Functions for mapping `QueryObject` to semantic layers. + +These functions validate and convert a `QueryObject` into one or more `SemanticQuery`, +which are then passed to semantic layer implementations for execution, returning a +single dataframe. + +""" + +from datetime import datetime, timedelta +from time import time +from typing import Any, cast, Sequence, TYPE_CHECKING, TypeGuard + +if TYPE_CHECKING: + from superset.superset_typing import Column + +import numpy as np + +from superset.common.db_query_status import QueryStatus +from superset.common.query_object import QueryObject +from superset.common.utils.time_range_utils import get_since_until_from_query_object +from superset.connectors.sqla.models import BaseDatasource +from superset.models.helpers import QueryResult +from superset.semantic_layers.types import ( + AdhocExpression, + AdhocFilter, + Day, + Dimension, + Filter, + FilterValues, + Grain, + GroupLimit, + Hour, + Metric, + Minute, + Month, + Operator, + OrderDirection, + OrderTuple, + PredicateType, + Quarter, + Second, + SemanticQuery, + SemanticResult, + SemanticViewFeature, + Week, + Year, +) +from superset.utils.core import ( + FilterOperator, + QueryObjectFilterClause, + TIME_COMPARISON, +) +from superset.utils.date_parser import get_past_or_future + + +class ValidatedQueryObjectFilterClause(QueryObjectFilterClause): + """ + A validated QueryObject filter clause with a string column name. + + The `col` in a `QueryObjectFilterClause` can be either a string (column name) or an + adhoc column, but we only support the former in semantic layers. + """ + + # overwrite to narrow type; mypy complains about more restrictive typed dicts, + # but the alternative would be to redefine the object + col: str # type: ignore[misc] + op: str # type: ignore[misc] + + +class ValidatedQueryObject(QueryObject): + """ + A query object that has a datasource defined. + """ + + datasource: BaseDatasource + + # overwrite to narrow type; mypy complains about the assignment since the base type + # allows adhoc filters, but we only support validated filters here + filter: list[ValidatedQueryObjectFilterClause] # type: ignore[assignment] + series_columns: Sequence[str] # type: ignore[assignment] + series_limit_metric: str | None + + +def get_results(query_object: QueryObject) -> QueryResult: + """ + Run 1+ queries based on `QueryObject` and return the results. + + :param query_object: The QueryObject containing query specifications + :return: QueryResult compatible with Superset's query interface + """ + if not validate_query_object(query_object): + raise ValueError("QueryObject must have a datasource defined.") + + # Track execution time + start_time = time() + + semantic_view = query_object.datasource.implementation + dispatcher = ( + semantic_view.get_row_count + if query_object.is_rowcount + else semantic_view.get_dataframe + ) + + # Step 1: Convert QueryObject to list of SemanticQuery objects + # The first query is the main query, subsequent queries are for time offsets + queries = map_query_object(query_object) + + # Step 2: Execute the main query (first in the list) + main_query = queries[0] + main_result = dispatcher( + metrics=main_query.metrics, + dimensions=main_query.dimensions, + filters=main_query.filters, + order=main_query.order, + limit=main_query.limit, + offset=main_query.offset, + group_limit=main_query.group_limit, + ) + + main_df = main_result.results + + # Collect all requests (SQL queries, HTTP requests, etc.) for troubleshooting + all_requests = list(main_result.requests) + + # If no time offsets, return the main result as-is + if not query_object.time_offsets or len(queries) <= 1: + semantic_result = SemanticResult( + requests=all_requests, + results=main_df, + ) + duration = timedelta(seconds=time() - start_time) + return map_semantic_result_to_query_result( + semantic_result, + query_object, + duration, + ) + + # Get metric names from the main query + # These are the columns that will be renamed with offset suffixes + metric_names = [metric.name for metric in main_query.metrics] + + # Join keys are all columns except metrics + # These will be used to match rows between main and offset DataFrames + join_keys = [col for col in main_df.columns if col not in metric_names] + + # Step 3 & 4: Execute each time offset query and join results + for offset_query, time_offset in zip( + queries[1:], + query_object.time_offsets, + strict=False, + ): + # Execute the offset query + result = dispatcher( + metrics=offset_query.metrics, + dimensions=offset_query.dimensions, + filters=offset_query.filters, + order=offset_query.order, + limit=offset_query.limit, + offset=offset_query.offset, + group_limit=offset_query.group_limit, + ) + + # Add this query's requests to the collection + all_requests.extend(result.requests) + + offset_df = result.results + + # Handle empty results - add NaN columns directly instead of merging + # This avoids dtype mismatch issues with empty DataFrames + if offset_df.empty: + # Add offset metric columns with NaN values directly to main_df + for metric in metric_names: + offset_col_name = TIME_COMPARISON.join([metric, time_offset]) + main_df[offset_col_name] = np.nan + else: + # Rename metric columns with time offset suffix + # Format: "{metric_name}__{time_offset}" + # Example: "revenue" -> "revenue__1 week ago" + offset_df = offset_df.rename( + columns={ + metric: TIME_COMPARISON.join([metric, time_offset]) + for metric in metric_names + } + ) + + # Step 5: Perform left join on dimension columns + # This preserves all rows from main_df and adds offset metrics + # where they match + main_df = main_df.merge( + offset_df, + on=join_keys, + how="left", + suffixes=("", "__duplicate"), + ) + + # Clean up any duplicate columns that might have been created + # (shouldn't happen with proper join keys, but defensive programming) + duplicate_cols = [ + col for col in main_df.columns if col.endswith("__duplicate") + ] + if duplicate_cols: + main_df = main_df.drop(columns=duplicate_cols) + + # Convert final result to QueryResult + semantic_result = SemanticResult(requests=all_requests, results=main_df) + duration = timedelta(seconds=time() - start_time) + return map_semantic_result_to_query_result( + semantic_result, + query_object, + duration, + ) + + +def map_semantic_result_to_query_result( + semantic_result: SemanticResult, + query_object: ValidatedQueryObject, + duration: timedelta, +) -> QueryResult: + """ + Convert a SemanticResult to a QueryResult. + + :param semantic_result: Result from the semantic layer + :param query_object: Original QueryObject (for passthrough attributes) + :param duration: Time taken to execute the query + :return: QueryResult compatible with Superset's query interface + """ + # Get the query string from requests (typically one or more SQL queries) + query_str = "" + if semantic_result.requests: + # Join all requests for display (could be multiple for time comparisons) + query_str = "\n\n".join( + f"-- {req.type}\n{req.definition}" for req in semantic_result.requests + ) + + return QueryResult( + # Core data + df=semantic_result.results, + query=query_str, + duration=duration, + # Template filters - not applicable to semantic layers + # (semantic layers don't use Jinja templates) + applied_template_filters=None, + # Filter columns - not applicable to semantic layers + # (semantic layers handle filter validation internally) + applied_filter_columns=None, + rejected_filter_columns=None, + # Status - always success if we got here + # (errors would raise exceptions before reaching this point) + status=QueryStatus.SUCCESS, + error_message=None, + errors=None, + # Time range - pass through from original query_object + from_dttm=query_object.from_dttm, + to_dttm=query_object.to_dttm, + ) + + +def _normalize_column(column: "Column", dimension_names: set[str]) -> str: + """ + Normalize a column to its dimension name. + + Columns can be either: + - A string (dimension name directly) + - A dict with isColumnReference=True and sqlExpression containing the dimension name + """ + if isinstance(column, str): + return column + + if isinstance(column, dict): + # Handle column references (e.g., from time-series charts) + if column.get("isColumnReference") and column.get("sqlExpression"): + sql_expr = column["sqlExpression"] + if sql_expr in dimension_names: + return sql_expr + + raise ValueError("Adhoc dimensions are not supported in Semantic Views.") + + +def map_query_object(query_object: ValidatedQueryObject) -> list[SemanticQuery]: + """ + Convert a `QueryObject` into a list of `SemanticQuery`. + + This function maps the `QueryObject` into query objects that focus less on + visualization and more on semantics. + """ + semantic_view = query_object.datasource.implementation + + all_metrics = {metric.name: metric for metric in semantic_view.metrics} + all_dimensions = { + dimension.name: dimension for dimension in semantic_view.dimensions + } + + # Normalize columns (may be dicts with isColumnReference=True for time-series) + dimension_names = set(all_dimensions.keys()) + normalized_columns = { + _normalize_column(column, dimension_names) for column in query_object.columns + } + + metrics = [all_metrics[metric] for metric in (query_object.metrics or [])] + + grain = ( + _convert_time_grain(query_object.extras["time_grain_sqla"]) + if "time_grain_sqla" in query_object.extras + else None + ) + dimensions = [ + dimension + for dimension in semantic_view.dimensions + if dimension.name in normalized_columns + and ( + # if a grain is specified, only include the time dimension if its grain + # matches the requested grain + grain is None + or dimension.name != query_object.granularity + or dimension.grain == grain + ) + ] + + order = _get_order_from_query_object(query_object, all_metrics, all_dimensions) + limit = query_object.row_limit + offset = query_object.row_offset + + group_limit = _get_group_limit_from_query_object( + query_object, + all_metrics, + all_dimensions, + ) + + queries = [] + for time_offset in [None] + query_object.time_offsets: + filters = _get_filters_from_query_object( + query_object, + time_offset, + all_dimensions, + ) + print(">>", filters) + + queries.append( + SemanticQuery( + metrics=metrics, + dimensions=dimensions, + filters=filters, + order=order, + limit=limit, + offset=offset, + group_limit=group_limit, + ) + ) + + return queries + + +def _get_filters_from_query_object( + query_object: ValidatedQueryObject, + time_offset: str | None, + all_dimensions: dict[str, Dimension], +) -> set[Filter | AdhocFilter]: + """ + Extract all filters from the query object, including time range filters. + + This simplifies the complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm + by converting all time constraints into filters. + """ + filters: set[Filter | AdhocFilter] = set() + + # 1. Add fetch values predicate if present + if ( + query_object.apply_fetch_values_predicate + and query_object.datasource.fetch_values_predicate + ): + filters.add( + AdhocFilter( + type=PredicateType.WHERE, + definition=query_object.datasource.fetch_values_predicate, + ) + ) + + # 2. Add time range filter based on from_dttm/to_dttm + # For time offsets, this automatically calculates the shifted bounds + time_filters = _get_time_filter(query_object, time_offset, all_dimensions) + filters.update(time_filters) + + # 3. Add filters from query_object.extras (WHERE and HAVING clauses) + extras_filters = _get_filters_from_extras(query_object.extras) + filters.update(extras_filters) + + # 4. Add all other filters from query_object.filter + for filter_ in query_object.filter: + # Skip temporal range filters - we're using inner bounds instead + if ( + filter_.get("op") == FilterOperator.TEMPORAL_RANGE.value + and query_object.granularity + ): + continue + + if converted_filters := _convert_query_object_filter(filter_, all_dimensions): + filters.update(converted_filters) + + return filters + + +def _get_filters_from_extras(extras: dict[str, Any]) -> set[AdhocFilter]: + """ + Extract filters from the extras dict. + + The extras dict can contain various keys that affect query behavior: + + Supported keys (converted to filters): + - "where": SQL WHERE clause expression (e.g., "customer_id > 100") + - "having": SQL HAVING clause expression (e.g., "SUM(sales) > 1000") + + Other keys in extras (handled elsewhere in the mapper): + - "time_grain_sqla": Time granularity (e.g., "P1D", "PT1H") + Handled in _convert_time_grain() and used for dimension grain matching + + Note: The WHERE and HAVING clauses from extras are SQL expressions that + are passed through as-is to the semantic layer as AdhocFilter objects. + """ + filters: set[AdhocFilter] = set() + + # Add WHERE clause from extras + if where_clause := extras.get("where"): + filters.add( + AdhocFilter( + type=PredicateType.WHERE, + definition=where_clause, + ) + ) + + # Add HAVING clause from extras + if having_clause := extras.get("having"): + filters.add( + AdhocFilter( + type=PredicateType.HAVING, + definition=having_clause, + ) + ) + + return filters + + +def _get_time_filter( + query_object: ValidatedQueryObject, + time_offset: str | None, + all_dimensions: dict[str, Dimension], +) -> set[Filter]: + """ + Create a time range filter from the query object. + + This handles both regular queries and time offset queries, simplifying the + complexity of from_dttm/to_dttm/inner_from_dttm/inner_to_dttm by using the + same time bounds for both the main query and series limit subqueries. + """ + filters: set[Filter] = set() + + if not query_object.granularity: + return filters + + time_dimension = all_dimensions.get(query_object.granularity) + if not time_dimension: + return filters + + # Get the appropriate time bounds based on whether this is a time offset query + from_dttm, to_dttm = _get_time_bounds(query_object, time_offset) + + if not from_dttm or not to_dttm: + return filters + + # Create a filter with >= and < operators + return { + Filter( + type=PredicateType.WHERE, + column=time_dimension, + operator=Operator.GREATER_THAN_OR_EQUAL, + value=from_dttm, + ), + Filter( + type=PredicateType.WHERE, + column=time_dimension, + operator=Operator.LESS_THAN, + value=to_dttm, + ), + } + + +def _get_time_bounds( + query_object: ValidatedQueryObject, + time_offset: str | None, +) -> tuple[datetime | None, datetime | None]: + """ + Get the appropriate time bounds for the query. + + For regular queries (time_offset is None), returns from_dttm/to_dttm. + For time offset queries, calculates the shifted bounds. + + This simplifies the inner_from_dttm/inner_to_dttm complexity by using + the same bounds for both main queries and series limit subqueries (Option 1). + """ + if time_offset is None: + # Main query: use from_dttm/to_dttm directly + return query_object.from_dttm, query_object.to_dttm + + # Time offset query: calculate shifted bounds + # Use from_dttm/to_dttm if available, otherwise try to get from time_range + outer_from = query_object.from_dttm + outer_to = query_object.to_dttm + + if not outer_from or not outer_to: + # Fall back to parsing time_range if from_dttm/to_dttm not set + outer_from, outer_to = get_since_until_from_query_object(query_object) + + if not outer_from or not outer_to: + return None, None + + # Apply the offset to both bounds + offset_from = get_past_or_future(time_offset, outer_from) + offset_to = get_past_or_future(time_offset, outer_to) + + return offset_from, offset_to + + +def _convert_query_object_filter( + filter_: ValidatedQueryObjectFilterClause, + all_dimensions: dict[str, Dimension], +) -> set[Filter] | None: + """ + Convert a QueryObject filter dict to a semantic layer Filter or AdhocFilter. + """ + operator_str = filter_["op"] + + # Handle simple column filters + col = filter_.get("col") + if col not in all_dimensions: + return None + + dimension = all_dimensions[col] + + val_str = filter_["val"] + value: FilterValues | set[FilterValues] + if val_str is None: + value = None + elif isinstance(val_str, (list, tuple)): + value = set(val_str) + else: + value = val_str + + # Special case for temporal range + if operator_str == FilterOperator.TEMPORAL_RANGE.value: + if not isinstance(value, str): + raise ValueError( + f"Expected string value for temporal range, got {type(value)}" + ) + start, end = value.split(" : ") + return { + Filter( + type=PredicateType.WHERE, + column=dimension, + operator=Operator.GREATER_THAN_OR_EQUAL, + value=start, + ), + Filter( + type=PredicateType.WHERE, + column=dimension, + operator=Operator.LESS_THAN, + value=end, + ), + } + + # Map QueryObject operators to semantic layer operators + operator_mapping = { + FilterOperator.EQUALS.value: Operator.EQUALS, + FilterOperator.NOT_EQUALS.value: Operator.NOT_EQUALS, + FilterOperator.GREATER_THAN.value: Operator.GREATER_THAN, + FilterOperator.LESS_THAN.value: Operator.LESS_THAN, + FilterOperator.GREATER_THAN_OR_EQUALS.value: Operator.GREATER_THAN_OR_EQUAL, + FilterOperator.LESS_THAN_OR_EQUALS.value: Operator.LESS_THAN_OR_EQUAL, + FilterOperator.IN.value: Operator.IN, + FilterOperator.NOT_IN.value: Operator.NOT_IN, + FilterOperator.LIKE.value: Operator.LIKE, + FilterOperator.NOT_LIKE.value: Operator.NOT_LIKE, + FilterOperator.IS_NULL.value: Operator.IS_NULL, + FilterOperator.IS_NOT_NULL.value: Operator.IS_NOT_NULL, + } + + operator = operator_mapping.get(operator_str) + if not operator: + # Unknown operator - create adhoc filter + return None + + return { + Filter( + type=PredicateType.WHERE, + column=dimension, + operator=operator, + value=value, + ) + } + + +def _get_order_from_query_object( + query_object: ValidatedQueryObject, + all_metrics: dict[str, Metric], + all_dimensions: dict[str, Dimension], +) -> list[OrderTuple]: + order: list[OrderTuple] = [] + for element, ascending in query_object.orderby: + direction = OrderDirection.ASC if ascending else OrderDirection.DESC + + # adhoc + if isinstance(element, dict): + if element["sqlExpression"] is not None: + order.append( + ( + AdhocExpression( + id=element["label"] or element["sqlExpression"], + definition=element["sqlExpression"], + ), + direction, + ) + ) + elif element in all_dimensions: + order.append((all_dimensions[element], direction)) + elif element in all_metrics: + order.append((all_metrics[element], direction)) + + return order + + +def _get_group_limit_from_query_object( + query_object: ValidatedQueryObject, + all_metrics: dict[str, Metric], + all_dimensions: dict[str, Dimension], +) -> GroupLimit | None: + # no limit + if query_object.series_limit == 0 or not query_object.columns: + return None + + dimensions = [all_dimensions[dim_id] for dim_id in query_object.series_columns] + top = query_object.series_limit + metric = ( + all_metrics[query_object.series_limit_metric] + if query_object.series_limit_metric + else None + ) + direction = OrderDirection.DESC if query_object.order_desc else OrderDirection.ASC + group_others = query_object.group_others_when_limit_reached + + # Check if we need separate filters for the group limit subquery + # This happens when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm + group_limit_filters = _get_group_limit_filters(query_object, all_dimensions) + + return GroupLimit( + dimensions=dimensions, + top=top, + metric=metric, + direction=direction, + group_others=group_others, + filters=group_limit_filters, + ) + + +def _get_group_limit_filters( + query_object: ValidatedQueryObject, + all_dimensions: dict[str, Dimension], +) -> set[Filter | AdhocFilter] | None: + """ + Get separate filters for the group limit subquery if needed. + + This is used when inner_from_dttm/inner_to_dttm differ from from_dttm/to_dttm, + which happens during time comparison queries. The group limit subquery may need + different time bounds to determine the top N groups. + + Returns None if the group limit should use the same filters as the main query. + """ + # Check if inner time bounds are explicitly set and differ from outer bounds + if ( + query_object.inner_from_dttm is None + or query_object.inner_to_dttm is None + or ( + query_object.inner_from_dttm == query_object.from_dttm + and query_object.inner_to_dttm == query_object.to_dttm + ) + ): + # No separate bounds needed - use the same filters as the main query + return None + + # Create separate filters for the group limit subquery + filters: set[Filter | AdhocFilter] = set() + + # Add time range filter using inner bounds + if query_object.granularity: + time_dimension = all_dimensions.get(query_object.granularity) + if ( + time_dimension + and query_object.inner_from_dttm + and query_object.inner_to_dttm + ): + filters.update( + { + Filter( + type=PredicateType.WHERE, + column=time_dimension, + operator=Operator.GREATER_THAN_OR_EQUAL, + value=query_object.inner_from_dttm, + ), + Filter( + type=PredicateType.WHERE, + column=time_dimension, + operator=Operator.LESS_THAN, + value=query_object.inner_to_dttm, + ), + } + ) + + # Add fetch values predicate if present + if ( + query_object.apply_fetch_values_predicate + and query_object.datasource.fetch_values_predicate + ): + filters.add( + AdhocFilter( + type=PredicateType.WHERE, + definition=query_object.datasource.fetch_values_predicate, + ) + ) + + # Add filters from query_object.extras (WHERE and HAVING clauses) + extras_filters = _get_filters_from_extras(query_object.extras) + filters.update(extras_filters) + + # Add all other non-temporal filters from query_object.filter + for filter_ in query_object.filter: + # Skip temporal range filters - we're using inner bounds instead + if ( + filter_.get("op") == FilterOperator.TEMPORAL_RANGE.value + and query_object.granularity + ): + continue + + if converted_filters := _convert_query_object_filter(filter_, all_dimensions): + filters.update(converted_filters) + + return filters if filters else None + + +def _convert_time_grain(time_grain: str) -> type[Grain] | None: + """ + Convert a time grain string from the query object to a Grain class. + """ + mapping = { + grain.representation: grain + for grain in [ + Second, + Minute, + Hour, + Day, + Week, + Month, + Quarter, + Year, + ] + } + + return mapping.get(time_grain) + + +def validate_query_object( + query_object: QueryObject, +) -> TypeGuard[ValidatedQueryObject]: + """ + Validate that the `QueryObject` is compatible with the `SemanticView`. + + If some semantic view implementation supports these features we should add an + attribute to the `SemanticViewImplementation` to indicate support for them. + """ + if not query_object.datasource: + return False + + query_object = cast(ValidatedQueryObject, query_object) + + _validate_metrics(query_object) + _validate_dimensions(query_object) + _validate_filters(query_object) + _validate_granularity(query_object) + _validate_group_limit(query_object) + _validate_orderby(query_object) + + return True + + +def _validate_metrics(query_object: ValidatedQueryObject) -> None: + """ + Make sure metrics are defined in the semantic view. + """ + semantic_view = query_object.datasource.implementation + + if any(not isinstance(metric, str) for metric in (query_object.metrics or [])): + raise ValueError("Adhoc metrics are not supported in Semantic Views.") + + metric_names = {metric.name for metric in semantic_view.metrics} + if not set(query_object.metrics or []) <= metric_names: + raise ValueError("All metrics must be defined in the Semantic View.") + + +def _validate_dimensions(query_object: ValidatedQueryObject) -> None: + """ + Make sure all dimensions are defined in the semantic view. + """ + semantic_view = query_object.datasource.implementation + dimension_names = {dimension.name for dimension in semantic_view.dimensions} + + # Normalize all columns to dimension names + normalized_columns = [ + _normalize_column(column, dimension_names) for column in query_object.columns + ] + + if not set(normalized_columns) <= dimension_names: + raise ValueError("All dimensions must be defined in the Semantic View.") + + +def _validate_filters(query_object: ValidatedQueryObject) -> None: + """ + Make sure all filters are valid. + """ + for filter_ in query_object.filter: + if isinstance(filter_["col"], dict): + raise ValueError( + "Adhoc columns are not supported in Semantic View filters." + ) + if not filter_.get("op"): + raise ValueError("All filters must have an operator defined.") + + +def _validate_granularity(query_object: ValidatedQueryObject) -> None: + """ + Make sure time column and time grain are valid. + """ + semantic_view = query_object.datasource.implementation + dimension_names = {dimension.name for dimension in semantic_view.dimensions} + + if time_column := query_object.granularity: + if time_column not in dimension_names: + raise ValueError( + "The time column must be defined in the Semantic View dimensions." + ) + + if time_grain := query_object.extras.get("time_grain_sqla"): + if not time_column: + raise ValueError( + "A time column must be specified when a time grain is provided." + ) + + supported_time_grains = { + dimension.grain + for dimension in semantic_view.dimensions + if dimension.name == time_column and dimension.grain + } + if _convert_time_grain(time_grain) not in supported_time_grains: + raise ValueError( + "The time grain is not supported for the time column in the " + "Semantic View." + ) + + +def _validate_group_limit(query_object: ValidatedQueryObject) -> None: + """ + Validate group limit related features in the query object. + """ + semantic_view = query_object.datasource.implementation + + # no limit + if query_object.series_limit == 0: + return + + if ( + query_object.series_columns + and SemanticViewFeature.GROUP_LIMIT not in semantic_view.features + ): + raise ValueError("Group limit is not supported in this Semantic View.") + + if any(not isinstance(col, str) for col in query_object.series_columns): + raise ValueError("Adhoc dimensions are not supported in series columns.") + + metric_names = {metric.name for metric in semantic_view.metrics} + if query_object.series_limit_metric and ( + not isinstance(query_object.series_limit_metric, str) + or query_object.series_limit_metric not in metric_names + ): + raise ValueError( + "The series limit metric must be defined in the Semantic View." + ) + + dimension_names = {dimension.name for dimension in semantic_view.dimensions} + if not set(query_object.series_columns) <= dimension_names: + raise ValueError("All series columns must be defined in the Semantic View.") + + if ( + query_object.group_others_when_limit_reached + and SemanticViewFeature.GROUP_OTHERS not in semantic_view.features + ): + raise ValueError( + "Grouping others when limit is reached is not supported in this Semantic " + "View." + ) + + +def _validate_orderby(query_object: ValidatedQueryObject) -> None: + """ + Validate order by elements in the query object. + """ + semantic_view = query_object.datasource.implementation + + if ( + any(not isinstance(element, str) for element, _ in query_object.orderby) + and SemanticViewFeature.ADHOC_EXPRESSIONS_IN_ORDERBY + not in semantic_view.features + ): + raise ValueError( + "Adhoc expressions in order by are not supported in this Semantic View." + ) + + elements = {orderby[0] for orderby in query_object.orderby} + metric_names = {metric.name for metric in semantic_view.metrics} + dimension_names = {dimension.name for dimension in semantic_view.dimensions} + if not elements <= metric_names | dimension_names: + raise ValueError("All order by elements must be defined in the Semantic View.") diff --git a/superset/semantic_layers/models.py b/superset/semantic_layers/models.py new file mode 100644 index 00000000000..5615d0f30c5 --- /dev/null +++ b/superset/semantic_layers/models.py @@ -0,0 +1,373 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""Semantic layer models.""" + +from __future__ import annotations + +import uuid +from collections.abc import Hashable +from dataclasses import dataclass +from functools import cached_property +from typing import Any, TYPE_CHECKING + +from flask_appbuilder import Model +from sqlalchemy import Column, ForeignKey, Integer, String, Text +from sqlalchemy.orm import relationship +from sqlalchemy_utils import UUIDType +from sqlalchemy_utils.types.json import JSONType + +from superset.common.query_object import QueryObject +from superset.explorables.base import TimeGrainDict +from superset.extensions import encrypted_field_factory +from superset.models.helpers import AuditMixinNullable, QueryResult +from superset.semantic_layers.mapper import get_results +from superset.semantic_layers.registry import get_semantic_layer +from superset.semantic_layers.types import ( + BINARY, + BOOLEAN, + DATE, + DATETIME, + DECIMAL, + INTEGER, + INTERVAL, + NUMBER, + OBJECT, + SemanticLayerImplementation, + SemanticViewImplementation, + STRING, + TIME, + Type, +) +from superset.utils import json +from superset.utils.core import GenericDataType + +if TYPE_CHECKING: + from superset.superset_typing import ExplorableData, QueryObjectDict + + +def get_column_type(semantic_type: type[Type]) -> GenericDataType: + """ + Map semantic layer types to generic data types. + """ + if semantic_type in {DATE, DATETIME, TIME}: + return GenericDataType.TEMPORAL + if semantic_type in {INTEGER, NUMBER, DECIMAL, INTERVAL}: + return GenericDataType.NUMERIC + if semantic_type is BOOLEAN: + return GenericDataType.BOOLEAN + if semantic_type in {STRING, OBJECT, BINARY}: + return GenericDataType.STRING + return GenericDataType.STRING + + +@dataclass(frozen=True) +class MetricMetadata: + metric_name: str + expression: str + verbose_name: str | None = None + description: str | None = None + d3format: str | None = None + currency: dict[str, Any] | None = None + warning_text: str | None = None + certified_by: str | None = None + certification_details: str | None = None + + +@dataclass(frozen=True) +class ColumnMetadata: + column_name: str + type: str + is_dttm: bool + verbose_name: str | None = None + description: str | None = None + groupby: bool = True + filterable: bool = True + expression: str | None = None + python_date_format: str | None = None + advanced_data_type: str | None = None + extra: str | None = None + + +class SemanticLayer(AuditMixinNullable, Model): + """ + Semantic layer model. + + A semantic layer provides an abstraction over data sources, + allowing users to query data through a semantic interface. + """ + + __tablename__ = "semantic_layers" + + uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4) + + # Core fields + name = Column(String(250), nullable=False) + description = Column(Text, nullable=True) + type = Column(String(250), nullable=False) # snowflake, etc + + configuration = Column(encrypted_field_factory.create(JSONType), default=dict) + cache_timeout = Column(Integer, nullable=True) + + # Semantic views relationship + semantic_views: list[SemanticView] = relationship( + "SemanticView", + back_populates="semantic_layer", + cascade="all, delete-orphan", + passive_deletes=True, + ) + + def __repr__(self) -> str: + return self.name or str(self.uuid) + + @cached_property + def implementation( + self, + ) -> SemanticLayerImplementation[Any, SemanticViewImplementation]: + """ + Return semantic layer implementation. + """ + implementation_class = get_semantic_layer(self.type) + + if not issubclass(implementation_class, SemanticLayerImplementation): + raise TypeError( + f"Semantic layer type '{self.type}' " + "must be a subclass of SemanticLayerImplementation" + ) + + return implementation_class.from_configuration(json.loads(self.configuration)) + + +class SemanticView(AuditMixinNullable, Model): + """ + Semantic view model. + + A semantic view represents a queryable view within a semantic layer. + """ + + __tablename__ = "semantic_views" + + uuid = Column(UUIDType(binary=True), primary_key=True, default=uuid.uuid4) + + # Core fields + name = Column(String(250), nullable=False) + description = Column(Text, nullable=True) + + configuration = Column(encrypted_field_factory.create(JSONType), default=dict) + cache_timeout = Column(Integer, nullable=True) + + # Semantic layer relationship + semantic_layer_uuid = Column( + UUIDType(binary=True), + ForeignKey("semantic_layers.uuid", ondelete="CASCADE"), + nullable=False, + ) + semantic_layer: SemanticLayer = relationship( + "SemanticLayer", + back_populates="semantic_views", + foreign_keys=[semantic_layer_uuid], + ) + + def __repr__(self) -> str: + return self.name or str(self.uuid) + + @cached_property + def implementation(self) -> SemanticViewImplementation: + """ + Return semantic view implementation. + """ + return self.semantic_layer.implementation.get_semantic_view( + self.name, + json.loads(self.configuration), + ) + + # ========================================================================= + # Explorable protocol implementation + # ========================================================================= + + def get_query_result(self, query_object: QueryObject) -> QueryResult: + return get_results(query_object) + + def get_query_str(self, query_obj: QueryObjectDict) -> str: + return "Not implemented for semantic layers" + + @property + def uid(self) -> str: + return self.implementation.uid() + + @property + def type(self) -> str: + return "semantic_view" + + @property + def metrics(self) -> list[MetricMetadata]: + return [ + MetricMetadata( + metric_name=metric.name, + expression=metric.definition or "", + description=metric.description, + ) + for metric in self.implementation.get_metrics() + ] + + @property + def columns(self) -> list[ColumnMetadata]: + return [ + ColumnMetadata( + column_name=dimension.name, + type=dimension.type.__name__, + is_dttm=dimension.type in {DATE, TIME, DATETIME}, + description=dimension.description, + expression=dimension.definition, + extra=json.dumps({"grain": dimension.grain}), + ) + for dimension in self.implementation.get_dimensions() + ] + + @property + def column_names(self) -> list[str]: + return [dimension.name for dimension in self.implementation.get_dimensions()] + + @property + def data(self) -> ExplorableData: + return { + # core + "id": self.uuid.hex, + "uid": self.uid, + "type": "semantic_view", + "name": self.name, + "columns": [ + { + "advanced_data_type": None, + "certification_details": None, + "certified_by": None, + "column_name": dimension.name, + "description": dimension.description, + "expression": dimension.definition, + "filterable": True, + "groupby": True, + "id": None, + "uuid": None, + "is_certified": False, + "is_dttm": dimension.type in {DATE, TIME, DATETIME}, + "python_date_format": None, + "type": dimension.type.__name__, + "type_generic": get_column_type(dimension.type), + "verbose_name": None, + "warning_markdown": None, + } + for dimension in self.implementation.get_dimensions() + ], + "metrics": [ + { + "certification_details": None, + "certified_by": None, + "d3format": None, + "description": metric.description, + "expression": metric.definition, + "id": None, + "uuid": None, + "is_certified": False, + "metric_name": metric.name, + "warning_markdown": None, + "warning_text": None, + "verbose_name": None, + } + for metric in self.implementation.get_metrics() + ], + "database": {}, + # UI features + "verbose_map": {}, + "order_by_choices": [], + "filter_select": True, + "filter_select_enabled": True, + "sql": None, + "select_star": None, + "owners": [], + "description": self.description, + "table_name": self.name, + "column_types": [ + get_column_type(dimension.type) + for dimension in self.implementation.get_dimensions() + ], + "column_names": [ + dimension.name for dimension in self.implementation.get_dimensions() + ], + # rare + "column_formats": {}, + "datasource_name": self.name, + "perm": self.perm, + "offset": None, + "cache_timeout": self.cache_timeout, + "params": None, + # sql-specific + "schema": None, + "catalog": None, + "main_dttm_col": None, + "time_grain_sqla": [], + "granularity_sqla": [], + "fetch_values_predicate": None, + "template_params": None, + "is_sqllab_view": False, + "extra": None, + "always_filter_main_dttm": False, + "normalize_columns": False, + # TODO XXX + # "owners": [owner.id for owner in self.owners], + "edit_url": "", + "default_endpoint": None, + "folders": [], + "health_check_message": None, + } + + def get_extra_cache_keys(self, query_obj: QueryObjectDict) -> list[Hashable]: + return [] + + @property + def perm(self) -> str: + return self.semantic_layer_uuid.hex + "::" + self.uuid.hex + + @property + def offset(self) -> int: + # always return datetime as UTC + return 0 + + @property + def get_time_grains(self) -> list[TimeGrainDict]: + return [ + { + "name": dimension.grain.name, + "function": "", + "duration": dimension.grain.representation, + } + for dimension in self.implementation.get_dimensions() + if dimension.grain + ] + + def has_drill_by_columns(self, column_names: list[str]) -> bool: + dimension_names = { + dimension.name for dimension in self.implementation.get_dimensions() + } + return all(column_name in dimension_names for column_name in column_names) + + @property + def is_rls_supported(self) -> bool: + return False + + @property + def query_language(self) -> str | None: + return None diff --git a/superset/semantic_layers/registry.py b/superset/semantic_layers/registry.py new file mode 100644 index 00000000000..182a7575ff6 --- /dev/null +++ b/superset/semantic_layers/registry.py @@ -0,0 +1,132 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +""" +Semantic layer registry. + +This module provides a registry for semantic layer implementations that can be +populated from: +1. Standard Python entry points (for pip-installed packages) +2. Superset extensions (for .supx bundles) +""" + +from __future__ import annotations + +import logging +from importlib.metadata import entry_points +from typing import Any, TYPE_CHECKING + +if TYPE_CHECKING: + from superset.semantic_layers.types import SemanticLayerImplementation + +logger = logging.getLogger(__name__) + +ENTRY_POINT_GROUP = "superset.semantic_layers" + +# Registry mapping semantic layer type names to implementation classes +_semantic_layer_registry: dict[str, type["SemanticLayerImplementation[Any, Any]"]] = {} +_initialized_from_entry_points = False + + +def _init_from_entry_points() -> None: + """ + Pre-populate the registry from installed packages' entry points. + + This is called lazily on first access to ensure all packages are loaded. + """ + global _initialized_from_entry_points + if _initialized_from_entry_points: + return + + for ep in entry_points(group=ENTRY_POINT_GROUP): + if ep.name not in _semantic_layer_registry: + try: + _semantic_layer_registry[ep.name] = ep.load() + logger.info( + "Registered semantic layer '%s' from entry point %s", + ep.name, + ep.value, + ) + except Exception: + logger.exception( + "Failed to load semantic layer '%s' from entry point %s", + ep.name, + ep.value, + ) + + _initialized_from_entry_points = True + + +def register_semantic_layer( + name: str, + cls: "type[SemanticLayerImplementation[Any, Any]]", +) -> None: + """ + Register a semantic layer implementation. + + This is called by extensions to register their semantic layer implementations. + + Args: + name: The type name for the semantic layer (e.g., "snowflake") + cls: The implementation class + """ + if name in _semantic_layer_registry: + logger.warning( + "Semantic layer '%s' already registered, overwriting with %s", + name, + cls, + ) + _semantic_layer_registry[name] = cls + logger.info("Registered semantic layer '%s' from extension: %s", name, cls) + + +def get_semantic_layer(name: str) -> "type[SemanticLayerImplementation[Any, Any]]": + """ + Get a semantic layer implementation by name. + + Args: + name: The type name for the semantic layer (e.g., "snowflake") + + Returns: + The implementation class + + Raises: + KeyError: If no implementation is registered for the given name + """ + _init_from_entry_points() + + if name not in _semantic_layer_registry: + available = ", ".join(_semantic_layer_registry.keys()) or "(none)" + raise KeyError( + f"No semantic layer implementation registered for type '{name}'. " + f"Available types: {available}" + ) + + return _semantic_layer_registry[name] + + +def get_registered_semantic_layers() -> ( + "dict[str, type[SemanticLayerImplementation[Any, Any]]]" +): + """ + Get all registered semantic layer implementations. + + Returns: + A dictionary mapping type names to implementation classes + """ + _init_from_entry_points() + return dict(_semantic_layer_registry) diff --git a/superset/semantic_layers/types.py b/superset/semantic_layers/types.py new file mode 100644 index 00000000000..826ee50666c --- /dev/null +++ b/superset/semantic_layers/types.py @@ -0,0 +1,497 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import enum +from dataclasses import dataclass +from datetime import date, datetime, time, timedelta +from functools import total_ordering +from typing import Any, Protocol, runtime_checkable, TypeVar + +from pandas import DataFrame +from pydantic import BaseModel + +__all__ = [ + "BINARY", + "BOOLEAN", + "DATE", + "DATETIME", + "DECIMAL", + "Day", + "Dimension", + "Hour", + "INTEGER", + "INTERVAL", + "Minute", + "Month", + "NUMBER", + "OBJECT", + "Quarter", + "Second", + "STRING", + "TIME", + "Week", + "Year", +] + + +class Type: + """ + Base class for types. + """ + + +class INTEGER(Type): + """ + Represents an integer type. + """ + + +class NUMBER(Type): + """ + Represents a number type. + """ + + +class DECIMAL(Type): + """ + Represents a decimal type. + """ + + +class STRING(Type): + """ + Represents a string type. + """ + + +class BOOLEAN(Type): + """ + Represents a boolean type. + """ + + +class DATE(Type): + """ + Represents a date type. + """ + + +class TIME(Type): + """ + Represents a time type. + """ + + +class DATETIME(DATE, TIME): + """ + Represents a datetime type. + """ + + +class INTERVAL(Type): + """ + Represents an interval type. + """ + + +class OBJECT(Type): + """ + Represents an object type. + """ + + +class BINARY(Type): + """ + Represents a binary type. + """ + + +@dataclass(frozen=True) +@total_ordering +class Grain: + """ + Base class for time and date grains with comparison support. + + Attributes: + name: Human-readable name of the grain (e.g., "Second") + representation: ISO 8601 representation (e.g., "PT1S") + value: Time period as a timedelta + """ + + name: str + representation: str + value: timedelta + + def __eq__(self, other: object) -> bool: + if isinstance(other, Grain): + return self.value == other.value + return NotImplemented + + def __lt__(self, other: object) -> bool: + if isinstance(other, Grain): + return self.value < other.value + return NotImplemented + + def __hash__(self) -> int: + return hash((self.name, self.representation, self.value)) + + +class Second(Grain): + name = "Second" + representation = "PT1S" + value = timedelta(seconds=1) + + +class Minute(Grain): + name = "Minute" + representation = "PT1M" + value = timedelta(minutes=1) + + +class Hour(Grain): + name = "Hour" + representation = "PT1H" + value = timedelta(hours=1) + + +class Day(Grain): + name = "Day" + representation = "P1D" + value = timedelta(days=1) + + +class Week(Grain): + name = "Week" + representation = "P1W" + value = timedelta(weeks=1) + + +class Month(Grain): + name = "Month" + representation = "P1M" + value = timedelta(days=30) + + +class Quarter(Grain): + name = "Quarter" + representation = "P3M" + value = timedelta(days=90) + + +class Year(Grain): + name = "Year" + representation = "P1Y" + value = timedelta(days=365) + + +@dataclass(frozen=True) +class Dimension: + id: str + name: str + type: type[Type] + + definition: str | None = None + description: str | None = None + grain: Grain | None = None + + +@dataclass(frozen=True) +class Metric: + id: str + name: str + type: type[Type] + + definition: str | None + description: str | None = None + + +@dataclass(frozen=True) +class AdhocExpression: + id: str + definition: str + + +class Operator(str, enum.Enum): + EQUALS = "=" + NOT_EQUALS = "!=" + GREATER_THAN = ">" + LESS_THAN = "<" + GREATER_THAN_OR_EQUAL = ">=" + LESS_THAN_OR_EQUAL = "<=" + IN = "IN" + NOT_IN = "NOT IN" + LIKE = "LIKE" + NOT_LIKE = "NOT LIKE" + IS_NULL = "IS NULL" + IS_NOT_NULL = "IS NOT NULL" + + +FilterValues = str | int | float | bool | datetime | date | time | timedelta | None + + +class PredicateType(enum.Enum): + WHERE = "WHERE" + HAVING = "HAVING" + + +@dataclass(frozen=True, order=True) +class Filter: + type: PredicateType + column: Dimension | Metric + operator: Operator + value: FilterValues | set[FilterValues] + + +@dataclass(frozen=True, order=True) +class AdhocFilter: + type: PredicateType + definition: str + + +class OrderDirection(enum.Enum): + ASC = "ASC" + DESC = "DESC" + + +OrderTuple = tuple[Metric | Dimension | AdhocExpression, OrderDirection] + + +@dataclass(frozen=True) +class GroupLimit: + """ + Limit query to top/bottom N combinations of specified dimensions. + + The `filters` parameter allows specifying separate filter constraints for the + group limit subquery. This is useful when you want to determine the top N groups + using different criteria (e.g., a different time range) than the main query. + + For example, you might want to find the top 10 products by sales over the last + 30 days, but then show daily sales for those products over the last 7 days. + """ + + dimensions: list[Dimension] + top: int + metric: Metric | None + direction: OrderDirection = OrderDirection.DESC + group_others: bool = False + filters: set[Filter | AdhocFilter] | None = None + + +@dataclass(frozen=True) +class SemanticRequest: + """ + Represents a request made to obtain semantic results. + + This could be a SQL query, an HTTP request, etc. + """ + + type: str + definition: str + + +@dataclass(frozen=True) +class SemanticResult: + """ + Represents the results of a semantic query. + + This includes any requests (SQL queries, HTTP requests) that were performed in order + to obtain the results, in order to help troubleshooting. + """ + + requests: list[SemanticRequest] + results: DataFrame + + +@dataclass(frozen=True) +class SemanticQuery: + """ + Represents a semantic query. + """ + + metrics: list[Metric] + dimensions: list[Dimension] + filters: set[Filter | AdhocFilter] | None = None + order: list[OrderTuple] | None = None + limit: int | None = None + offset: int | None = None + group_limit: GroupLimit | None = None + + +class SemanticViewFeature(enum.Enum): + """ + Custom features supported by semantic layers. + """ + + ADHOC_EXPRESSIONS_IN_ORDERBY = "ADHOC_EXPRESSIONS_IN_ORDERBY" + GROUP_LIMIT = "GROUP_LIMIT" + GROUP_OTHERS = "GROUP_OTHERS" + + +ConfigT = TypeVar("ConfigT", bound=BaseModel, contravariant=True) +SemanticViewT = TypeVar("SemanticViewT", bound="SemanticViewImplementation") + + +@runtime_checkable +class SemanticLayerImplementation(Protocol[ConfigT, SemanticViewT]): + """ + A protocol for semantic layers. + """ + + @classmethod + def from_configuration( + cls, + configuration: dict[str, Any], + ) -> SemanticLayerImplementation[ConfigT, SemanticViewT]: + """ + Create a semantic layer from its configuration. + """ + + @classmethod + def get_configuration_schema( + cls, + configuration: ConfigT | None = None, + ) -> dict[str, Any]: + """ + Get the JSON schema for the configuration needed to add the semantic layer. + + A partial configuration `configuration` can be sent to improve the schema, + allowing for progressive validation and better UX. For example, a semantic + layer might require: + + - auth information + - a database + + If the user provides the auth information, a client can send the partial + configuration to this method, and the resulting JSON schema would include + the list of databases the user has access to, allowing a dropdown to be + populated. + + The Snowflake semantic layer has an example implementation of this method, where + database and schema names are populated based on the provided connection info. + """ + + @classmethod + def get_runtime_schema( + cls, + configuration: ConfigT, + runtime_data: dict[str, Any] | None = None, + ) -> dict[str, Any]: + """ + Get the JSON schema for the runtime parameters needed to load semantic views. + + This returns the schema needed to connect to a semantic view given the + configuration for the semantic layer. For example, a semantic layer might + be configured by: + + - auth information + - an optional database + + If the user does not provide a database when creating the semantic layer, the + runtime schema would require the database name to be provided before loading any + semantic views. This allows users to create semantic layers that connect to a + specific database (or project, account, etc.), or that allow users to select it + at query time. + + The Snowflake semantic layer has an example implementation of this method, where + database and schema names are required if they were not provided in the initial + configuration. + """ + + def get_semantic_views( + self, + runtime_configuration: dict[str, Any], + ) -> set[SemanticViewT]: + """ + Get the semantic views available in the semantic layer. + + The runtime configuration can provide information like a given project or + schema, used to restrict the semantic views returned. + """ + + def get_semantic_view( + self, + name: str, + additional_configuration: dict[str, Any], + ) -> SemanticViewT: + """ + Get a specific semantic view by its name and additional configuration. + """ + + +@runtime_checkable +class SemanticViewImplementation(Protocol): + """ + A protocol for semantic views. + """ + + features: frozenset[SemanticViewFeature] + + def uid(self) -> str: + """ + Returns a unique identifier for the semantic view. + """ + + def get_dimensions(self) -> set[Dimension]: + """ + Get the dimensions defined in the semantic view. + """ + + def get_metrics(self) -> set[Metric]: + """ + Get the metrics defined in the semantic view. + """ + + def get_values( + self, + dimension: Dimension, + filters: set[Filter | AdhocFilter] | None = None, + ) -> SemanticResult: + """ + Return distinct values for a dimension. + """ + + def get_dataframe( + self, + metrics: list[Metric], + dimensions: list[Dimension], + filters: set[Filter | AdhocFilter] | None = None, + order: list[OrderTuple] | None = None, + limit: int | None = None, + offset: int | None = None, + *, + group_limit: GroupLimit | None = None, + ) -> SemanticResult: + """ + Execute a semantic query and return the results as a DataFrame. + """ + + def get_row_count( + self, + metrics: list[Metric], + dimensions: list[Dimension], + filters: set[Filter | AdhocFilter] | None = None, + order: list[OrderTuple] | None = None, + limit: int | None = None, + offset: int | None = None, + *, + group_limit: GroupLimit | None = None, + ) -> SemanticResult: + """ + Execute a query and return the number of rows the result would have. + """ diff --git a/superset/superset_typing.py b/superset/superset_typing.py index 4d409398d1c..537fc0cff65 100644 --- a/superset/superset_typing.py +++ b/superset/superset_typing.py @@ -57,6 +57,46 @@ class AdhocMetric(TypedDict, total=False): sqlExpression: str | None +class DatasetColumnData(TypedDict, total=False): + """Type for column metadata in ExplorableData datasets.""" + + advanced_data_type: str | None + certification_details: str | None + certified_by: str | None + column_name: str + description: str | None + expression: str | None + filterable: bool + groupby: bool + id: int | None + uuid: str | None + is_certified: bool + is_dttm: bool + python_date_format: str | None + type: str + type_generic: NotRequired["GenericDataType" | None] + verbose_name: str | None + warning_markdown: str | None + + +class DatasetMetricData(TypedDict, total=False): + """Type for metric metadata in ExplorableData datasets.""" + + certification_details: str | None + certified_by: str | None + currency: NotRequired[dict[str, Any]] + d3format: str | None + description: str | None + expression: str | None + id: int | None + uuid: str | None + is_certified: bool + metric_name: str + warning_markdown: str | None + warning_text: str | None + verbose_name: str | None + + class AdhocColumn(TypedDict, total=False): hasCustomLabel: bool | None label: str @@ -254,7 +294,7 @@ class ExplorableData(TypedDict, total=False): """ # Core fields from BaseDatasource.data - id: int + id: int | str # String for UUID-based explorables like SemanticView uid: str column_formats: dict[str, str | None] description: str | None @@ -268,14 +308,14 @@ class ExplorableData(TypedDict, total=False): type: str catalog: str | None schema: str | None - offset: int + offset: int | None cache_timeout: int | None params: str | None perm: str | None edit_url: str sql: str | None - columns: list[dict[str, Any]] - metrics: list[dict[str, Any]] + columns: list["DatasetColumnData"] + metrics: list["DatasetMetricData"] folders: Any # JSON field, can be list or dict order_by_choices: list[tuple[str, str]] owners: list[int] | list[dict[str, Any]] # Can be either format @@ -283,8 +323,8 @@ class ExplorableData(TypedDict, total=False): select_star: str | None # Additional fields from SqlaTable and data_for_slices - column_types: list[Any] - column_names: set[str] | set[Any] + column_types: list["GenericDataType"] + column_names: set[str] | list[str] granularity_sqla: list[tuple[Any, Any]] time_grain_sqla: list[tuple[Any, Any]] main_dttm_col: str | None diff --git a/superset/utils/core.py b/superset/utils/core.py index a5c69554559..5bfb867b472 100644 --- a/superset/utils/core.py +++ b/superset/utils/core.py @@ -96,7 +96,7 @@ from superset.exceptions import ( SupersetException, SupersetTimeoutException, ) -from superset.explorables.base import Explorable +from superset.explorables.base import ColumnMetadata, Explorable from superset.sql.parse import sanitize_clause from superset.superset_typing import ( AdhocColumn, @@ -115,7 +115,6 @@ from superset.utils.hashing import hash_from_dict, hash_from_str from superset.utils.pandas import detect_datetime_format if TYPE_CHECKING: - from superset.connectors.sqla.models import TableColumn from superset.models.core import Database logging.getLogger("MARKDOWN").setLevel(logging.INFO) @@ -200,6 +199,7 @@ class DatasourceType(StrEnum): QUERY = "query" SAVEDQUERY = "saved_query" VIEW = "view" + SEMANTIC_VIEW = "semantic_view" class LoggerLevel(StrEnum): @@ -1672,15 +1672,12 @@ def get_metric_type_from_column(column: Any, datasource: Explorable) -> str: :return: The inferred metric type as a string, or an empty string if the column is not a metric or no valid operation is found. """ - - from superset.connectors.sqla.models import SqlMetric - - metric: SqlMetric = next( - (metric for metric in datasource.metrics if metric.metric_name == column), - SqlMetric(metric_name=""), + metric = next( + (m for m in datasource.metrics if m.metric_name == column), + None, ) - if metric.metric_name == "": + if metric is None: return "" expression: str = metric.expression @@ -1725,18 +1722,18 @@ def extract_dataframe_dtypes( columns_by_name[column.column_name] = column generic_types: list[GenericDataType] = [] - for column in df.columns: - column_object = columns_by_name.get(column) - series = df[column] + for col_name in df.columns: + column_object = columns_by_name.get(str(col_name)) + series = df[col_name] inferred_type: str = "" if series.isna().all(): sql_type: Optional[str] = "" if datasource and hasattr(datasource, "columns_types"): - if column in datasource.columns_types: - sql_type = datasource.columns_types.get(column) + if col_name in datasource.columns_types: + sql_type = datasource.columns_types.get(col_name) inferred_type = map_sql_type_to_inferred_type(sql_type) else: - inferred_type = get_metric_type_from_column(column, datasource) + inferred_type = get_metric_type_from_column(col_name, datasource) else: inferred_type = infer_dtype(series) if isinstance(column_object, dict): @@ -1756,11 +1753,17 @@ def extract_dataframe_dtypes( return generic_types -def extract_column_dtype(col: TableColumn) -> GenericDataType: - if col.is_temporal: +def extract_column_dtype(col: "ColumnMetadata") -> GenericDataType: + # Check for temporal type + if hasattr(col, "is_temporal") and col.is_temporal: + return GenericDataType.TEMPORAL + if col.is_dttm: return GenericDataType.TEMPORAL - if col.is_numeric: + + # Check for numeric type + if hasattr(col, "is_numeric") and col.is_numeric: return GenericDataType.NUMERIC + # TODO: add check for boolean data type when proper support is added return GenericDataType.STRING @@ -1774,9 +1777,7 @@ def get_time_filter_status( applied_time_extras: dict[str, str], ) -> tuple[list[dict[str, str]], list[dict[str, str]]]: temporal_columns: set[Any] = { - (col.column_name if hasattr(col, "column_name") else col.get("column_name")) - for col in datasource.columns - if (col.is_dttm if hasattr(col, "is_dttm") else col.get("is_dttm")) + col.column_name for col in datasource.columns if col.is_dttm } applied: list[dict[str, str]] = [] rejected: list[dict[str, str]] = []
