This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 13827271e04 Add validation for table_name and expression_list in 
DatabricksCopyIntoOperator (#62499)
13827271e04 is described below

commit 13827271e0414d7d46256cebf3d3c79b5594facd
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Mar 3 20:48:19 2026 +0000

    Add validation for table_name and expression_list in 
DatabricksCopyIntoOperator (#62499)
    
    Validate table_name to ensure each dot-separated segment is a valid SQL 
identifier
    and prevent it from mutating the generated COPY INTO statement. Block 
multi-statement
    tokens in expression_list to prevent statement boundary escape and 
destructive SQL
    injection via projection clauses.
    
    Includes input validation and unit tests covering both valid and invalid 
identifier
    and expression_list scenarios.
    
    Co-authored-by: Sameer Mesiah <[email protected]>
---
 .../databricks/operators/databricks_sql.py         | 24 +++++-
 .../databricks/operators/test_databricks_copy.py   | 91 ++++++++++++++++++++++
 2 files changed, 114 insertions(+), 1 deletion(-)

diff --git 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
index 840fa7bc8b8..b50c434d04c 100644
--- 
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
+++ 
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
@@ -22,6 +22,7 @@ from __future__ import annotations
 import csv
 import json
 import os
+import re
 from collections.abc import Sequence
 from functools import cached_property
 from tempfile import NamedTemporaryFile
@@ -41,6 +42,9 @@ from airflow.providers.databricks.hooks.databricks_sql import 
DatabricksSqlHook
 if TYPE_CHECKING:
     from airflow.providers.common.compat.sdk import Context
 
+_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
+_DISALLOWED_SQL_TOKENS = (";", "--", "/*", "*/")
+
 
 class DatabricksSqlOperator(SQLExecuteQueryOperator):
     """
@@ -447,7 +451,25 @@ class DatabricksCopyIntoOperator(BaseOperator):
 
         return formatted_opts
 
+    def _validate_sql_fragments(self) -> None:
+        # Validate table_name segments (supports table, schema.table, 
catalog.schema.table).
+        parts = self.table_name.split(".")
+        for part in parts:
+            if not part or not _IDENTIFIER_RE.match(part):
+                raise ValueError(
+                    f"Invalid table identifier segment '{part}' in 
'{self.table_name}'. "
+                    "Only alphanumeric characters and underscores are allowed."
+                )
+
+        # Prevent multi-statement injection via expression_list.
+        if self._expression_list:
+            for token in _DISALLOWED_SQL_TOKENS:
+                if token in self._expression_list:
+                    raise ValueError("expression_list must not contain 
statement separators or comments.")
+
     def _create_sql_query(self) -> str:
+
+        self._validate_sql_fragments()
         escaper = ParamEscaper()
         maybe_with = ""
         if self._encryption is not None or self._credential is not None:
@@ -484,7 +506,7 @@ class DatabricksCopyIntoOperator(BaseOperator):
                 validation = f"VALIDATE {self._validate} ROWS\n"
             else:
                 raise AirflowException(f"Incorrect data type for validate 
parameter: {type(self._validate)}")
-        # TODO: think on how to make sure that table_name and expression_list 
aren't used for SQL injection
+
         sql = f"""COPY INTO {self.table_name}{storage_cred}
 FROM {location}
 FILEFORMAT = {self._file_format}
diff --git 
a/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py 
b/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py
index eddaf1d194a..f00653f8b22 100644
--- 
a/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py
+++ 
b/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py
@@ -240,6 +240,97 @@ def test_incorrect_params_wrong_format():
         )
 
 
[email protected](
+    "table_name",
+    [
+        "safe; DROP TABLE x",
+        "safe table",
+        "safe-table",
+        "safe()",
+        "safe--comment",
+        "1invalid",
+        ".table",
+        "schema.",
+    ],
+)
+def test_invalid_table_identifier_rejected(table_name):
+    op = DatabricksCopyIntoOperator(
+        file_location=COPY_FILE_LOCATION,
+        file_format="JSON",
+        table_name=table_name,
+        task_id=TASK_ID,
+    )
+
+    with pytest.raises(ValueError, match="Invalid table identifier"):
+        op._create_sql_query()
+
+
[email protected](
+    "table_name",
+    [
+        "table",
+        "schema.table",
+        "catalog.schema.table",
+        "_table",
+        "table_123",
+    ],
+)
+def test_valid_table_identifier_allowed(table_name):
+    op = DatabricksCopyIntoOperator(
+        file_location=COPY_FILE_LOCATION,
+        file_format="JSON",
+        table_name=table_name,
+        task_id=TASK_ID,
+    )
+
+    sql = op._create_sql_query()
+    assert f"COPY INTO {table_name}" in sql
+
+
[email protected](
+    "expression_list",
+    [
+        "col1; DROP TABLE x",
+        "col1 -- comment",
+        "col1 /* comment */",
+    ],
+)
+def test_expression_list_rejects_multi_statement(expression_list):
+    op = DatabricksCopyIntoOperator(
+        file_location=COPY_FILE_LOCATION,
+        file_format="JSON",
+        table_name="test",
+        task_id=TASK_ID,
+        expression_list=expression_list,
+    )
+
+    with pytest.raises(ValueError, match="expression_list"):
+        op._create_sql_query()
+
+
[email protected](
+    "expression_list",
+    [
+        "*",
+        "col1",
+        "col1, col2",
+        "upper(col1) as col1",
+        "cast(_c0 as int) as id",
+    ],
+)
+def test_valid_expression_list_allowed(expression_list):
+    op = DatabricksCopyIntoOperator(
+        file_location=COPY_FILE_LOCATION,
+        file_format="JSON",
+        table_name="test",
+        task_id=TASK_ID,
+        expression_list=expression_list,
+    )
+
+    sql = op._create_sql_query()
+    assert f"SELECT {expression_list}" in sql
+
+
 @pytest.mark.db_test
 def test_templating(create_task_instance_of_operator, session):
     ti = create_task_instance_of_operator(

Reply via email to