This is an automated email from the ASF dual-hosted git repository.
rusackas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git
The following commit(s) were added to refs/heads/master by this push:
new e3e6b0e18b fix(mcp): use SQLScript for all SQL parsing in execute_sql
(#36599)
e3e6b0e18b is described below
commit e3e6b0e18bddf2cf523131c5df98c567e6e3cda7
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Sat Dec 20 23:52:56 2025 -0500
fix(mcp): use SQLScript for all SQL parsing in execute_sql (#36599)
---
superset/mcp_service/sql_lab/sql_lab_utils.py | 64 +++++-----
.../mcp_service/sql_lab/test_sql_lab_utils.py | 137 +++++++++++++++++++++
.../mcp_service/sql_lab/tool/test_execute_sql.py | 18 +--
3 files changed, 180 insertions(+), 39 deletions(-)
diff --git a/superset/mcp_service/sql_lab/sql_lab_utils.py
b/superset/mcp_service/sql_lab/sql_lab_utils.py
index 6844e26a49..10c543c768 100644
--- a/superset/mcp_service/sql_lab/sql_lab_utils.py
+++ b/superset/mcp_service/sql_lab/sql_lab_utils.py
@@ -70,26 +70,22 @@ def validate_sql_query(sql: str, database: Any) -> None:
SupersetDisallowedSQLFunctionException,
SupersetDMLNotAllowedException,
)
+ from superset.sql.parse import SQLScript
- # Simplified validation without complex parsing
- sql_upper = sql.upper().strip()
+ # Use SQLScript for proper SQL parsing
+ script = SQLScript(sql, database.db_engine_spec.engine)
# Check for DML operations if not allowed
- dml_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER",
"TRUNCATE"]
- if any(sql_upper.startswith(keyword) for keyword in dml_keywords):
- if not database.allow_dml:
- raise SupersetDMLNotAllowedException()
+ if script.has_mutation() and not database.allow_dml:
+ raise SupersetDMLNotAllowedException()
# Check for disallowed functions from config
disallowed_functions = app.config.get("DISALLOWED_SQL_FUNCTIONS", {}).get(
- "sqlite",
- set(), # Default to sqlite for now
+ database.db_engine_spec.engine,
+ set(),
)
- if disallowed_functions:
- sql_lower = sql.lower()
- for func in disallowed_functions:
- if f"{func.lower()}(" in sql_lower:
- raise
SupersetDisallowedSQLFunctionException(disallowed_functions)
+ if disallowed_functions and
script.check_functions_present(disallowed_functions):
+ raise SupersetDisallowedSQLFunctionException(disallowed_functions)
def execute_sql_query(
@@ -110,8 +106,8 @@ def execute_sql_query(
sql = _apply_parameters(sql, parameters)
validate_sql_query(sql, database)
- # Apply limit for SELECT queries
- rendered_sql = _apply_limit(sql, limit)
+ # Apply limit for SELECT queries using SQLScript
+ rendered_sql = _apply_limit(sql, limit, database)
# Execute and get results
results = _execute_query(database, rendered_sql, schema, limit)
@@ -156,12 +152,23 @@ def _apply_parameters(sql: str, parameters: dict[str,
Any] | None) -> str:
return sql
-def _apply_limit(sql: str, limit: int) -> str:
- """Apply limit to SELECT queries if not already present."""
- sql_lower = sql.lower().strip()
- if sql_lower.startswith("select") and "limit" not in sql_lower:
- return f"{sql.rstrip().rstrip(';')} LIMIT {limit}"
- return sql
+def _apply_limit(sql: str, limit: int, database: Any) -> str:
+ """Apply limit to SELECT queries using SQLScript for proper parsing."""
+ from superset.sql.parse import LimitMethod, SQLScript
+
+ script = SQLScript(sql, database.db_engine_spec.engine)
+
+ # Only apply limit to non-mutating (SELECT-like) queries
+ if script.has_mutation():
+ return sql
+
+ # Apply limit to each statement in the script
+ for statement in script.statements:
+ # Only set limit if not already present
+ if statement.get_limit_value() is None:
+ statement.set_limit_value(limit, LimitMethod.FORCE_LIMIT)
+
+ return script.format()
def _execute_query(
@@ -172,6 +179,7 @@ def _execute_query(
) -> dict[str, Any]:
"""Execute the query and process results."""
# Import inside function to avoid initialization issues
+ from superset.sql.parse import SQLScript
from superset.utils.core import QuerySource
results = {
@@ -192,11 +200,12 @@ def _execute_query(
cursor = conn.cursor()
cursor.execute(sql)
- # Process results based on query type
- if _is_select_query(sql):
- _process_select_results(cursor, results, limit)
- else:
+ # Use SQLScript for proper SQL parsing to determine query type
+ script = SQLScript(sql, database.db_engine_spec.engine)
+ if script.has_mutation():
_process_dml_results(cursor, conn, results)
+ else:
+ _process_select_results(cursor, results, limit)
except Exception as e:
logger.error("Error executing SQL: %s", e)
@@ -205,11 +214,6 @@ def _execute_query(
return results
-def _is_select_query(sql: str) -> bool:
- """Check if SQL is a SELECT query."""
- return sql.lower().strip().startswith("select")
-
-
def _process_select_results(cursor: Any, results: dict[str, Any], limit: int)
-> None:
"""Process SELECT query results."""
# Fetch results
diff --git a/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py
b/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py
new file mode 100644
index 0000000000..9a08b72a91
--- /dev/null
+++ b/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Unit tests for MCP SQL Lab utility functions."""
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from superset.mcp_service.sql_lab.sql_lab_utils import _apply_limit
+from superset.sql.parse import SQLScript
+
+
+class TestSQLScriptMutationDetection:
+ """Tests for SQLScript.has_mutation() used in query type detection."""
+
+ def test_simple_select_no_mutation(self):
+ """Test simple SELECT query is not a mutation."""
+ script = SQLScript("SELECT * FROM table1", "sqlite")
+ assert script.has_mutation() is False
+
+ def test_cte_query_no_mutation(self):
+ """Test CTE (WITH clause) query is not a mutation."""
+ cte_sql = """
+ WITH cte_name AS (
+ SELECT * FROM table1
+ )
+ SELECT * FROM cte_name
+ """
+ script = SQLScript(cte_sql, "sqlite")
+ assert script.has_mutation() is False
+
+ def test_recursive_cte_no_mutation(self):
+ """Test recursive CTE is not a mutation."""
+ recursive_cte = """
+ WITH RECURSIVE cte AS (
+ SELECT 1 AS n
+ UNION ALL
+ SELECT n + 1 FROM cte WHERE n < 10
+ )
+ SELECT n FROM cte
+ """
+ script = SQLScript(recursive_cte, "sqlite")
+ assert script.has_mutation() is False
+
+ def test_multiple_ctes_no_mutation(self):
+ """Test query with multiple CTEs is not a mutation."""
+ multiple_ctes = """
+ WITH
+ cte1 AS (SELECT 1 as a),
+ cte2 AS (SELECT 2 as b)
+ SELECT * FROM cte1, cte2
+ """
+ script = SQLScript(multiple_ctes, "sqlite")
+ assert script.has_mutation() is False
+
+ def test_insert_is_mutation(self):
+ """Test INSERT query is a mutation."""
+ script = SQLScript("INSERT INTO table1 VALUES (1)", "sqlite")
+ assert script.has_mutation() is True
+
+ def test_update_is_mutation(self):
+ """Test UPDATE query is a mutation."""
+ script = SQLScript("UPDATE table1 SET col = 1", "sqlite")
+ assert script.has_mutation() is True
+
+ def test_delete_is_mutation(self):
+ """Test DELETE query is a mutation."""
+ script = SQLScript("DELETE FROM table1", "sqlite")
+ assert script.has_mutation() is True
+
+ def test_create_is_mutation(self):
+ """Test CREATE query is a mutation."""
+ script = SQLScript("CREATE TABLE table1 (id INT)", "sqlite")
+ assert script.has_mutation() is True
+
+
+class TestApplyLimit:
+ """Tests for _apply_limit function using SQLScript."""
+
+ @pytest.fixture
+ def mock_database(self):
+ """Create a mock database with sqlite engine spec."""
+ db = MagicMock()
+ db.db_engine_spec.engine = "sqlite"
+ return db
+
+ def test_adds_limit_to_select(self, mock_database):
+ """Test LIMIT is added to SELECT query."""
+ result = _apply_limit("SELECT * FROM table1", 100, mock_database)
+ assert "LIMIT 100" in result
+
+ def test_adds_limit_to_cte(self, mock_database):
+ """Test LIMIT is added to CTE query."""
+ cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte"
+ result = _apply_limit(cte_sql, 50, mock_database)
+ assert "LIMIT 50" in result
+
+ def test_preserves_existing_limit(self, mock_database):
+ """Test existing LIMIT is not modified."""
+ sql = "SELECT * FROM table1 LIMIT 10"
+ result = _apply_limit(sql, 100, mock_database)
+ assert "LIMIT 10" in result
+ assert "LIMIT 100" not in result
+
+ def test_preserves_existing_limit_in_cte(self, mock_database):
+ """Test existing LIMIT in CTE query is not modified."""
+ cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 5"
+ result = _apply_limit(cte_sql, 100, mock_database)
+ assert "LIMIT 5" in result
+ assert "LIMIT 100" not in result
+
+ def test_no_limit_on_insert(self, mock_database):
+ """Test LIMIT is not added to INSERT query."""
+ sql = "INSERT INTO table1 VALUES (1)"
+ result = _apply_limit(sql, 100, mock_database)
+ assert result == sql
+
+ def test_no_limit_on_update(self, mock_database):
+ """Test LIMIT is not added to UPDATE query."""
+ sql = "UPDATE table1 SET col = 1"
+ result = _apply_limit(sql, 100, mock_database)
+ assert result == sql
diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
index bbf9e410b0..5b3012462d 100644
--- a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
+++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
@@ -431,19 +431,17 @@ class TestExecuteSql:
async def test_execute_sql_sql_injection_prevention(
self, mock_db, mock_security_manager, mcp_server
):
- """Test that SQL injection attempts are handled safely."""
+ """Test that SQL injection attempts are handled safely.
+
+ SQLScript detects the DROP TABLE as a mutation and blocks it
+ before execution when DML is not allowed on the database.
+ """
mock_database = _mock_database()
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
mock_database
)
mock_security_manager.can_access_database.return_value = True
- # Mock execute to raise an exception
- cursor = ( # fmt: skip
-
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
- )
- cursor.execute.side_effect = Exception("Syntax error")
-
request = {
"database_id": 1,
"sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
@@ -453,10 +451,12 @@ class TestExecuteSql:
async with Client(mcp_server) as client:
result = await client.call_tool("execute_sql", {"request":
request})
+ # SQLScript correctly detects DROP TABLE as a mutation
+ # and blocks it before execution (improved security)
assert result.data.success is False
assert result.data.error is not None
- assert "Syntax error" in result.data.error # Contains actual error
- assert result.data.error_type == "EXECUTION_ERROR"
+ assert "DML" in result.data.error or "mutates" in result.data.error
+ assert result.data.error_type == "DML_NOT_ALLOWED"
@pytest.mark.asyncio
async def test_execute_sql_empty_query_validation(self, mcp_server):