This is an automated email from the ASF dual-hosted git repository.
michaelsmolina pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new 6b25d0663e refactor: Migrates the MCP `execute_sql` tool to use the
SQL execution API (#36739)
6b25d0663e is described below
commit 6b25d0663ef2cec6a3b16ac9dbb56bf5888114cf
Author: Michael S. Molina <[email protected]>
AuthorDate: Mon Dec 22 09:48:28 2025 -0300
refactor: Migrates the MCP `execute_sql` tool to use the SQL execution API
(#36739)
Co-authored-by: codeant-ai-for-open-source[bot]
<244253245+codeant-ai-for-open-source[bot]@users.noreply.github.com>
---
superset-core/src/superset_core/api/models.py | 12 +-
superset/mcp_service/sql_lab/execute_sql_core.py | 221 ----------
superset/mcp_service/sql_lab/schemas.py | 42 +-
superset/mcp_service/sql_lab/sql_lab_utils.py | 247 -----------
superset/mcp_service/sql_lab/tool/execute_sql.py | 136 +++++-
superset/mcp_service/utils/schema_utils.py | 21 +-
.../mcp_service/sql_lab/test_sql_lab_utils.py | 137 ------
.../mcp_service/sql_lab/tool/test_execute_sql.py | 490 +++++++++++++--------
8 files changed, 480 insertions(+), 826 deletions(-)
diff --git a/superset-core/src/superset_core/api/models.py
b/superset-core/src/superset_core/api/models.py
index 59cb07dc38..346e8392f1 100644
--- a/superset-core/src/superset_core/api/models.py
+++ b/superset-core/src/superset_core/api/models.py
@@ -92,7 +92,11 @@ class Database(CoreModel):
"""
Execute SQL synchronously.
- :param sql: SQL query to execute
+ The SQL must be written in the dialect of the target database (e.g.,
+ PostgreSQL syntax for PostgreSQL databases, Snowflake syntax for
+ Snowflake, etc.). No automatic cross-dialect translation is performed.
+
+ :param sql: SQL query to execute (in the target database's dialect)
:param options: Query execution options (see `QueryOptions`).
If not provided, defaults are used.
:returns: QueryResult with status, data (DataFrame), and metadata
@@ -139,7 +143,11 @@ class Database(CoreModel):
Returns immediately with a handle for tracking progress and retrieving
results from the background worker.
- :param sql: SQL query to execute
+ The SQL must be written in the dialect of the target database (e.g.,
+ PostgreSQL syntax for PostgreSQL databases, Snowflake syntax for
+ Snowflake, etc.). No automatic cross-dialect translation is performed.
+
+ :param sql: SQL query to execute (in the target database's dialect)
:param options: Query execution options (see `QueryOptions`).
If not provided, defaults are used.
:returns: AsyncQueryHandle for tracking the query
diff --git a/superset/mcp_service/sql_lab/execute_sql_core.py
b/superset/mcp_service/sql_lab/execute_sql_core.py
deleted file mode 100644
index 263e012338..0000000000
--- a/superset/mcp_service/sql_lab/execute_sql_core.py
+++ /dev/null
@@ -1,221 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""
-Generic SQL execution core for MCP service.
-"""
-
-import logging
-from typing import Any
-
-from superset.mcp_service.mcp_core import BaseCore
-from superset.mcp_service.sql_lab.schemas import (
- ExecuteSqlRequest,
- ExecuteSqlResponse,
-)
-
-
-class ExecuteSqlCore(BaseCore):
- """
- Generic tool for executing SQL queries with security validation.
-
- This tool provides a high-level interface for SQL execution that can be
used
- by different MCP tools or other components. It handles:
- - Database access validation
- - SQL query validation (DML permissions, disallowed functions)
- - Parameter substitution
- - Query execution with timeout
- - Result formatting
-
- The tool can work in two modes:
- 1. Simple mode: Direct SQL execution using sql_lab_utils (default)
- 2. Command mode: Using ExecuteSqlCommand for full SQL Lab features
- """
-
- def __init__(
- self,
- use_command_mode: bool = False,
- logger: logging.Logger | None = None,
- ) -> None:
- super().__init__(logger)
- self.use_command_mode = use_command_mode
-
- def run_tool(self, request: ExecuteSqlRequest) -> ExecuteSqlResponse:
- """
- Execute SQL query and return results.
-
- Args:
- request: ExecuteSqlRequest with database_id, sql, and optional
parameters
-
- Returns:
- ExecuteSqlResponse with success status, results, or error
information
- """
- try:
- # Import inside method to avoid initialization issues
- from superset.mcp_service.sql_lab.sql_lab_utils import
check_database_access
-
- # Check database access
- database = check_database_access(request.database_id)
-
- if self.use_command_mode:
- # Use full SQL Lab command for complex queries
- return self._execute_with_command(request, database)
- else:
- # Use simplified execution for basic queries
- return self._execute_simple(request, database)
-
- except Exception as e:
- # Handle errors and return error response with proper error types
- self._log_error(e, "executing SQL")
- return self._handle_execution_error(e)
-
- def _execute_simple(
- self,
- request: ExecuteSqlRequest,
- database: Any,
- ) -> ExecuteSqlResponse:
- """Execute SQL using simplified sql_lab_utils."""
- # Import inside method to avoid initialization issues
- from superset.mcp_service.sql_lab.sql_lab_utils import
execute_sql_query
-
- results = execute_sql_query(
- database=database,
- sql=request.sql,
- schema=request.schema_name,
- limit=request.limit,
- timeout=request.timeout,
- parameters=request.parameters,
- )
-
- return ExecuteSqlResponse(
- success=True,
- rows=results.get("rows"),
- columns=results.get("columns"),
- row_count=results.get("row_count"),
- affected_rows=results.get("affected_rows"),
- query_id=None, # Not available in simple mode
- execution_time=results.get("execution_time"),
- error=None,
- error_type=None,
- )
-
- def _execute_with_command(
- self,
- request: ExecuteSqlRequest,
- database: Any,
- ) -> ExecuteSqlResponse:
- """Execute SQL using full SQL Lab command (not implemented yet)."""
- # This would use ExecuteSqlCommand for full SQL Lab features
- # Including query caching, async execution, complex parsing, etc.
- # For now, we'll fall back to simple execution
- self._log_info("Command mode not fully implemented, using simple mode")
- return self._execute_simple(request, database)
-
- # Future implementation would look like:
- # context = SqlJsonExecutionContext(
- # database_id=request.database_id,
- # sql=request.sql,
- # schema=request.schema_name,
- # limit=request.limit,
- # # ... other context fields
- # )
- #
- # command = ExecuteSqlCommand(
- # execution_context=context,
- # query_dao=QueryDAO(),
- # database_dao=DatabaseDAO(),
- # # ... other dependencies
- # )
- #
- # result = command.run()
- # return self._format_command_result(result)
-
- def _handle_execution_error(self, e: Exception) -> ExecuteSqlResponse:
- """Map exceptions to error responses."""
- error_type = self._get_error_type(e)
- return ExecuteSqlResponse(
- success=False,
- error=str(e),
- error_type=error_type,
- rows=None,
- columns=None,
- row_count=None,
- affected_rows=None,
- query_id=None,
- execution_time=None,
- )
-
- def _get_error_type(self, e: Exception) -> str:
- """Determine error type from exception."""
- # Import inside method to avoid initialization issues
- from superset.exceptions import (
- SupersetDisallowedSQLFunctionException,
- SupersetDMLNotAllowedException,
- SupersetErrorException,
- SupersetSecurityException,
- SupersetTimeoutException,
- )
-
- if isinstance(e, SupersetSecurityException):
- return "SECURITY_ERROR"
- elif isinstance(e, SupersetTimeoutException):
- return "TIMEOUT"
- elif isinstance(e, SupersetDMLNotAllowedException):
- return "DML_NOT_ALLOWED"
- elif isinstance(e, SupersetDisallowedSQLFunctionException):
- return "DISALLOWED_FUNCTION"
- elif isinstance(e, SupersetErrorException):
- return self._extract_superset_error_type(e)
- else:
- return "EXECUTION_ERROR"
-
- def _extract_superset_error_type(self, e: Exception) -> str:
- """Extract error type from SupersetErrorException."""
- if hasattr(e, "error") and hasattr(e.error, "error_type"):
- error_type_name = e.error.error_type.name
- # Map common error type patterns
- if "INVALID_PAYLOAD" in error_type_name:
- return "INVALID_PAYLOAD_FORMAT_ERROR"
- elif "DATABASE_NOT_FOUND" in error_type_name:
- return "DATABASE_NOT_FOUND_ERROR"
- elif "SECURITY" in error_type_name:
- return "SECURITY_ERROR"
- elif "TIMEOUT" in error_type_name:
- return "TIMEOUT"
- elif "DML_NOT_ALLOWED" in error_type_name:
- return "DML_NOT_ALLOWED"
- else:
- return error_type_name
- return "EXECUTION_ERROR"
-
- def _format_command_result(
- self, command_result: dict[str, Any]
- ) -> ExecuteSqlResponse:
- """Format ExecuteSqlCommand result into ExecuteSqlResponse."""
- # This would extract relevant fields from command result
- # Placeholder implementation for future use
- return ExecuteSqlResponse(
- success=command_result.get("success", False),
- rows=command_result.get("data"),
- columns=command_result.get("columns"),
- row_count=command_result.get("row_count"),
- affected_rows=command_result.get("affected_rows"),
- query_id=command_result.get("query_id"),
- execution_time=command_result.get("execution_time"),
- error=command_result.get("error"),
- error_type=command_result.get("error_type"),
- )
diff --git a/superset/mcp_service/sql_lab/schemas.py
b/superset/mcp_service/sql_lab/schemas.py
index fcfe7cb62b..571abd2a23 100644
--- a/superset/mcp_service/sql_lab/schemas.py
+++ b/superset/mcp_service/sql_lab/schemas.py
@@ -28,10 +28,14 @@ class ExecuteSqlRequest(BaseModel):
database_id: int = Field(
..., description="Database connection ID to execute query against"
)
- sql: str = Field(..., description="SQL query to execute")
+ sql: str = Field(
+ ...,
+ description="SQL query to execute (supports Jinja2 {{ var }} template
syntax)",
+ )
schema_name: str | None = Field(
None, description="Schema to use for query execution", alias="schema"
)
+ catalog: str | None = Field(None, description="Catalog name for query
execution")
limit: int = Field(
default=1000,
description="Maximum number of rows to return",
@@ -41,8 +45,21 @@ class ExecuteSqlRequest(BaseModel):
timeout: int = Field(
default=30, description="Query timeout in seconds", ge=1, le=300
)
- parameters: dict[str, Any] | None = Field(
- None, description="Parameters for query substitution"
+ template_params: dict[str, Any] | None = Field(
+ None, description="Jinja2 template parameters for SQL rendering"
+ )
+ dry_run: bool = Field(
+ default=False,
+ description="Return transformed SQL without executing (for debugging)",
+ )
+ force_refresh: bool = Field(
+ default=False,
+ description=(
+ "Bypass cache and re-execute query. "
+ "IMPORTANT: Only set to true when the user EXPLICITLY requests "
+ "fresh/updated data (e.g., 'refresh', 'get latest', 're-run'). "
+ "Default to false to reduce database load."
+ ),
)
@field_validator("sql")
@@ -61,11 +78,24 @@ class ColumnInfo(BaseModel):
is_nullable: bool | None = Field(None, description="Whether column allows
NULL")
+class StatementInfo(BaseModel):
+ """Information about a single SQL statement execution."""
+
+ original_sql: str = Field(..., description="Original SQL as submitted")
+ executed_sql: str = Field(
+ ..., description="SQL after transformations (RLS, mutations, limits)"
+ )
+ row_count: int = Field(..., description="Number of rows returned/affected")
+ execution_time_ms: float | None = Field(
+ None, description="Statement execution time in milliseconds"
+ )
+
+
class ExecuteSqlResponse(BaseModel):
"""Response schema for SQL execution results."""
success: bool = Field(..., description="Whether query executed
successfully")
- rows: Any | None = Field(
+ rows: list[dict[str, Any]] | None = Field(
None, description="Query result rows as list of dictionaries"
)
columns: list[ColumnInfo] | None = Field(
@@ -75,12 +105,14 @@ class ExecuteSqlResponse(BaseModel):
affected_rows: int | None = Field(
None, description="Number of rows affected (for DML queries)"
)
- query_id: str | None = Field(None, description="Query tracking ID")
execution_time: float | None = Field(
None, description="Query execution time in seconds"
)
error: str | None = Field(None, description="Error message if query
failed")
error_type: str | None = Field(None, description="Type of error if failed")
+ statements: list[StatementInfo] | None = Field(
+ None, description="Per-statement execution info (for multi-statement
queries)"
+ )
class OpenSqlLabRequest(BaseModel):
diff --git a/superset/mcp_service/sql_lab/sql_lab_utils.py
b/superset/mcp_service/sql_lab/sql_lab_utils.py
deleted file mode 100644
index 10c543c768..0000000000
--- a/superset/mcp_service/sql_lab/sql_lab_utils.py
+++ /dev/null
@@ -1,247 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""
-Utility functions for SQL Lab MCP tools.
-
-This module contains helper functions for SQL execution, validation,
-and database access that are shared across SQL Lab tools.
-"""
-
-import logging
-from typing import Any
-
-logger = logging.getLogger(__name__)
-
-
-def check_database_access(database_id: int) -> Any:
- """Check if user has access to the database."""
- # Import inside function to avoid initialization issues
- from superset import db, security_manager
- from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
- from superset.exceptions import SupersetErrorException,
SupersetSecurityException
- from superset.models.core import Database
-
- # Use session query to ensure relationships are loaded
- database = db.session.query(Database).filter_by(id=database_id).first()
-
- if not database:
- raise SupersetErrorException(
- SupersetError(
- message=f"Database with ID {database_id} not found",
- error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
- level=ErrorLevel.ERROR,
- )
- )
-
- # Check database access permissions
- if not security_manager.can_access_database(database):
- raise SupersetSecurityException(
- SupersetError(
- message=f"Access denied to database {database.database_name}",
- error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR,
- level=ErrorLevel.ERROR,
- )
- )
-
- return database
-
-
-def validate_sql_query(sql: str, database: Any) -> None:
- """Validate SQL query for security and syntax."""
- # Import inside function to avoid initialization issues
- from flask import current_app as app
-
- from superset.exceptions import (
- SupersetDisallowedSQLFunctionException,
- SupersetDMLNotAllowedException,
- )
- from superset.sql.parse import SQLScript
-
- # Use SQLScript for proper SQL parsing
- script = SQLScript(sql, database.db_engine_spec.engine)
-
- # Check for DML operations if not allowed
- if script.has_mutation() and not database.allow_dml:
- raise SupersetDMLNotAllowedException()
-
- # Check for disallowed functions from config
- disallowed_functions = app.config.get("DISALLOWED_SQL_FUNCTIONS", {}).get(
- database.db_engine_spec.engine,
- set(),
- )
- if disallowed_functions and
script.check_functions_present(disallowed_functions):
- raise SupersetDisallowedSQLFunctionException(disallowed_functions)
-
-
-def execute_sql_query(
- database: Any,
- sql: str,
- schema: str | None,
- limit: int,
- timeout: int,
- parameters: dict[str, Any] | None,
-) -> dict[str, Any]:
- """Execute SQL query and return results."""
- # Import inside function to avoid initialization issues
- from superset.utils.dates import now_as_float
-
- start_time = now_as_float()
-
- # Apply parameters and validate
- sql = _apply_parameters(sql, parameters)
- validate_sql_query(sql, database)
-
- # Apply limit for SELECT queries using SQLScript
- rendered_sql = _apply_limit(sql, limit, database)
-
- # Execute and get results
- results = _execute_query(database, rendered_sql, schema, limit)
-
- # Calculate execution time
- end_time = now_as_float()
- results["execution_time"] = end_time - start_time
-
- return results
-
-
-def _apply_parameters(sql: str, parameters: dict[str, Any] | None) -> str:
- """Apply parameters to SQL query."""
- # Import inside function to avoid initialization issues
- from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
- from superset.exceptions import SupersetErrorException
-
- if parameters:
- try:
- return sql.format(**parameters)
- except KeyError as e:
- raise SupersetErrorException(
- SupersetError(
- message=f"Missing parameter: {e}",
- error_type=SupersetErrorType.INVALID_PAYLOAD_FORMAT_ERROR,
- level=ErrorLevel.ERROR,
- )
- ) from e
- else:
- # Check if SQL contains placeholders when no parameters provided
- import re
-
- placeholders = re.findall(r"{(\w+)}", sql)
- if placeholders:
- raise SupersetErrorException(
- SupersetError(
- message=f"Missing parameter: {placeholders[0]}",
- error_type=SupersetErrorType.INVALID_PAYLOAD_FORMAT_ERROR,
- level=ErrorLevel.ERROR,
- )
- )
- return sql
-
-
-def _apply_limit(sql: str, limit: int, database: Any) -> str:
- """Apply limit to SELECT queries using SQLScript for proper parsing."""
- from superset.sql.parse import LimitMethod, SQLScript
-
- script = SQLScript(sql, database.db_engine_spec.engine)
-
- # Only apply limit to non-mutating (SELECT-like) queries
- if script.has_mutation():
- return sql
-
- # Apply limit to each statement in the script
- for statement in script.statements:
- # Only set limit if not already present
- if statement.get_limit_value() is None:
- statement.set_limit_value(limit, LimitMethod.FORCE_LIMIT)
-
- return script.format()
-
-
-def _execute_query(
- database: Any,
- sql: str,
- schema: str | None,
- limit: int,
-) -> dict[str, Any]:
- """Execute the query and process results."""
- # Import inside function to avoid initialization issues
- from superset.sql.parse import SQLScript
- from superset.utils.core import QuerySource
-
- results = {
- "rows": [],
- "columns": [],
- "row_count": 0,
- "affected_rows": None,
- "execution_time": 0.0,
- }
-
- try:
- # Execute query with timeout
- with database.get_raw_connection(
- catalog=None,
- schema=schema,
- source=QuerySource.SQL_LAB,
- ) as conn:
- cursor = conn.cursor()
- cursor.execute(sql)
-
- # Use SQLScript for proper SQL parsing to determine query type
- script = SQLScript(sql, database.db_engine_spec.engine)
- if script.has_mutation():
- _process_dml_results(cursor, conn, results)
- else:
- _process_select_results(cursor, results, limit)
-
- except Exception as e:
- logger.error("Error executing SQL: %s", e)
- raise
-
- return results
-
-
-def _process_select_results(cursor: Any, results: dict[str, Any], limit: int)
-> None:
- """Process SELECT query results."""
- # Fetch results
- data = cursor.fetchmany(limit)
-
- # Get column metadata
- column_info = []
- if cursor.description:
- for col in cursor.description:
- column_info.append(
- {
- "name": col[0],
- "type": str(col[1]) if col[1] else "unknown",
- "is_nullable": col[6] if len(col) > 6 else None,
- }
- )
-
- # Set column info regardless of whether there's data
- if column_info:
- results["columns"] = column_info
-
- # Convert rows to dictionaries
- column_names = [col["name"] for col in column_info]
- results["rows"] = [dict(zip(column_names, row, strict=False)) for row
in data]
- results["row_count"] = len(data)
-
-
-def _process_dml_results(cursor: Any, conn: Any, results: dict[str, Any]) ->
None:
- """Process DML query results."""
- results["affected_rows"] = cursor.rowcount
- conn.commit() # pylint: disable=consider-using-transaction
diff --git a/superset/mcp_service/sql_lab/tool/execute_sql.py
b/superset/mcp_service/sql_lab/tool/execute_sql.py
index cb93454084..64fe24a395 100644
--- a/superset/mcp_service/sql_lab/tool/execute_sql.py
+++ b/superset/mcp_service/sql_lab/tool/execute_sql.py
@@ -18,19 +18,26 @@
"""
Execute SQL MCP Tool
-Tool for executing SQL queries against databases with security validation
-and timeout protection.
+Tool for executing SQL queries against databases using the unified
+Database.execute() API with RLS, template rendering, and security validation.
"""
+from __future__ import annotations
+
import logging
+from typing import Any
from fastmcp import Context
+from superset_core.api.types import CacheOptions, QueryOptions, QueryResult,
QueryStatus
from superset_core.mcp import tool
-from superset.mcp_service.sql_lab.execute_sql_core import ExecuteSqlCore
+from superset.errors import ErrorLevel, SupersetError, SupersetErrorType
+from superset.exceptions import SupersetErrorException,
SupersetSecurityException
from superset.mcp_service.sql_lab.schemas import (
+ ColumnInfo,
ExecuteSqlRequest,
ExecuteSqlResponse,
+ StatementInfo,
)
from superset.mcp_service.utils.schema_utils import parse_request
@@ -40,10 +47,7 @@ logger = logging.getLogger(__name__)
@tool(tags=["mutate"])
@parse_request(ExecuteSqlRequest)
async def execute_sql(request: ExecuteSqlRequest, ctx: Context) ->
ExecuteSqlResponse:
- """Execute SQL query against database.
-
- Returns query results with security validation and timeout protection.
- """
+ """Execute SQL query against database using the unified Database.execute()
API."""
await ctx.info(
"Starting SQL execution: database_id=%s, timeout=%s, limit=%s,
schema=%s"
% (request.database_id, request.timeout, request.limit,
request.schema_name)
@@ -52,36 +56,78 @@ async def execute_sql(request: ExecuteSqlRequest, ctx:
Context) -> ExecuteSqlRes
# Log SQL query details (truncated for security)
sql_preview = request.sql[:100] + "..." if len(request.sql) > 100 else
request.sql
await ctx.debug(
- "SQL query details: sql_preview=%r, sql_length=%s, has_parameters=%s"
+ "SQL query details: sql_preview=%r, sql_length=%s,
has_template_params=%s"
% (
sql_preview,
len(request.sql),
- bool(request.parameters),
+ bool(request.template_params),
)
)
logger.info("Executing SQL query on database ID: %s", request.database_id)
try:
- # Use the ExecuteSqlCore to handle all the logic
- sql_tool = ExecuteSqlCore(use_command_mode=False, logger=logger)
- result = sql_tool.run_tool(request)
+ # Import inside function to avoid initialization issues
+ from superset import db, security_manager
+ from superset.models.core import Database
+
+ # 1. Get database and check access
+ database =
db.session.query(Database).filter_by(id=request.database_id).first()
+ if not database:
+ raise SupersetErrorException(
+ SupersetError(
+ message=f"Database with ID {request.database_id} not
found",
+ error_type=SupersetErrorType.DATABASE_NOT_FOUND_ERROR,
+ level=ErrorLevel.ERROR,
+ )
+ )
+
+ if not security_manager.can_access_database(database):
+ raise SupersetSecurityException(
+ SupersetError(
+ message=f"Access denied to database
{database.database_name}",
+
error_type=SupersetErrorType.DATABASE_SECURITY_ACCESS_ERROR,
+ level=ErrorLevel.ERROR,
+ )
+ )
+
+ # 2. Build QueryOptions
+ # Caching is enabled by default to reduce database load.
+ # force_refresh bypasses cache when user explicitly requests fresh
data.
+ cache_opts = CacheOptions(force_refresh=True) if request.force_refresh
else None
+ options = QueryOptions(
+ catalog=request.catalog,
+ schema=request.schema_name,
+ limit=request.limit,
+ timeout_seconds=request.timeout,
+ template_params=request.template_params,
+ dry_run=request.dry_run,
+ cache=cache_opts,
+ )
+
+ # 3. Execute query
+ result = database.execute(request.sql, options)
+
+ # 4. Convert to MCP response format
+ response = _convert_to_response(result)
# Log successful execution
- if hasattr(result, "data") and result.data:
- row_count = len(result.data) if isinstance(result.data, list) else
1
+ if response.success:
await ctx.info(
"SQL execution completed successfully: rows_returned=%s, "
- "query_duration_ms=%s"
+ "execution_time=%s"
% (
- row_count,
- getattr(result, "query_duration_ms", None),
+ response.row_count,
+ response.execution_time,
)
)
else:
- await ctx.info("SQL execution completed: status=no_data_returned")
+ await ctx.info(
+ "SQL execution failed: error=%s, error_type=%s"
+ % (response.error, response.error_type)
+ )
- return result
+ return response
except Exception as e:
await ctx.error(
@@ -92,3 +138,55 @@ async def execute_sql(request: ExecuteSqlRequest, ctx:
Context) -> ExecuteSqlRes
)
)
raise
+
+
+def _convert_to_response(result: QueryResult) -> ExecuteSqlResponse:
+ """Convert QueryResult to ExecuteSqlResponse."""
+ if result.status != QueryStatus.SUCCESS:
+ return ExecuteSqlResponse(
+ success=False,
+ error=result.error_message,
+ error_type=result.status.value,
+ )
+
+ # Build statement info list
+ statements = [
+ StatementInfo(
+ original_sql=stmt.original_sql,
+ executed_sql=stmt.executed_sql,
+ row_count=stmt.row_count,
+ execution_time_ms=stmt.execution_time_ms,
+ )
+ for stmt in result.statements
+ ]
+
+ # Get first statement's data for backward compatibility
+ first_stmt = result.statements[0] if result.statements else None
+ rows: list[dict[str, Any]] | None = None
+ columns: list[ColumnInfo] | None = None
+ row_count: int | None = None
+ affected_rows: int | None = None
+
+ if first_stmt and first_stmt.data is not None:
+ # SELECT query - convert DataFrame
+ df = first_stmt.data
+ rows = df.to_dict(orient="records")
+ columns = [ColumnInfo(name=col, type=str(df[col].dtype)) for col in
df.columns]
+ row_count = len(df)
+ elif first_stmt:
+ # DML query
+ affected_rows = first_stmt.row_count
+
+ return ExecuteSqlResponse(
+ success=True,
+ rows=rows,
+ columns=columns,
+ row_count=row_count,
+ affected_rows=affected_rows,
+ execution_time=(
+ result.total_execution_time_ms / 1000
+ if result.total_execution_time_ms is not None
+ else None
+ ),
+ statements=statements,
+ )
diff --git a/superset/mcp_service/utils/schema_utils.py
b/superset/mcp_service/utils/schema_utils.py
index 4e97abf405..cc1aa2392d 100644
--- a/superset/mcp_service/utils/schema_utils.py
+++ b/superset/mcp_service/utils/schema_utils.py
@@ -508,10 +508,23 @@ def parse_request(
new_params = []
for name, param in orig_sig.parameters.items():
# Skip ctx parameter - FastMCP tools don't expose it to clients
- if param.annotation is FMContext or (
- hasattr(param.annotation, "__name__")
- and param.annotation.__name__ == "Context"
- ):
+ # Check for Context type, forward reference string, or parameter
named 'ctx'
+ is_context = (
+ param.annotation is FMContext
+ or (
+ hasattr(param.annotation, "__name__")
+ and param.annotation.__name__ == "Context"
+ )
+ or (
+ isinstance(param.annotation, str)
+ and (
+ param.annotation == "Context"
+ or param.annotation.endswith(".Context")
+ )
+ )
+ or name == "ctx" # Fallback: skip any param named 'ctx'
+ )
+ if is_context:
continue
if name == "request":
new_params.append(param.replace(annotation=str |
request_class))
diff --git a/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py
b/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py
deleted file mode 100644
index 9a08b72a91..0000000000
--- a/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py
+++ /dev/null
@@ -1,137 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements. See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership. The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License. You may obtain a copy of the License at
-#
-# http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied. See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-"""Unit tests for MCP SQL Lab utility functions."""
-
-from unittest.mock import MagicMock
-
-import pytest
-
-from superset.mcp_service.sql_lab.sql_lab_utils import _apply_limit
-from superset.sql.parse import SQLScript
-
-
-class TestSQLScriptMutationDetection:
- """Tests for SQLScript.has_mutation() used in query type detection."""
-
- def test_simple_select_no_mutation(self):
- """Test simple SELECT query is not a mutation."""
- script = SQLScript("SELECT * FROM table1", "sqlite")
- assert script.has_mutation() is False
-
- def test_cte_query_no_mutation(self):
- """Test CTE (WITH clause) query is not a mutation."""
- cte_sql = """
- WITH cte_name AS (
- SELECT * FROM table1
- )
- SELECT * FROM cte_name
- """
- script = SQLScript(cte_sql, "sqlite")
- assert script.has_mutation() is False
-
- def test_recursive_cte_no_mutation(self):
- """Test recursive CTE is not a mutation."""
- recursive_cte = """
- WITH RECURSIVE cte AS (
- SELECT 1 AS n
- UNION ALL
- SELECT n + 1 FROM cte WHERE n < 10
- )
- SELECT n FROM cte
- """
- script = SQLScript(recursive_cte, "sqlite")
- assert script.has_mutation() is False
-
- def test_multiple_ctes_no_mutation(self):
- """Test query with multiple CTEs is not a mutation."""
- multiple_ctes = """
- WITH
- cte1 AS (SELECT 1 as a),
- cte2 AS (SELECT 2 as b)
- SELECT * FROM cte1, cte2
- """
- script = SQLScript(multiple_ctes, "sqlite")
- assert script.has_mutation() is False
-
- def test_insert_is_mutation(self):
- """Test INSERT query is a mutation."""
- script = SQLScript("INSERT INTO table1 VALUES (1)", "sqlite")
- assert script.has_mutation() is True
-
- def test_update_is_mutation(self):
- """Test UPDATE query is a mutation."""
- script = SQLScript("UPDATE table1 SET col = 1", "sqlite")
- assert script.has_mutation() is True
-
- def test_delete_is_mutation(self):
- """Test DELETE query is a mutation."""
- script = SQLScript("DELETE FROM table1", "sqlite")
- assert script.has_mutation() is True
-
- def test_create_is_mutation(self):
- """Test CREATE query is a mutation."""
- script = SQLScript("CREATE TABLE table1 (id INT)", "sqlite")
- assert script.has_mutation() is True
-
-
-class TestApplyLimit:
- """Tests for _apply_limit function using SQLScript."""
-
- @pytest.fixture
- def mock_database(self):
- """Create a mock database with sqlite engine spec."""
- db = MagicMock()
- db.db_engine_spec.engine = "sqlite"
- return db
-
- def test_adds_limit_to_select(self, mock_database):
- """Test LIMIT is added to SELECT query."""
- result = _apply_limit("SELECT * FROM table1", 100, mock_database)
- assert "LIMIT 100" in result
-
- def test_adds_limit_to_cte(self, mock_database):
- """Test LIMIT is added to CTE query."""
- cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte"
- result = _apply_limit(cte_sql, 50, mock_database)
- assert "LIMIT 50" in result
-
- def test_preserves_existing_limit(self, mock_database):
- """Test existing LIMIT is not modified."""
- sql = "SELECT * FROM table1 LIMIT 10"
- result = _apply_limit(sql, 100, mock_database)
- assert "LIMIT 10" in result
- assert "LIMIT 100" not in result
-
- def test_preserves_existing_limit_in_cte(self, mock_database):
- """Test existing LIMIT in CTE query is not modified."""
- cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 5"
- result = _apply_limit(cte_sql, 100, mock_database)
- assert "LIMIT 5" in result
- assert "LIMIT 100" not in result
-
- def test_no_limit_on_insert(self, mock_database):
- """Test LIMIT is not added to INSERT query."""
- sql = "INSERT INTO table1 VALUES (1)"
- result = _apply_limit(sql, 100, mock_database)
- assert result == sql
-
- def test_no_limit_on_update(self, mock_database):
- """Test LIMIT is not added to UPDATE query."""
- sql = "UPDATE table1 SET col = 1"
- result = _apply_limit(sql, 100, mock_database)
- assert result == sql
diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
index 5b3012462d..86b9af9188 100644
--- a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
+++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
@@ -17,14 +17,20 @@
"""
Unit tests for execute_sql MCP tool
+
+These tests mock Database.execute() to test the MCP tool's parameter mapping
+and response conversion logic.
"""
import logging
+from typing import Any
from unittest.mock import MagicMock, Mock, patch
+import pandas as pd
import pytest
from fastmcp import Client
from fastmcp.exceptions import ToolError
+from superset_core.api.types import QueryResult, QueryStatus, StatementResult
from superset.mcp_service.app import mcp
@@ -48,6 +54,71 @@ def mock_auth():
yield mock_get_user
+def _create_select_result(
+ rows: list[dict[str, Any]],
+ columns: list[str],
+ original_sql: str = "SELECT * FROM users",
+ executed_sql: str | None = None,
+ execution_time_ms: float = 10.0,
+) -> QueryResult:
+ """Create a mock QueryResult for SELECT queries."""
+ df = pd.DataFrame(rows) if rows else pd.DataFrame(columns=columns)
+ return QueryResult(
+ status=QueryStatus.SUCCESS,
+ statements=[
+ StatementResult(
+ original_sql=original_sql,
+ executed_sql=executed_sql or original_sql,
+ data=df,
+ row_count=len(df),
+ execution_time_ms=execution_time_ms,
+ )
+ ],
+ query_id=None,
+ total_execution_time_ms=execution_time_ms,
+ is_cached=False,
+ )
+
+
+def _create_dml_result(
+ affected_rows: int,
+ original_sql: str = "UPDATE users SET active = true",
+ executed_sql: str | None = None,
+ execution_time_ms: float = 5.0,
+) -> QueryResult:
+ """Create a mock QueryResult for DML queries."""
+ return QueryResult(
+ status=QueryStatus.SUCCESS,
+ statements=[
+ StatementResult(
+ original_sql=original_sql,
+ executed_sql=executed_sql or original_sql,
+ data=None,
+ row_count=affected_rows,
+ execution_time_ms=execution_time_ms,
+ )
+ ],
+ query_id=None,
+ total_execution_time_ms=execution_time_ms,
+ is_cached=False,
+ )
+
+
+def _create_error_result(
+ error_message: str,
+ status: QueryStatus = QueryStatus.FAILED,
+) -> QueryResult:
+ """Create a mock QueryResult for failed queries."""
+ return QueryResult(
+ status=status,
+ statements=[],
+ query_id=None,
+ total_execution_time_ms=0,
+ is_cached=False,
+ error_message=error_message,
+ )
+
+
def _mock_database(
id: int = 1,
database_name: str = "test_db",
@@ -58,26 +129,6 @@ def _mock_database(
database.id = id
database.database_name = database_name
database.allow_dml = allow_dml
-
- # Mock raw connection context manager
- mock_cursor = Mock()
- mock_cursor.description = [
- ("id", "INTEGER", None, None, None, None, False),
- ("name", "VARCHAR", None, None, None, None, True),
- ]
- mock_cursor.fetchmany.return_value = [(1, "test_name")]
- mock_cursor.rowcount = 1
-
- mock_conn = Mock()
- mock_conn.cursor.return_value = mock_cursor
- mock_conn.commit = Mock()
-
- mock_context = MagicMock()
- mock_context.__enter__.return_value = mock_conn
- mock_context.__exit__.return_value = None
-
- database.get_raw_connection.return_value = mock_context
-
return database
@@ -91,8 +142,11 @@ class TestExecuteSql:
self, mock_db, mock_security_manager, mcp_server
):
"""Test basic SELECT query execution."""
- # Setup mocks
mock_database = _mock_database()
+ mock_database.execute.return_value = _create_select_result(
+ rows=[{"id": 1, "name": "test_name"}],
+ columns=["id", "name"],
+ )
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -107,25 +161,41 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request":
request})
- assert result.data.success is True
- assert result.data.error is None
- assert result.data.row_count == 1
- assert len(result.data.rows) == 1
- assert result.data.rows[0]["id"] == 1
- assert result.data.rows[0]["name"] == "test_name"
- assert len(result.data.columns) == 2
- assert result.data.columns[0].name == "id"
- assert result.data.columns[0].type == "INTEGER"
- assert result.data.execution_time > 0
+ # Use structured_content for dictionary access (Pydantic model
responses)
+ data = result.structured_content
+ assert data["success"] is True
+ assert data["error"] is None
+ assert data["row_count"] == 1
+ assert len(data["rows"]) == 1
+ assert data["rows"][0]["id"] == 1
+ assert data["rows"][0]["name"] == "test_name"
+ assert len(data["columns"]) == 2
+ assert data["columns"][0]["name"] == "id"
+ assert data["execution_time"] > 0
+
+ # Verify Database.execute() was called with correct QueryOptions
+ mock_database.execute.assert_called_once()
+ call_args = mock_database.execute.call_args
+ assert call_args[0][0] == request["sql"]
+ options = call_args[0][1]
+ assert options.limit == 10
+ # Caching is enabled by default (force_refresh=False means
cache=None)
+ assert options.cache is None
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
- async def test_execute_sql_with_parameters(
+ async def test_execute_sql_with_template_params(
self, mock_db, mock_security_manager, mcp_server
):
- """Test SQL execution with parameter substitution."""
+ """Test SQL execution with Jinja2 template parameters."""
mock_database = _mock_database()
+ mock_database.execute.return_value = _create_select_result(
+ rows=[{"order_id": 1, "status": "active"}],
+ columns=["order_id", "status"],
+ original_sql="SELECT * FROM {{ table }} WHERE status = '{{ status
}}'",
+ executed_sql="SELECT * FROM orders WHERE status = 'active'",
+ )
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -133,33 +203,42 @@ class TestExecuteSql:
request = {
"database_id": 1,
- "sql": "SELECT * FROM {table} WHERE status = '{status}' LIMIT
{limit}",
- "parameters": {"table": "orders", "status": "active", "limit":
"5"},
+ "sql": "SELECT * FROM {{ table }} WHERE status = '{{ status }}'",
+ "template_params": {"table": "orders", "status": "active"},
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request":
request})
- assert result.data.success is True
- assert result.data.error is None
- # Verify parameter substitution happened
- mock_database.get_raw_connection.assert_called_once()
- cursor = ( # fmt: skip
-
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
- )
- # Check that the SQL was formatted with parameters
- executed_sql = cursor.execute.call_args[0][0]
- assert "orders" in executed_sql
- assert "active" in executed_sql
+ # Use structured_content for dictionary access (Pydantic model
responses)
+ data = result.structured_content
+ assert data["success"] is True
+ assert data["error"] is None
+
+ # Verify template_params were passed to QueryOptions
+ call_args = mock_database.execute.call_args
+ options = call_args[0][1]
+ assert options.template_params == {"table": "orders", "status":
"active"}
+
+ # Verify statements contain both original and executed SQL
+ assert data["statements"] is not None
+ assert len(data["statements"]) == 1
+ assert "{{ table }}" in data["statements"][0]["original_sql"]
+ assert "orders" in data["statements"][0]["executed_sql"]
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_database_not_found(
- self, mock_db, mock_security_manager, mcp_server
+ self,
+ mock_db,
+ mock_security_manager, # noqa: PT019
+ mcp_server,
):
"""Test error when database is not found."""
+ # mock_security_manager is patched but not used (error happens first)
+ del mock_security_manager # Silence unused variable warning
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
None
)
@@ -171,15 +250,10 @@ class TestExecuteSql:
}
async with Client(mcp_server) as client:
- result = await client.call_tool("execute_sql", {"request":
request})
-
- assert result.data.success is False
- assert result.data.error is not None
- assert "Database with ID 999 not found" in result.data.error
- assert result.data.error_type == "DATABASE_NOT_FOUND_ERROR"
- assert result.data.rows is None
+ with pytest.raises(ToolError, match="Database with ID 999 not
found"):
+ await client.call_tool("execute_sql", {"request": request})
- @patch("superset.security_manager")
+ @patch("superset.security_manager", new_callable=MagicMock)
@patch("superset.db")
@pytest.mark.asyncio
async def test_execute_sql_access_denied(
@@ -190,10 +264,7 @@ class TestExecuteSql:
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
- # Use Mock instead of AsyncMock for synchronous call
- from unittest.mock import Mock
-
- mock_security_manager.can_access_database = Mock(return_value=False)
+ mock_security_manager.can_access_database.return_value = False
request = {
"database_id": 1,
@@ -202,58 +273,27 @@ class TestExecuteSql:
}
async with Client(mcp_server) as client:
- result = await client.call_tool("execute_sql", {"request":
request})
-
- assert result.data.success is False
- assert result.data.error is not None
- assert "Access denied to database" in result.data.error
- assert result.data.error_type == "SECURITY_ERROR"
-
- @patch("superset.security_manager")
- @patch("superset.db")
- @pytest.mark.asyncio
- async def test_execute_sql_dml_not_allowed(
- self, mock_db, mock_security_manager, mcp_server
- ):
- """Test error when DML operations are not allowed."""
- mock_database = _mock_database(allow_dml=False)
-
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
- mock_database
- )
- mock_security_manager.can_access_database.return_value = True
-
- request = {
- "database_id": 1,
- "sql": "UPDATE users SET name = 'test' WHERE id = 1",
- "limit": 1,
- }
-
- async with Client(mcp_server) as client:
- result = await client.call_tool("execute_sql", {"request":
request})
-
- assert result.data.success is False
- assert result.data.error is not None
- assert result.data.error_type == "DML_NOT_ALLOWED"
+ with pytest.raises(ToolError, match="Access denied to database"):
+ await client.call_tool("execute_sql", {"request": request})
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
- async def test_execute_sql_dml_allowed(
+ async def test_execute_sql_dml_success(
self, mock_db, mock_security_manager, mcp_server
):
- """Test successful DML execution when allowed."""
+ """Test successful DML execution."""
mock_database = _mock_database(allow_dml=True)
+ dml_sql = "UPDATE users SET active = true WHERE last_login >
'2024-01-01'"
+ mock_database.execute.return_value = _create_dml_result(
+ affected_rows=3,
+ original_sql=dml_sql,
+ )
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
- # Mock cursor for DML operation
- cursor = ( # fmt: skip
-
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
- )
- cursor.rowcount = 3 # 3 rows affected
-
request = {
"database_id": 1,
"sql": "UPDATE users SET active = true WHERE last_login >
'2024-01-01'",
@@ -263,15 +303,13 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request":
request})
- assert result.data.success is True
- assert result.data.error is None
- assert result.data.affected_rows == 3
- assert result.data.rows == [] # Empty rows for DML
- assert result.data.row_count == 0
- # Verify commit was called
- (
-
mock_database.get_raw_connection.return_value.__enter__.return_value.commit.assert_called_once()
- )
+ # Use structured_content for dictionary access (Pydantic model
responses)
+ data = result.structured_content
+ assert data["success"] is True
+ assert data["error"] is None
+ assert data["affected_rows"] == 3
+ assert data["rows"] is None # None for DML
+ assert data["row_count"] is None
@patch("superset.security_manager")
@patch("superset.db")
@@ -281,17 +319,15 @@ class TestExecuteSql:
):
"""Test query that returns no results."""
mock_database = _mock_database()
+ mock_database.execute.return_value = _create_select_result(
+ rows=[],
+ columns=["id", "name"],
+ )
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
- # Mock empty results
- cursor = ( # fmt: skip
-
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
- )
- cursor.fetchmany.return_value = []
-
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 999999",
@@ -301,20 +337,26 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request":
request})
- assert result.data.success is True
- assert result.data.error is None
- assert result.data.row_count == 0
- assert len(result.data.rows) == 0
- assert len(result.data.columns) == 2 # Column metadata still
returned
+ # Use structured_content for dictionary access (Pydantic model
responses)
+ data = result.structured_content
+ assert data["success"] is True
+ assert data["error"] is None
+ assert data["row_count"] == 0
+ assert len(data["rows"]) == 0
+ assert len(data["columns"]) == 2 # Column metadata still returned
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
- async def test_execute_sql_missing_parameter(
+ async def test_execute_sql_with_schema_and_catalog(
self, mock_db, mock_security_manager, mcp_server
):
- """Test error when required parameter is missing."""
+ """Test SQL execution with schema and catalog specification."""
mock_database = _mock_database()
+ mock_database.execute.return_value = _create_select_result(
+ rows=[{"total": 100}],
+ columns=["total"],
+ )
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -322,56 +364,49 @@ class TestExecuteSql:
request = {
"database_id": 1,
- "sql": "SELECT * FROM {table_name} WHERE id = {user_id}",
- "parameters": {"table_name": "users"}, # Missing user_id
+ "sql": "SELECT COUNT(*) as total FROM orders",
+ "schema": "sales",
+ "catalog": "prod_catalog",
"limit": 1,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request":
request})
- assert result.data.success is False
- assert result.data.error is not None
- assert "user_id" in result.data.error # Error contains parameter
name
- assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
-
- @patch("superset.security_manager")
- @patch("superset.db")
- @pytest.mark.asyncio
- async def test_execute_sql_empty_parameters_with_placeholders(
- self, mock_db, mock_security_manager, mcp_server
- ):
- """Test error when empty parameters dict is provided but SQL has
- placeholders."""
- mock_database = _mock_database()
-
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
- mock_database
- )
- mock_security_manager.can_access_database.return_value = True
-
- request = {
- "database_id": 1,
- "sql": "SELECT * FROM {table_name} LIMIT 5",
- "parameters": {}, # Empty dict but SQL has {table_name}
- "limit": 5,
- }
-
- async with Client(mcp_server) as client:
- result = await client.call_tool("execute_sql", {"request":
request})
+ # Use structured_content for dictionary access (Pydantic model
responses)
+ data = result.structured_content
+ assert data["success"] is True
- assert result.data.success is False
- assert result.data.error is not None
- assert "Missing parameter: table_name" in result.data.error
- assert result.data.error_type == "INVALID_PAYLOAD_FORMAT_ERROR"
+ # Verify schema and catalog were passed to QueryOptions
+ call_args = mock_database.execute.call_args
+ options = call_args[0][1]
+ assert options.schema == "sales"
+ assert options.catalog == "prod_catalog"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
- async def test_execute_sql_with_schema(
+ async def test_execute_sql_dry_run(
self, mock_db, mock_security_manager, mcp_server
):
- """Test SQL execution with schema specification."""
+ """Test dry_run mode returns transformed SQL without executing."""
mock_database = _mock_database()
+ executed_sql = "SELECT * FROM users WHERE user_id IN (SELECT ...)
LIMIT 100"
+ mock_database.execute.return_value = QueryResult(
+ status=QueryStatus.SUCCESS,
+ statements=[
+ StatementResult(
+ original_sql="SELECT * FROM {{ table }}",
+ executed_sql=executed_sql,
+ data=None,
+ row_count=0,
+ execution_time_ms=0,
+ )
+ ],
+ query_id=None,
+ total_execution_time_ms=0,
+ is_cached=False,
+ )
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -379,30 +414,44 @@ class TestExecuteSql:
request = {
"database_id": 1,
- "sql": "SELECT COUNT(*) as total FROM orders",
- "schema": "sales",
- "limit": 1,
+ "sql": "SELECT * FROM {{ table }}",
+ "template_params": {"table": "users"},
+ "dry_run": True,
+ "limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request":
request})
- assert result.data.success is True
- assert result.data.error is None
- # Verify schema was passed to get_raw_connection
- # Verify schema was passed
- call_args = mock_database.get_raw_connection.call_args
- assert call_args[1]["schema"] == "sales"
- assert call_args[1]["catalog"] is None
+ # Use structured_content for dictionary access (Pydantic model
responses)
+ data = result.structured_content
+ assert data["success"] is True
+ # Verify dry_run was passed
+ call_args = mock_database.execute.call_args
+ options = call_args[0][1]
+ assert options.dry_run is True
+
+ # Verify statements show transformed SQL
+ assert data["statements"] is not None
+ assert "{{ table }}" in data["statements"][0]["original_sql"]
+ assert "users" in data["statements"][0]["executed_sql"]
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
- async def test_execute_sql_limit_enforcement(
+ async def test_execute_sql_timeout_error(
self, mock_db, mock_security_manager, mcp_server
):
- """Test that LIMIT is added to SELECT queries without one."""
+ """Test that SQL injection attempts are handled safely.
+
+ SQLScript detects the DROP TABLE as a mutation and blocks it
+ before execution when DML is not allowed on the database.
+ """
mock_database = _mock_database()
+ mock_database.execute.return_value = _create_error_result(
+ error_message="Query exceeded the timeout limit",
+ status=QueryStatus.TIMED_OUT,
+ )
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -410,33 +459,50 @@ class TestExecuteSql:
request = {
"database_id": 1,
- "sql": "SELECT * FROM users", # No LIMIT
- "limit": 50,
+ "sql": "SELECT * FROM large_table",
+ "timeout": 5,
+ "limit": 100,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request":
request})
- assert result.data.success is True
- # Verify LIMIT was added
- cursor = ( # fmt: skip
-
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
- )
- executed_sql = cursor.execute.call_args[0][0]
- assert "LIMIT 50" in executed_sql
+ # Use structured_content for dictionary access (Pydantic model
responses)
+ data = result.structured_content
+ assert data["success"] is False
+ assert data["error"] == "Query exceeded the timeout limit"
+ assert data["error_type"] == "timed_out"
@patch("superset.security_manager")
@patch("superset.db")
@pytest.mark.asyncio
- async def test_execute_sql_sql_injection_prevention(
+ async def test_execute_sql_multi_statement(
self, mock_db, mock_security_manager, mcp_server
):
- """Test that SQL injection attempts are handled safely.
-
- SQLScript detects the DROP TABLE as a mutation and blocks it
- before execution when DML is not allowed on the database.
- """
+ """Test multi-statement SQL execution."""
mock_database = _mock_database()
+ mock_database.execute.return_value = QueryResult(
+ status=QueryStatus.SUCCESS,
+ statements=[
+ StatementResult(
+ original_sql="SELECT 1 as a",
+ executed_sql="SELECT 1 as a",
+ data=pd.DataFrame([{"a": 1}]),
+ row_count=1,
+ execution_time_ms=5.0,
+ ),
+ StatementResult(
+ original_sql="SELECT 2 as b",
+ executed_sql="SELECT 2 as b",
+ data=pd.DataFrame([{"b": 2}]),
+ row_count=1,
+ execution_time_ms=3.0,
+ ),
+ ],
+ query_id=None,
+ total_execution_time_ms=8.0,
+ is_cached=False,
+ )
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
@@ -444,19 +510,25 @@ class TestExecuteSql:
request = {
"database_id": 1,
- "sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
+ "sql": "SELECT 1 as a; SELECT 2 as b;",
"limit": 10,
}
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request":
request})
- # SQLScript correctly detects DROP TABLE as a mutation
- # and blocks it before execution (improved security)
- assert result.data.success is False
- assert result.data.error is not None
- assert "DML" in result.data.error or "mutates" in result.data.error
- assert result.data.error_type == "DML_NOT_ALLOWED"
+ # Use structured_content for dictionary access (Pydantic model
responses)
+ data = result.structured_content
+ assert data["success"] is True
+ # Statements should contain both
+ assert data["statements"] is not None
+ assert len(data["statements"]) == 2
+ assert data["statements"][0]["original_sql"] == "SELECT 1 as a"
+ assert data["statements"][1]["original_sql"] == "SELECT 2 as b"
+
+ # rows/columns should be from first statement for backward compat
+ assert data["rows"] == [{"a": 1}]
+ assert data["row_count"] == 1
@pytest.mark.asyncio
async def test_execute_sql_empty_query_validation(self, mcp_server):
@@ -495,3 +567,39 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
with pytest.raises(ToolError, match="less than or equal to 10000"):
await client.call_tool("execute_sql", {"request": request})
+
+ @patch("superset.security_manager")
+ @patch("superset.db")
+ @pytest.mark.asyncio
+ async def test_execute_sql_force_refresh(
+ self, mock_db, mock_security_manager, mcp_server
+ ):
+ """Test force_refresh bypasses cache."""
+ mock_database = _mock_database()
+ mock_database.execute.return_value = _create_select_result(
+ rows=[{"id": 1}],
+ columns=["id"],
+ )
+
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
+ mock_database
+ )
+ mock_security_manager.can_access_database.return_value = True
+
+ request = {
+ "database_id": 1,
+ "sql": "SELECT id FROM users",
+ "limit": 10,
+ "force_refresh": True,
+ }
+
+ async with Client(mcp_server) as client:
+ result = await client.call_tool("execute_sql", {"request":
request})
+
+ data = result.structured_content
+ assert data["success"] is True
+
+ # Verify force_refresh was passed to CacheOptions
+ call_args = mock_database.execute.call_args
+ options = call_args[0][1]
+ assert options.cache is not None
+ assert options.cache.force_refresh is True