This is an automated email from the ASF dual-hosted git repository.

rusackas pushed a commit to branch master
in repository https://gitbox.apache.org/repos/asf/superset.git


The following commit(s) were added to refs/heads/master by this push:
     new e3e6b0e18b fix(mcp): use SQLScript for all SQL parsing in execute_sql 
(#36599)
e3e6b0e18b is described below

commit e3e6b0e18bddf2cf523131c5df98c567e6e3cda7
Author: Amin Ghadersohi <[email protected]>
AuthorDate: Sat Dec 20 23:52:56 2025 -0500

    fix(mcp): use SQLScript for all SQL parsing in execute_sql (#36599)
---
 superset/mcp_service/sql_lab/sql_lab_utils.py      |  64 +++++-----
 .../mcp_service/sql_lab/test_sql_lab_utils.py      | 137 +++++++++++++++++++++
 .../mcp_service/sql_lab/tool/test_execute_sql.py   |  18 +--
 3 files changed, 180 insertions(+), 39 deletions(-)

diff --git a/superset/mcp_service/sql_lab/sql_lab_utils.py 
b/superset/mcp_service/sql_lab/sql_lab_utils.py
index 6844e26a49..10c543c768 100644
--- a/superset/mcp_service/sql_lab/sql_lab_utils.py
+++ b/superset/mcp_service/sql_lab/sql_lab_utils.py
@@ -70,26 +70,22 @@ def validate_sql_query(sql: str, database: Any) -> None:
         SupersetDisallowedSQLFunctionException,
         SupersetDMLNotAllowedException,
     )
+    from superset.sql.parse import SQLScript
 
-    # Simplified validation without complex parsing
-    sql_upper = sql.upper().strip()
+    # Use SQLScript for proper SQL parsing
+    script = SQLScript(sql, database.db_engine_spec.engine)
 
     # Check for DML operations if not allowed
-    dml_keywords = ["INSERT", "UPDATE", "DELETE", "DROP", "CREATE", "ALTER", 
"TRUNCATE"]
-    if any(sql_upper.startswith(keyword) for keyword in dml_keywords):
-        if not database.allow_dml:
-            raise SupersetDMLNotAllowedException()
+    if script.has_mutation() and not database.allow_dml:
+        raise SupersetDMLNotAllowedException()
 
     # Check for disallowed functions from config
     disallowed_functions = app.config.get("DISALLOWED_SQL_FUNCTIONS", {}).get(
-        "sqlite",
-        set(),  # Default to sqlite for now
+        database.db_engine_spec.engine,
+        set(),
     )
-    if disallowed_functions:
-        sql_lower = sql.lower()
-        for func in disallowed_functions:
-            if f"{func.lower()}(" in sql_lower:
-                raise 
SupersetDisallowedSQLFunctionException(disallowed_functions)
+    if disallowed_functions and 
script.check_functions_present(disallowed_functions):
+        raise SupersetDisallowedSQLFunctionException(disallowed_functions)
 
 
 def execute_sql_query(
@@ -110,8 +106,8 @@ def execute_sql_query(
     sql = _apply_parameters(sql, parameters)
     validate_sql_query(sql, database)
 
-    # Apply limit for SELECT queries
-    rendered_sql = _apply_limit(sql, limit)
+    # Apply limit for SELECT queries using SQLScript
+    rendered_sql = _apply_limit(sql, limit, database)
 
     # Execute and get results
     results = _execute_query(database, rendered_sql, schema, limit)
@@ -156,12 +152,23 @@ def _apply_parameters(sql: str, parameters: dict[str, 
Any] | None) -> str:
     return sql
 
 
-def _apply_limit(sql: str, limit: int) -> str:
-    """Apply limit to SELECT queries if not already present."""
-    sql_lower = sql.lower().strip()
-    if sql_lower.startswith("select") and "limit" not in sql_lower:
-        return f"{sql.rstrip().rstrip(';')} LIMIT {limit}"
-    return sql
+def _apply_limit(sql: str, limit: int, database: Any) -> str:
+    """Apply limit to SELECT queries using SQLScript for proper parsing."""
+    from superset.sql.parse import LimitMethod, SQLScript
+
+    script = SQLScript(sql, database.db_engine_spec.engine)
+
+    # Only apply limit to non-mutating (SELECT-like) queries
+    if script.has_mutation():
+        return sql
+
+    # Apply limit to each statement in the script
+    for statement in script.statements:
+        # Only set limit if not already present
+        if statement.get_limit_value() is None:
+            statement.set_limit_value(limit, LimitMethod.FORCE_LIMIT)
+
+    return script.format()
 
 
 def _execute_query(
@@ -172,6 +179,7 @@ def _execute_query(
 ) -> dict[str, Any]:
     """Execute the query and process results."""
     # Import inside function to avoid initialization issues
+    from superset.sql.parse import SQLScript
     from superset.utils.core import QuerySource
 
     results = {
@@ -192,11 +200,12 @@ def _execute_query(
             cursor = conn.cursor()
             cursor.execute(sql)
 
-            # Process results based on query type
-            if _is_select_query(sql):
-                _process_select_results(cursor, results, limit)
-            else:
+            # Use SQLScript for proper SQL parsing to determine query type
+            script = SQLScript(sql, database.db_engine_spec.engine)
+            if script.has_mutation():
                 _process_dml_results(cursor, conn, results)
+            else:
+                _process_select_results(cursor, results, limit)
 
     except Exception as e:
         logger.error("Error executing SQL: %s", e)
@@ -205,11 +214,6 @@ def _execute_query(
     return results
 
 
-def _is_select_query(sql: str) -> bool:
-    """Check if SQL is a SELECT query."""
-    return sql.lower().strip().startswith("select")
-
-
 def _process_select_results(cursor: Any, results: dict[str, Any], limit: int) 
-> None:
     """Process SELECT query results."""
     # Fetch results
diff --git a/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py 
b/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py
new file mode 100644
index 0000000000..9a08b72a91
--- /dev/null
+++ b/tests/unit_tests/mcp_service/sql_lab/test_sql_lab_utils.py
@@ -0,0 +1,137 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+"""Unit tests for MCP SQL Lab utility functions."""
+
+from unittest.mock import MagicMock
+
+import pytest
+
+from superset.mcp_service.sql_lab.sql_lab_utils import _apply_limit
+from superset.sql.parse import SQLScript
+
+
+class TestSQLScriptMutationDetection:
+    """Tests for SQLScript.has_mutation() used in query type detection."""
+
+    def test_simple_select_no_mutation(self):
+        """Test simple SELECT query is not a mutation."""
+        script = SQLScript("SELECT * FROM table1", "sqlite")
+        assert script.has_mutation() is False
+
+    def test_cte_query_no_mutation(self):
+        """Test CTE (WITH clause) query is not a mutation."""
+        cte_sql = """
+        WITH cte_name AS (
+            SELECT * FROM table1
+        )
+        SELECT * FROM cte_name
+        """
+        script = SQLScript(cte_sql, "sqlite")
+        assert script.has_mutation() is False
+
+    def test_recursive_cte_no_mutation(self):
+        """Test recursive CTE is not a mutation."""
+        recursive_cte = """
+        WITH RECURSIVE cte AS (
+            SELECT 1 AS n
+            UNION ALL
+            SELECT n + 1 FROM cte WHERE n < 10
+        )
+        SELECT n FROM cte
+        """
+        script = SQLScript(recursive_cte, "sqlite")
+        assert script.has_mutation() is False
+
+    def test_multiple_ctes_no_mutation(self):
+        """Test query with multiple CTEs is not a mutation."""
+        multiple_ctes = """
+        WITH
+            cte1 AS (SELECT 1 as a),
+            cte2 AS (SELECT 2 as b)
+        SELECT * FROM cte1, cte2
+        """
+        script = SQLScript(multiple_ctes, "sqlite")
+        assert script.has_mutation() is False
+
+    def test_insert_is_mutation(self):
+        """Test INSERT query is a mutation."""
+        script = SQLScript("INSERT INTO table1 VALUES (1)", "sqlite")
+        assert script.has_mutation() is True
+
+    def test_update_is_mutation(self):
+        """Test UPDATE query is a mutation."""
+        script = SQLScript("UPDATE table1 SET col = 1", "sqlite")
+        assert script.has_mutation() is True
+
+    def test_delete_is_mutation(self):
+        """Test DELETE query is a mutation."""
+        script = SQLScript("DELETE FROM table1", "sqlite")
+        assert script.has_mutation() is True
+
+    def test_create_is_mutation(self):
+        """Test CREATE query is a mutation."""
+        script = SQLScript("CREATE TABLE table1 (id INT)", "sqlite")
+        assert script.has_mutation() is True
+
+
+class TestApplyLimit:
+    """Tests for _apply_limit function using SQLScript."""
+
+    @pytest.fixture
+    def mock_database(self):
+        """Create a mock database with sqlite engine spec."""
+        db = MagicMock()
+        db.db_engine_spec.engine = "sqlite"
+        return db
+
+    def test_adds_limit_to_select(self, mock_database):
+        """Test LIMIT is added to SELECT query."""
+        result = _apply_limit("SELECT * FROM table1", 100, mock_database)
+        assert "LIMIT 100" in result
+
+    def test_adds_limit_to_cte(self, mock_database):
+        """Test LIMIT is added to CTE query."""
+        cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte"
+        result = _apply_limit(cte_sql, 50, mock_database)
+        assert "LIMIT 50" in result
+
+    def test_preserves_existing_limit(self, mock_database):
+        """Test existing LIMIT is not modified."""
+        sql = "SELECT * FROM table1 LIMIT 10"
+        result = _apply_limit(sql, 100, mock_database)
+        assert "LIMIT 10" in result
+        assert "LIMIT 100" not in result
+
+    def test_preserves_existing_limit_in_cte(self, mock_database):
+        """Test existing LIMIT in CTE query is not modified."""
+        cte_sql = "WITH cte AS (SELECT 1) SELECT * FROM cte LIMIT 5"
+        result = _apply_limit(cte_sql, 100, mock_database)
+        assert "LIMIT 5" in result
+        assert "LIMIT 100" not in result
+
+    def test_no_limit_on_insert(self, mock_database):
+        """Test LIMIT is not added to INSERT query."""
+        sql = "INSERT INTO table1 VALUES (1)"
+        result = _apply_limit(sql, 100, mock_database)
+        assert result == sql
+
+    def test_no_limit_on_update(self, mock_database):
+        """Test LIMIT is not added to UPDATE query."""
+        sql = "UPDATE table1 SET col = 1"
+        result = _apply_limit(sql, 100, mock_database)
+        assert result == sql
diff --git a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py 
b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
index bbf9e410b0..5b3012462d 100644
--- a/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
+++ b/tests/unit_tests/mcp_service/sql_lab/tool/test_execute_sql.py
@@ -431,19 +431,17 @@ class TestExecuteSql:
     async def test_execute_sql_sql_injection_prevention(
         self, mock_db, mock_security_manager, mcp_server
     ):
-        """Test that SQL injection attempts are handled safely."""
+        """Test that SQL injection attempts are handled safely.
+
+        SQLScript detects the DROP TABLE as a mutation and blocks it
+        before execution when DML is not allowed on the database.
+        """
         mock_database = _mock_database()
         
mock_db.session.query.return_value.filter_by.return_value.first.return_value = (
             mock_database
         )
         mock_security_manager.can_access_database.return_value = True
 
-        # Mock execute to raise an exception
-        cursor = (  # fmt: skip
-            
mock_database.get_raw_connection.return_value.__enter__.return_value.cursor.return_value
-        )
-        cursor.execute.side_effect = Exception("Syntax error")
-
         request = {
             "database_id": 1,
             "sql": "SELECT * FROM users WHERE id = 1; DROP TABLE users;--",
@@ -453,10 +451,12 @@ class TestExecuteSql:
         async with Client(mcp_server) as client:
             result = await client.call_tool("execute_sql", {"request": 
request})
 
+            # SQLScript correctly detects DROP TABLE as a mutation
+            # and blocks it before execution (improved security)
             assert result.data.success is False
             assert result.data.error is not None
-            assert "Syntax error" in result.data.error  # Contains actual error
-            assert result.data.error_type == "EXECUTION_ERROR"
+            assert "DML" in result.data.error or "mutates" in result.data.error
+            assert result.data.error_type == "DML_NOT_ALLOWED"
 
     @pytest.mark.asyncio
     async def test_execute_sql_empty_query_validation(self, mcp_server):

Reply via email to