dpgaspar commented on code in PR #35018:
URL: https://github.com/apache/superset/pull/35018#discussion_r2332753572
##########
superset/daos/base.py:
##########
@@ -251,3 +455,205 @@ def filter_by(cls, **filter_by: Any) -> list[T]:
cls.id_column_name, data_model
).apply(query, None)
return query.filter_by(**filter_by).all()
+
+ @classmethod
+ def apply_column_operators(
+ cls, query: Any, column_operators: Optional[List[ColumnOperator]] =
None
+ ) -> Any:
+ """
+ Apply column operators (list of ColumnOperator) to the query using
+ ColumnOperatorEnum logic. Raises ValueError if a filter references a
+ non-existent column.
+ """
+ if not column_operators:
+ return query
+ for c in column_operators:
+ if not isinstance(c, ColumnOperator):
+ continue
+ col = c.col
+ opr = c.opr
+ value = c.value
+ if not col or not hasattr(cls.model_cls, col):
+ model_name = cls.model_cls.__name__ if cls.model_cls else
"Unknown"
+ logging.error(
+ f"Invalid filter: column '{col}' does not exist on
{model_name}"
Review Comment:
nit: we should not use f strings when logging for performance reasons:
The f-string is evaluated immediately, even if the log level is higher
(e.g., INFO or WARNING) and the DEBUG message will never be logged.
This means you pay the cost of:
Building the string
Evaluating any expressions inside {}
Even if the message is dropped due to the log level.
let's use:
``` python
logging.error("Invalid filter: column '%s' does not exist on %s", col,
model_name)
```
##########
superset/daos/base.py:
##########
@@ -251,3 +455,205 @@ def filter_by(cls, **filter_by: Any) -> list[T]:
cls.id_column_name, data_model
).apply(query, None)
return query.filter_by(**filter_by).all()
+
+ @classmethod
+ def apply_column_operators(
+ cls, query: Any, column_operators: Optional[List[ColumnOperator]] =
None
+ ) -> Any:
+ """
+ Apply column operators (list of ColumnOperator) to the query using
+ ColumnOperatorEnum logic. Raises ValueError if a filter references a
+ non-existent column.
+ """
+ if not column_operators:
+ return query
+ for c in column_operators:
+ if not isinstance(c, ColumnOperator):
+ continue
+ col = c.col
+ opr = c.opr
+ value = c.value
+ if not col or not hasattr(cls.model_cls, col):
+ model_name = cls.model_cls.__name__ if cls.model_cls else
"Unknown"
+ logging.error(
+ f"Invalid filter: column '{col}' does not exist on
{model_name}"
+ )
+ raise ValueError(
+ f"Invalid filter: column '{col}' does not exist on
{model_name}"
+ )
+ column = getattr(cls.model_cls, col)
+ try:
+ # Always use ColumnOperatorEnum's apply method
+ operator_enum = ColumnOperatorEnum(opr)
+ query = query.filter(operator_enum.apply(column, value))
+ except Exception as e:
+ logging.error(f"Error applying filter on column '{col}': {e}")
+ raise
+ return query
+
+ @classmethod
+ def get_filterable_columns_and_operators(cls) -> Dict[str, List[str]]:
+ """
+ Returns a dict mapping filterable columns (including hybrid/computed
fields if
+ present) to their supported operators. Used by MCP tools to
dynamically expose
+ filter options. Custom fields supported by the DAO but not present on
the model
+ should be documented here.
+ """
+
+ mapper = inspect(cls.model_cls)
+ columns = {c.key: c for c in mapper.columns}
+ # Add hybrid properties
+ hybrids = {
+ name: attr
+ for name, attr in vars(cls.model_cls).items()
+ if isinstance(attr, hybrid_property)
+ }
+ # You may add custom fields here, e.g.:
+ # custom_fields = {"tags": ["eq", "in_", "like"], ...}
+ custom_fields: Dict[str, List[str]] = {}
+
+ filterable = {}
+ for name, col in columns.items():
+ if isinstance(col.type, (sa.String, sa.Text)):
+ filterable[name] = TYPE_OPERATOR_MAP["string"]
+ elif isinstance(col.type, (sa.Boolean,)):
+ filterable[name] = TYPE_OPERATOR_MAP["boolean"]
+ elif isinstance(col.type, (sa.Integer, sa.Float, sa.Numeric)):
+ filterable[name] = TYPE_OPERATOR_MAP["number"]
+ elif isinstance(col.type, (sa.DateTime, sa.Date, sa.Time)):
+ filterable[name] = TYPE_OPERATOR_MAP["datetime"]
+ else:
+ # Fallback to eq/ne/null
+ filterable[name] = ["eq", "ne", "is_null", "is_not_null"]
+ # Add hybrid properties as string fields by default
+ for name in hybrids:
+ filterable[name] = TYPE_OPERATOR_MAP["string"]
+ # Add custom fields
+ filterable.update(custom_fields)
+ return filterable
+
+ @classmethod
+ def _build_query(
+ cls,
+ column_operators: Optional[List[ColumnOperator]] = None,
+ search: Optional[str] = None,
+ search_columns: Optional[List[str]] = None,
+ custom_filters: Optional[Dict[str, BaseFilter]] = None,
+ skip_base_filter: bool = False,
+ data_model: Optional[SQLAInterface] = None,
+ ) -> Any:
+ """
+ Build a SQLAlchemy query with base filter, column operators, search,
and
+ custom filters.
+ """
+ if data_model is None:
+ data_model = SQLAInterface(cls.model_cls, db.session)
+ query = data_model.session.query(cls.model_cls)
+ query = cls._apply_base_filter(
+ query, skip_base_filter=skip_base_filter, data_model=data_model
+ )
+ if search and search_columns:
+ search_filters = []
+ for column_name in search_columns:
+ if hasattr(cls.model_cls, column_name):
+ column = getattr(cls.model_cls, column_name)
+ search_filters.append(cast(column,
Text).ilike(f"%{search}%"))
+ if search_filters:
+ query = query.filter(or_(*search_filters))
+ if custom_filters:
+ for filter_class in custom_filters.values():
+ query = filter_class.apply(query, None)
+ if column_operators:
+ query = cls.apply_column_operators(query, column_operators)
+ return query
+
+ @classmethod
+ def list( # noqa: C901
+ cls,
+ column_operators: Optional[List[ColumnOperator]] = None,
+ order_column: str = "changed_on",
+ order_direction: str = "desc",
+ page: int = 0,
+ page_size: int = 100,
+ search: Optional[str] = None,
+ search_columns: Optional[List[str]] = None,
+ custom_filters: Optional[Dict[str, BaseFilter]] = None,
+ columns: Optional[List[str]] = None,
+ ) -> Tuple[List[Any], int]:
+ """
+ Generic list method for filtered, sorted, and paginated results.
+ If columns is specified, returns a list of tuples (one per row),
+ otherwise returns model instances.
+ """
+ data_model = SQLAInterface(cls.model_cls, db.session)
+
+ column_attrs = []
+ relationship_loads = []
+ if columns is None:
+ columns = []
+ for name in columns:
+ attr = getattr(cls.model_cls, name, None)
+ if attr is None:
+ continue
+ prop = getattr(attr, "property", None)
+ if isinstance(prop, ColumnProperty):
+ column_attrs.append(attr)
+ elif isinstance(prop, RelationshipProperty):
+ relationship_loads.append(joinedload(attr))
Review Comment:
It seems we will support many-to-many relations here, double check if
pagination will work as expected on these cases.
do we have tests for relations?
##########
superset/daos/base.py:
##########
@@ -251,3 +455,205 @@ def filter_by(cls, **filter_by: Any) -> list[T]:
cls.id_column_name, data_model
).apply(query, None)
return query.filter_by(**filter_by).all()
+
+ @classmethod
+ def apply_column_operators(
+ cls, query: Any, column_operators: Optional[List[ColumnOperator]] =
None
+ ) -> Any:
+ """
+ Apply column operators (list of ColumnOperator) to the query using
+ ColumnOperatorEnum logic. Raises ValueError if a filter references a
+ non-existent column.
+ """
+ if not column_operators:
+ return query
+ for c in column_operators:
+ if not isinstance(c, ColumnOperator):
+ continue
+ col = c.col
+ opr = c.opr
+ value = c.value
+ if not col or not hasattr(cls.model_cls, col):
+ model_name = cls.model_cls.__name__ if cls.model_cls else
"Unknown"
+ logging.error(
+ f"Invalid filter: column '{col}' does not exist on
{model_name}"
+ )
+ raise ValueError(
+ f"Invalid filter: column '{col}' does not exist on
{model_name}"
+ )
+ column = getattr(cls.model_cls, col)
+ try:
+ # Always use ColumnOperatorEnum's apply method
+ operator_enum = ColumnOperatorEnum(opr)
+ query = query.filter(operator_enum.apply(column, value))
+ except Exception as e:
+ logging.error(f"Error applying filter on column '{col}': {e}")
Review Comment:
nit: let's use lazy formatting
##########
tests/integration_tests/dao/base_dao_test.py:
##########
@@ -0,0 +1,1397 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""
+Integration tests for BaseDAO functionality.
+
+This module contains comprehensive integration tests for the BaseDAO class and
its
+subclasses, covering database operations, CRUD methods, flexible column
support,
+column operators, and error handling.
+
+Tests use an in-memory SQLite database for isolation and to replicate the unit
test
+environment behavior. User model deletions are avoided due to circular
dependency
+constraints with self-referential foreign keys.
+"""
+
+import datetime
+import time
+import uuid
+
+import pytest
+from flask_appbuilder.models.filters import BaseFilter
+from flask_appbuilder.security.sqla.models import User
+from sqlalchemy import Column, DateTime, Integer, String
+from sqlalchemy.ext.declarative import declarative_base
+from sqlalchemy.orm.session import Session
+
+from superset.daos.base import BaseDAO, ColumnOperator, ColumnOperatorEnum
+from superset.daos.chart import ChartDAO
+from superset.daos.dashboard import DashboardDAO
+from superset.daos.database import DatabaseDAO
+from superset.daos.user import UserDAO
+from superset.extensions import db
+from superset.models.core import Database
+from superset.models.dashboard import Dashboard
+from superset.models.slice import Slice
+
+# Create a test model for comprehensive testing
+Base = declarative_base()
+
+
+class ExampleModel(Base): # type: ignore
+ __tablename__ = "example_model"
+ id = Column(Integer, primary_key=True)
+ uuid = Column(String(36), unique=True, nullable=False)
+ slug = Column(String(100), unique=True)
+ name = Column(String(100))
+ code = Column(String(50), unique=True)
+ created_on = Column(DateTime, default=datetime.datetime.utcnow)
+
+
+class ExampleModelDAO(BaseDAO[ExampleModel]):
+ model_cls = ExampleModel
+ id_column_name = "id"
+ base_filter = None
+
+
+class MockModel:
+ def __init__(self, id=1, name="test"):
+ self.id = id
+ self.name = name
+
+
+class TestDAO(BaseDAO[MockModel]):
+ model_cls = MockModel
+
+
[email protected](autouse=True)
+def mock_g_user(app_context):
+ """Mock the flask g.user for security context."""
+ # Within app context, we can safely mock g
+ from flask import g
+
+ mock_user = User()
+ mock_user.id = 1
+ mock_user.username = "test_user"
+
+ # Set g.user directly instead of patching
+ g.user = mock_user
+ yield
+
+ # Clean up
+ if hasattr(g, "user"):
+ delattr(g, "user")
+
+
+# =============================================================================
+# Integration Tests - These tests use the actual database
+# =============================================================================
+
+
+def test_column_operator_enum_complete_coverage(user_with_data: Session) ->
None:
+ """
+ Test that every single ColumnOperatorEnum operator is covered by tests.
+ This ensures we have comprehensive test coverage for all operators.
+ """
+ # Simply verify that we can create queries with all operators
+ for operator in ColumnOperatorEnum:
+ column_operator = ColumnOperator(
+ col="username", opr=operator, value="test_value"
+ )
+ # Just check it doesn't raise an error
+ assert column_operator.opr == operator
+
+
+def test_find_by_id_with_default_column(app_context: Session) -> None:
+ """Test find_by_id with default 'id' column."""
+ # Create a user to test with
+ user = User(
+ username="test_find_by_id",
+ first_name="Test",
+ last_name="User",
+ email="[email protected]",
+ active=True,
+ )
+ db.session.add(user)
+ db.session.commit()
+
+ # Find by numeric id
+ found = UserDAO.find_by_id(user.id, skip_base_filter=True)
+ assert found is not None
+ assert found.id == user.id
+ assert found.username == "test_find_by_id"
+
+ # Test with non-existent id
+ not_found = UserDAO.find_by_id(999999, skip_base_filter=True)
+ assert not_found is None
+
+
+def test_find_by_id_with_uuid_column(app_context: Session) -> None:
+ """Test find_by_id with custom uuid column."""
+ # Create a dashboard with uuid
+ dashboard = Dashboard(
+ dashboard_title="Test UUID Dashboard",
+ slug="test-uuid-dashboard",
+ published=True,
+ )
+ db.session.add(dashboard)
+ db.session.commit()
+
+ # Find by uuid string using the uuid column
+ found = DashboardDAO.find_by_id(
+ str(dashboard.uuid), id_column="uuid", skip_base_filter=True
+ )
+ assert found is not None
+ assert found.uuid == dashboard.uuid
+ assert found.dashboard_title == "Test UUID Dashboard"
+
+ # Find by numeric id (should still work)
+ found_by_id = DashboardDAO.find_by_id(dashboard.id, skip_base_filter=True)
+ assert found_by_id is not None
+ assert found_by_id.id == dashboard.id
+
+ # Test with non-existent uuid
+ not_found = DashboardDAO.find_by_id(str(uuid.uuid4()),
skip_base_filter=True)
+ assert not_found is None
+
+
+def test_find_by_id_with_slug_column(app_context: Session) -> None:
+ """Test find_by_id with slug column fallback."""
+ # Create a dashboard with slug
+ dashboard = Dashboard(
+ dashboard_title="Test Slug Dashboard",
+ slug="test-slug-dashboard",
+ published=True,
+ )
+ db.session.add(dashboard)
+ db.session.commit()
+
+ # Find by slug using the slug column
+ found = DashboardDAO.find_by_id(
+ "test-slug-dashboard", id_column="slug", skip_base_filter=True
+ )
+ assert found is not None
+ assert found.slug == "test-slug-dashboard"
+ assert found.dashboard_title == "Test Slug Dashboard"
+
+ # Test with non-existent slug
+ not_found = DashboardDAO.find_by_id("non-existent-slug",
skip_base_filter=True)
+ assert not_found is None
+
+
+def test_find_by_id_with_invalid_column(app_context: Session) -> None:
+ """Test find_by_id returns None when column doesn't exist."""
+ # This should return None gracefully
+ result = UserDAO.find_by_id("not_a_valid_id", skip_base_filter=True)
+ assert result is None
+
+
+def test_find_by_id_skip_base_filter(app_context: Session) -> None:
+ """Test find_by_id with skip_base_filter parameter."""
+ # Create users with different active states
+ active_user = User(
+ username="active_user",
+ first_name="Active",
+ last_name="User",
+ email="[email protected]",
+ active=True,
+ )
+ inactive_user = User(
+ username="inactive_user",
+ first_name="Inactive",
+ last_name="User",
+ email="[email protected]",
+ active=False,
+ )
+ db.session.add_all([active_user, inactive_user])
+ db.session.commit()
+
+ # Without skipping base filter (if one exists)
+ found_active = UserDAO.find_by_id(active_user.id, skip_base_filter=False)
+ assert found_active is not None
+
+ # With skipping base filter
+ found_active_skip = UserDAO.find_by_id(active_user.id,
skip_base_filter=True)
+ assert found_active_skip is not None
+
+ # Both should find the user since UserDAO might not have a base filter
+ assert found_active.id == active_user.id
+ assert found_active_skip.id == active_user.id
+
+
+def test_find_by_ids_with_default_column(app_context: Session) -> None:
+ """Test find_by_ids with default 'id' column."""
+ # Create multiple users
+ users = []
+ for i in range(3):
+ user = User(
+ username=f"test_find_by_ids_{i}",
+ first_name=f"Test{i}",
+ last_name="User",
+ email=f"test{i}@example.com",
+ active=True,
+ )
+ users.append(user)
+ db.session.add(user)
+ db.session.commit()
+
+ # Find by multiple ids
+ ids = [user.id for user in users]
+ found = UserDAO.find_by_ids(ids, skip_base_filter=True)
+ assert len(found) == 3
+ found_ids = [u.id for u in found]
+ assert set(found_ids) == set(ids)
+
+ # Test with mix of existent and non-existent ids
+ mixed_ids = [users[0].id, 999999, users[1].id]
+ found_mixed = UserDAO.find_by_ids(mixed_ids, skip_base_filter=True)
+ assert len(found_mixed) == 2
+
+ # Test with empty list
+ found_empty = UserDAO.find_by_ids([], skip_base_filter=True)
+ assert found_empty == []
+
+
+def test_find_by_ids_with_uuid_column(app_context: Session) -> None:
+ """Test find_by_ids with uuid column."""
+ # Create multiple dashboards
+ dashboards = []
+ for i in range(3):
+ dashboard = Dashboard(
+ dashboard_title=f"Test UUID Dashboard {i}",
+ slug=f"test-uuid-dashboard-{i}",
+ published=True,
+ )
+ dashboards.append(dashboard)
+ db.session.add(dashboard)
+ db.session.commit()
+
+ # Find by multiple uuids
+ uuids = [str(dashboard.uuid) for dashboard in dashboards]
+ found = DashboardDAO.find_by_ids(uuids, id_column="uuid",
skip_base_filter=True)
+ assert len(found) == 3
+ found_uuids = [str(d.uuid) for d in found]
+ assert set(found_uuids) == set(uuids)
+
+ # Test with mix of ids and uuids - search separately by column
+ found_by_id = DashboardDAO.find_by_ids([dashboards[0].id],
skip_base_filter=True)
+ found_by_uuid = DashboardDAO.find_by_ids(
+ [str(dashboards[1].uuid)], id_column="uuid", skip_base_filter=True
+ )
+ assert len(found_by_id) == 1
+ assert len(found_by_uuid) == 1
+
+
+def test_find_by_ids_with_slug_column(app_context: Session) -> None:
+ """Test find_by_ids with slug column."""
+ # Create multiple dashboards
+ dashboards = []
+ for i in range(3):
+ dashboard = Dashboard(
+ dashboard_title=f"Test Slug Dashboard {i}",
+ slug=f"test-slug-dashboard-{i}",
+ published=True,
+ )
+ dashboards.append(dashboard)
+ db.session.add(dashboard)
+ db.session.commit()
+
+ # Find by multiple slugs
+ slugs = [dashboard.slug for dashboard in dashboards]
+ found = DashboardDAO.find_by_ids(slugs, id_column="slug",
skip_base_filter=True)
+ assert len(found) == 3
+ found_slugs = [d.slug for d in found]
+ assert set(found_slugs) == set(slugs)
+
+
+def test_find_by_ids_with_invalid_column(app_context: Session) -> None:
+ """Test find_by_ids returns empty list when column doesn't exist."""
+ # This should return empty list gracefully
+ result = UserDAO.find_by_ids(["not_a_valid_id"], skip_base_filter=True)
+ assert result == []
+
+
+def test_find_by_ids_skip_base_filter(app_context: Session) -> None:
+ """Test find_by_ids with skip_base_filter parameter."""
+ # Create users
+ users = []
+ for i in range(3):
+ user = User(
+ username=f"test_skip_filter_{i}",
+ first_name=f"Test{i}",
+ last_name="User",
+ email=f"test{i}@example.com",
+ active=True,
+ )
+ users.append(user)
+ db.session.add(user)
+ db.session.commit()
+
+ ids = [user.id for user in users]
+
+ # Without skipping base filter
+ found_no_skip = UserDAO.find_by_ids(ids, skip_base_filter=False)
+ assert len(found_no_skip) == 3
+
+ # With skipping base filter
+ found_skip = UserDAO.find_by_ids(ids, skip_base_filter=True)
+ assert len(found_skip) == 3
+
+
+def test_base_dao_create_with_item(app_context: Session) -> None:
+ """Test BaseDAO.create with an item parameter."""
+ # Create a user item
+ user = User(
+ username="created_with_item",
+ first_name="Created",
+ last_name="Item",
+ email="[email protected]",
+ active=True,
+ )
+
+ # Create using the item
+ created = UserDAO.create(item=user)
+ assert created is not None
+ assert created.username == "created_with_item"
+ assert created.first_name == "Created"
+
+ # Verify it's in the session
+ assert created in db.session
+
+ # Commit and verify it persists
+ db.session.commit()
+
+ # Find it again to ensure it was saved
+ found = UserDAO.find_by_id(created.id, skip_base_filter=True)
+ assert found is not None
+ assert found.username == "created_with_item"
+
+
+def test_base_dao_create_with_attributes(app_context: Session) -> None:
+ """Test BaseDAO.create with attributes parameter."""
+ # Create using attributes dict
+ attributes = {
+ "username": "created_with_attrs",
+ "first_name": "Created",
+ "last_name": "Attrs",
+ "email": "[email protected]",
+ "active": True,
+ }
+
+ created = UserDAO.create(attributes=attributes)
+ assert created is not None
+ assert created.username == "created_with_attrs"
+ assert created.email == "[email protected]"
+
+ # Commit and verify
+ db.session.commit()
+ found = UserDAO.find_by_id(created.id, skip_base_filter=True)
+ assert found is not None
+ assert found.username == "created_with_attrs"
+
+
+def test_base_dao_create_with_both_item_and_attributes(app_context: Session)
-> None:
+ """Test BaseDAO.create with both item and attributes (override
behavior)."""
+ # Create a user item
+ user = User(
+ username="item_username",
+ first_name="Item",
+ last_name="User",
+ email="[email protected]",
+ active=False,
+ )
+
+ # Override some attributes
+ attributes = {
+ "username": "override_username",
+ "active": True,
+ }
+
+ created = UserDAO.create(item=user, attributes=attributes)
+ assert created is not None
+ assert created.username == "override_username" # Should be overridden
+ assert created.active is True # Should be overridden
+ assert created.first_name == "Item" # Should keep original
+ assert created.last_name == "User" # Should keep original
+
+ db.session.commit()
+
+
+def test_base_dao_update_with_item(app_context: Session) -> None:
+ """Test BaseDAO.update with an item parameter."""
+ # Create a user first
+ user = User(
+ username="update_test",
+ first_name="Original",
+ last_name="User",
+ email="[email protected]",
+ active=True,
+ )
+ db.session.add(user)
+ db.session.commit()
+
+ # Update the user
+ user.first_name = "Updated"
+ updated = UserDAO.update(item=user)
+ assert updated is not None
+ assert updated.first_name == "Updated"
+
+ db.session.commit()
+
+ # Verify the update persisted
+ found = UserDAO.find_by_id(user.id, skip_base_filter=True)
+ assert found is not None
+ assert found.first_name == "Updated"
+
+
+def test_base_dao_update_with_attributes(app_context: Session) -> None:
+ """Test BaseDAO.update with attributes parameter."""
+ # Create a user first
+ user = User(
+ username="update_attrs_test",
+ first_name="Original",
+ last_name="User",
+ email="[email protected]",
+ active=True,
+ )
+ db.session.add(user)
+ db.session.commit()
+
+ # Update using attributes
+ attributes = {"first_name": "Updated", "last_name": "Attr User"}
+ updated = UserDAO.update(item=user, attributes=attributes)
+ assert updated is not None
+ assert updated.first_name == "Updated"
+ assert updated.last_name == "Attr User"
+
+ db.session.commit()
+
+
+def test_base_dao_update_detached_item(app_context: Session) -> None:
+ """Test BaseDAO.update with a detached item."""
+ # Create a user first
+ user = User(
+ username="detached_test",
+ first_name="Original",
+ last_name="User",
+ email="[email protected]",
+ active=True,
+ )
+ db.session.add(user)
+ db.session.commit()
+
+ user_id = user.id
+
+ # Expunge to detach from session
+ db.session.expunge(user)
Review Comment:
interesting test!
##########
superset/daos/base.py:
##########
@@ -32,6 +51,100 @@
T = TypeVar("T", bound=Model)
+class ColumnOperatorEnum(str, Enum):
+ eq = "eq"
+ ne = "ne"
+ sw = "sw"
+ ew = "ew"
+ in_ = "in"
+ nin = "nin"
+ gt = "gt"
+ gte = "gte"
+ lt = "lt"
+ lte = "lte"
+ like = "like"
+ ilike = "ilike"
+ is_null = "is_null"
+ is_not_null = "is_not_null"
+
+ @classmethod
+ def operator_map(cls) -> Dict[ColumnOperatorEnum, Any]:
+ return {
+ cls.eq: lambda col, val: col == val,
+ cls.ne: lambda col, val: col != val,
+ cls.sw: lambda col, val: col.like(f"{val}%"),
+ cls.ew: lambda col, val: col.like(f"%{val}"),
+ cls.in_: lambda col, val: col.in_(
+ val if isinstance(val, (list, tuple)) else [val]
+ ),
+ cls.nin: lambda col, val: ~col.in_(
+ val if isinstance(val, (list, tuple)) else [val]
+ ),
+ cls.gt: lambda col, val: col > val,
+ cls.gte: lambda col, val: col >= val,
+ cls.lt: lambda col, val: col < val,
+ cls.lte: lambda col, val: col <= val,
+ cls.like: lambda col, val: col.like(f"%{val}%"),
+ cls.ilike: lambda col, val: col.ilike(f"%{val}%"),
+ cls.is_null: lambda col, _: col.is_(None),
+ cls.is_not_null: lambda col, _: col.isnot(None),
+ }
Review Comment:
WDYT about just making this a dict, would be slightly faster since we would
avoid a function call
##########
superset/daos/base.py:
##########
@@ -83,39 +196,138 @@ def find_by_id_or_uuid(
return None
@classmethod
- def find_by_id(
- cls,
- model_id: str | int,
- skip_base_filter: bool = False,
- ) -> T | None:
+ def _apply_base_filter(
+ cls, query: Any, skip_base_filter: bool = False, data_model: Any = None
+ ) -> Any:
"""
- Find a model by id, if defined applies `base_filter`
+ Apply the base_filter to the query if it exists and skip_base_filter
is False.
"""
- query = db.session.query(cls.model_cls)
if cls.base_filter and not skip_base_filter:
- data_model = SQLAInterface(cls.model_cls, db.session)
+ if data_model is None:
+ data_model = SQLAInterface(cls.model_cls, db.session)
query = cls.base_filter( # pylint: disable=not-callable
cls.id_column_name, data_model
).apply(query, None)
- id_column = getattr(cls.model_cls, cls.id_column_name)
+ return query
+
+ @classmethod
+ def _convert_value_for_column(cls, column: Any, value: Any) -> Any:
+ """
+ Convert a value to the appropriate type for a given SQLAlchemy column.
+
+ Args:
+ column: SQLAlchemy column object
+ value: Value to convert
+
+ Returns:
+ Converted value or None if conversion fails
+ """
+ if (
+ hasattr(column.type, "python_type")
+ and column.type.python_type == uuid_lib.UUID
+ ):
+ if isinstance(value, str):
+ try:
+ return uuid_lib.UUID(value)
+ except (ValueError, AttributeError):
+ return None
+ return value
+
+ @classmethod
+ def _find_by_column(
+ cls,
+ column_name: str,
+ value: str | int,
+ skip_base_filter: bool = False,
+ ) -> T | None:
+ """
+ Private method to find a model by any column value.
+
+ Args:
+ column_name: Name of the column to search by
+ value: Value to search for
+ skip_base_filter: Whether to skip base filtering
+
+ Returns:
+ Model instance or None if not found
+ """
+ query = db.session.query(cls.model_cls)
+ query = cls._apply_base_filter(query, skip_base_filter)
+
+ if not hasattr(cls.model_cls, column_name):
+ return None
+
+ column = getattr(cls.model_cls, column_name)
+ converted_value = cls._convert_value_for_column(column, value)
+ if converted_value is None:
+ return None
+
try:
- return query.filter(id_column == model_id).one_or_none()
+ return query.filter(column == converted_value).one_or_none()
except StatementError:
# can happen if int is passed instead of a string or similar
return None
+ @classmethod
+ def find_by_id(
+ cls,
+ model_id: str | int,
+ skip_base_filter: bool = False,
+ id_column: str | None = None,
+ ) -> T | None:
+ """
+ Find a model by ID using specified or default ID column.
+
+ Args:
+ model_id: ID value to search for
+ skip_base_filter: Whether to skip base filtering
+ id_column: Column name to use (defaults to cls.id_column_name)
+
+ Returns:
+ Model instance or None if not found
+ """
+ column = id_column or cls.id_column_name
+ return cls._find_by_column(column, model_id, skip_base_filter)
+
@classmethod
def find_by_ids(
cls,
- model_ids: list[str] | list[int],
+ model_ids: Sequence[str | int],
skip_base_filter: bool = False,
+ id_column: str | None = None,
) -> list[T]:
"""
Find a List of models by a list of ids, if defined applies
`base_filter`
+
+ :param model_ids: List of IDs to find
+ :param skip_base_filter: If true, skip applying the base filter
+ :param id_column: Optional column name to use for ID lookup
+ (defaults to id_column_name)
"""
- id_col = getattr(cls.model_cls, cls.id_column_name, None)
+ column = id_column or cls.id_column_name
+ id_col = getattr(cls.model_cls, column, None)
if id_col is None or not model_ids:
return []
+
+ # Convert IDs to appropriate types based on column type
+ converted_ids: list[str | int | uuid_lib.UUID] = []
+ for id_val in model_ids:
+ converted_value = cls._convert_value_for_column(id_col, id_val)
+ if converted_value is not None:
+ converted_ids.append(converted_value)
Review Comment:
nit: it seems we are assuming that all values are of the same type, should
we make sure this is always true?
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]