This is an automated email from the ASF dual-hosted git repository.
potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 13827271e04 Add validation for table_name and expression_list in
DatabricksCopyIntoOperator (#62499)
13827271e04 is described below
commit 13827271e0414d7d46256cebf3d3c79b5594facd
Author: SameerMesiah97 <[email protected]>
AuthorDate: Tue Mar 3 20:48:19 2026 +0000
Add validation for table_name and expression_list in
DatabricksCopyIntoOperator (#62499)
Validate table_name to ensure each dot-separated segment is a valid SQL
identifier
and prevent it from mutating the generated COPY INTO statement. Block
multi-statement
tokens in expression_list to prevent statement boundary escape and
destructive SQL
injection via projection clauses.
Includes input validation and unit tests covering both valid and invalid
identifier
and expression_list scenarios.
Co-authored-by: Sameer Mesiah <[email protected]>
---
.../databricks/operators/databricks_sql.py | 24 +++++-
.../databricks/operators/test_databricks_copy.py | 91 ++++++++++++++++++++++
2 files changed, 114 insertions(+), 1 deletion(-)
diff --git
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
index 840fa7bc8b8..b50c434d04c 100644
---
a/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
+++
b/providers/databricks/src/airflow/providers/databricks/operators/databricks_sql.py
@@ -22,6 +22,7 @@ from __future__ import annotations
import csv
import json
import os
+import re
from collections.abc import Sequence
from functools import cached_property
from tempfile import NamedTemporaryFile
@@ -41,6 +42,9 @@ from airflow.providers.databricks.hooks.databricks_sql import
DatabricksSqlHook
if TYPE_CHECKING:
from airflow.providers.common.compat.sdk import Context
+_IDENTIFIER_RE = re.compile(r"^[A-Za-z_][A-Za-z0-9_]*$")
+_DISALLOWED_SQL_TOKENS = (";", "--", "/*", "*/")
+
class DatabricksSqlOperator(SQLExecuteQueryOperator):
"""
@@ -447,7 +451,25 @@ class DatabricksCopyIntoOperator(BaseOperator):
return formatted_opts
+ def _validate_sql_fragments(self) -> None:
+ # Validate table_name segments (supports table, schema.table,
catalog.schema.table).
+ parts = self.table_name.split(".")
+ for part in parts:
+ if not part or not _IDENTIFIER_RE.match(part):
+ raise ValueError(
+ f"Invalid table identifier segment '{part}' in
'{self.table_name}'. "
+ "Only alphanumeric characters and underscores are allowed."
+ )
+
+ # Prevent multi-statement injection via expression_list.
+ if self._expression_list:
+ for token in _DISALLOWED_SQL_TOKENS:
+ if token in self._expression_list:
+ raise ValueError("expression_list must not contain
statement separators or comments.")
+
def _create_sql_query(self) -> str:
+
+ self._validate_sql_fragments()
escaper = ParamEscaper()
maybe_with = ""
if self._encryption is not None or self._credential is not None:
@@ -484,7 +506,7 @@ class DatabricksCopyIntoOperator(BaseOperator):
validation = f"VALIDATE {self._validate} ROWS\n"
else:
raise AirflowException(f"Incorrect data type for validate
parameter: {type(self._validate)}")
- # TODO: think on how to make sure that table_name and expression_list
aren't used for SQL injection
+
sql = f"""COPY INTO {self.table_name}{storage_cred}
FROM {location}
FILEFORMAT = {self._file_format}
diff --git
a/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py
b/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py
index eddaf1d194a..f00653f8b22 100644
---
a/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py
+++
b/providers/databricks/tests/unit/databricks/operators/test_databricks_copy.py
@@ -240,6 +240,97 @@ def test_incorrect_params_wrong_format():
)
[email protected](
+ "table_name",
+ [
+ "safe; DROP TABLE x",
+ "safe table",
+ "safe-table",
+ "safe()",
+ "safe--comment",
+ "1invalid",
+ ".table",
+ "schema.",
+ ],
+)
+def test_invalid_table_identifier_rejected(table_name):
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="JSON",
+ table_name=table_name,
+ task_id=TASK_ID,
+ )
+
+ with pytest.raises(ValueError, match="Invalid table identifier"):
+ op._create_sql_query()
+
+
[email protected](
+ "table_name",
+ [
+ "table",
+ "schema.table",
+ "catalog.schema.table",
+ "_table",
+ "table_123",
+ ],
+)
+def test_valid_table_identifier_allowed(table_name):
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="JSON",
+ table_name=table_name,
+ task_id=TASK_ID,
+ )
+
+ sql = op._create_sql_query()
+ assert f"COPY INTO {table_name}" in sql
+
+
[email protected](
+ "expression_list",
+ [
+ "col1; DROP TABLE x",
+ "col1 -- comment",
+ "col1 /* comment */",
+ ],
+)
+def test_expression_list_rejects_multi_statement(expression_list):
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="JSON",
+ table_name="test",
+ task_id=TASK_ID,
+ expression_list=expression_list,
+ )
+
+ with pytest.raises(ValueError, match="expression_list"):
+ op._create_sql_query()
+
+
[email protected](
+ "expression_list",
+ [
+ "*",
+ "col1",
+ "col1, col2",
+ "upper(col1) as col1",
+ "cast(_c0 as int) as id",
+ ],
+)
+def test_valid_expression_list_allowed(expression_list):
+ op = DatabricksCopyIntoOperator(
+ file_location=COPY_FILE_LOCATION,
+ file_format="JSON",
+ table_name="test",
+ task_id=TASK_ID,
+ expression_list=expression_list,
+ )
+
+ sql = op._create_sql_query()
+ assert f"SELECT {expression_list}" in sql
+
+
@pytest.mark.db_test
def test_templating(create_task_instance_of_operator, session):
ti = create_task_instance_of_operator(