This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 73d0ee88b50 Add `HookToolset` and `SQLToolset` for agentic LLM
workflows (#62785)
73d0ee88b50 is described below
commit 73d0ee88b50ddc9c2e1e22250a6a11083b975780
Author: Kaxil Naik <[email protected]>
AuthorDate: Tue Mar 3 20:50:43 2026 +0000
Add `HookToolset` and `SQLToolset` for agentic LLM workflows (#62785)
HookToolset: Generic adapter that exposes any Airflow Hook's methods
as pydantic-ai tools via introspection. Requires explicit
allowed_methods list (no auto-discovery). Builds JSON Schema from
method signatures and enriches tool descriptions from docstrings.
SQLToolset: Curated 4-tool database toolset (list_tables, get_schema,
query, check_query) wrapping DbApiHook. Read-only by default with SQL
validation, allowed_tables metadata filtering, and max_rows truncation.
Both implement pydantic-ai's AbstractToolset interface with
sequential=True on all tool definitions to prevent concurrent sync I/O.
* Fix mypy error: annotate result variable in SQLToolset._query
The list comprehension in the else branch produces list[list[Any]]
while the if branch produces list[dict[str, Any]]. Add an explicit
type annotation to satisfy mypy.
* Add toolset/agentic/ctx to spelling wordlist
Sphinx autoapi generates RST from pydantic-ai's AbstractToolset base
class docstrings. These words appear in the auto-generated docs and
need to be in the global wordlist.
Docs for HookToolset (generic hook→tools adapter) and SQLToolset
(curated 4-tool DB toolset). Includes defense layers table,
allowed_tables limitation, HookToolset guidelines, recommended
configurations, and production checklist.
---
docs/spelling_wordlist.txt | 7 +
providers/common/ai/docs/index.rst | 1 +
providers/common/ai/docs/toolsets.rst | 273 ++++++++++++++++++++
.../providers/common/ai/toolsets/__init__.py | 35 +++
.../airflow/providers/common/ai/toolsets/hook.py | 267 ++++++++++++++++++++
.../airflow/providers/common/ai/toolsets/sql.py | 231 +++++++++++++++++
.../ai/tests/unit/common/ai/toolsets/__init__.py | 16 ++
.../ai/tests/unit/common/ai/toolsets/test_hook.py | 281 +++++++++++++++++++++
.../ai/tests/unit/common/ai/toolsets/test_sql.py | 233 +++++++++++++++++
9 files changed, 1344 insertions(+)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index fbfbeda2b00..81fd9cc567e 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1,6 +1,7 @@
aarch
abc
AbstractFileSystem
+AbstractToolset
accessor
AccessSecretVersionResponse
accountmaking
@@ -24,6 +25,7 @@ adobjects
AdsInsights
adsinsights
afterall
+agentic
AgentKey
ai
aio
@@ -375,6 +377,7 @@ Ctl
ctl
ctor
Ctrl
+ctx
cubeName
customDataImportUids
customizability
@@ -831,6 +834,7 @@ Gzip
gzipped
hadoop
hadoopcmd
+hardcode
hardcoded
Harenslak
Hashable
@@ -1918,6 +1922,9 @@ tokopedia
tolerations
toml
toolchain
+toolset
+Toolsets
+toolsets
Tooltip
tooltip
tooltips
diff --git a/providers/common/ai/docs/index.rst
b/providers/common/ai/docs/index.rst
index 764ae24c4dd..87fbe994716 100644
--- a/providers/common/ai/docs/index.rst
+++ b/providers/common/ai/docs/index.rst
@@ -36,6 +36,7 @@
Connection types <connections/pydantic_ai>
Hooks <hooks/pydantic_ai>
+ Toolsets <toolsets>
Operators <operators/index>
.. toctree::
diff --git a/providers/common/ai/docs/toolsets.rst
b/providers/common/ai/docs/toolsets.rst
new file mode 100644
index 00000000000..7334a5ae0a4
--- /dev/null
+++ b/providers/common/ai/docs/toolsets.rst
@@ -0,0 +1,273 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/toolsets:
+
+Toolsets — Airflow Hooks as AI Agent Tools
+==========================================
+
+Airflow's 350+ provider hooks already have typed methods, rich docstrings,
+and managed credentials. Toolsets expose them as pydantic-ai tools so that
+LLM agents can call them during multi-turn reasoning.
+
+Two toolsets are included:
+
+- :class:`~airflow.providers.common.ai.toolsets.hook.HookToolset` — generic
+ adapter for any Airflow Hook.
+- :class:`~airflow.providers.common.ai.toolsets.sql.SQLToolset` — curated
+ 4-tool database toolset.
+
+Both implement pydantic-ai's
+`AbstractToolset <https://ai.pydantic.dev/toolsets/>`__ interface and can be
+passed to any pydantic-ai ``Agent``, including via
+:class:`~airflow.providers.common.ai.operators.agent.AgentOperator`.
+
+
+``HookToolset``
+---------------
+
+Generic adapter that exposes selected methods of any Airflow Hook as
+pydantic-ai tools via introspection. Requires an explicit ``allowed_methods``
+list — there is no auto-discovery.
+
+.. code-block:: python
+
+ from airflow.providers.http.hooks.http import HttpHook
+ from airflow.providers.common.ai.toolsets.hook import HookToolset
+
+ http_hook = HttpHook(http_conn_id="my_api")
+
+ toolset = HookToolset(
+ http_hook,
+ allowed_methods=["run"],
+ tool_name_prefix="http_",
+ )
+
+For each listed method, the introspection engine:
+
+1. Builds a JSON Schema from the method signature (``inspect.signature`` +
+ ``get_type_hints``).
+2. Extracts the description from the first paragraph of the docstring.
+3. Enriches parameter descriptions from Sphinx ``:param:`` or Google
+ ``Args:`` blocks.
+
+Parameters
+^^^^^^^^^^
+
+- ``hook``: An instantiated Airflow Hook.
+- ``allowed_methods``: Method names to expose as tools. Required. Methods
+ are validated with ``hasattr`` + ``callable`` at instantiation time.
+- ``tool_name_prefix``: Optional prefix prepended to each tool name
+ (e.g. ``"s3_"`` produces ``"s3_list_keys"``).
+
+
+``SQLToolset``
+--------------
+
+Curated toolset wrapping
+:class:`~airflow.providers.common.sql.hooks.sql.DbApiHook` with four tools:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 20 50
+
+ * - Tool
+ - Description
+ * - ``list_tables``
+ - Lists available table names (filtered by ``allowed_tables`` if set)
+ * - ``get_schema``
+ - Returns column names and types for a table
+ * - ``query``
+ - Executes a SQL query and returns rows as JSON
+ * - ``check_query``
+ - Validates SQL syntax without executing it
+
+.. code-block:: python
+
+ from airflow.providers.common.ai.toolsets.sql import SQLToolset
+
+ toolset = SQLToolset(
+ db_conn_id="postgres_default",
+ allowed_tables=["customers", "orders"],
+ max_rows=20,
+ )
+
+The ``DbApiHook`` is resolved lazily from ``db_conn_id`` on first tool call
+via ``BaseHook.get_connection(conn_id).get_hook()``.
+
+Parameters
+^^^^^^^^^^
+
+- ``db_conn_id``: Airflow connection ID for the database.
+- ``allowed_tables``: Restrict which tables the agent can discover via
+ ``list_tables`` and ``get_schema``. ``None`` (default) exposes all tables.
+ See :ref:`allowed-tables-limitation` for an important caveat.
+- ``schema``: Database schema/namespace for table listing and introspection.
+- ``allow_writes``: Allow data-modifying SQL (INSERT, UPDATE, DELETE, etc.).
+ Default ``False`` — only SELECT-family statements are permitted.
+- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
+
+
+Security
+--------
+
+LLM agents call tools based on natural-language reasoning. This makes them
+powerful but introduces risks that don't exist with deterministic operators.
+
+Defense Layers
+^^^^^^^^^^^^^^
+
+No single layer is sufficient — they work together.
+
+.. list-table::
+ :header-rows: 1
+ :widths: 20 40 40
+
+ * - Layer
+ - What it does
+ - What it does NOT do
+ * - **Airflow Connections**
+ - Credentials are stored in Airflow's secret backend, never in DAG code.
+ The LLM agent cannot see API keys or database passwords.
+ - Does not prevent the agent from using the connection to access data
+ the connection has access to.
+ * - **HookToolset: explicit allow-list**
+ - Only methods listed in ``allowed_methods`` are exposed as tools.
+ Auto-discovery is not supported. Methods are validated at DAG parse
+ time.
+ - Does not restrict what arguments the agent passes to allowed methods.
+ * - **SQLToolset: read-only by default**
+ - ``allow_writes=False`` (default) validates every SQL query through
+ ``validate_sql()`` and rejects INSERT, UPDATE, DELETE, DROP, etc.
+ - Does not prevent the agent from reading sensitive data that the
+ database user has SELECT access to.
+ * - **SQLToolset: allowed_tables**
+ - Restricts which tables appear in ``list_tables`` and ``get_schema``
+ responses, limiting the agent's knowledge of the schema.
+ - Does **not** validate table references in SQL queries. The agent can
+ still query unlisted tables if it guesses the name. See
+ :ref:`allowed-tables-limitation` below.
+ * - **SQLToolset: max_rows**
+ - Truncates query results to ``max_rows`` (default 50), preventing the
+ agent from pulling entire tables into context.
+ - Does not limit the number of queries the agent can make.
+ * - **pydantic-ai: tool call budget**
+ - pydantic-ai's ``max_result_retries`` and ``model_settings`` control
+ how many tool-call rounds the agent can make before stopping.
+ - Requires explicit configuration — the default allows many rounds.
+
+
+.. _allowed-tables-limitation:
+
+The ``allowed_tables`` Limitation
+"""""""""""""""""""""""""""""""""
+
+``allowed_tables`` is a **metadata filter**, not an access control mechanism.
+It hides table names from ``list_tables`` and blocks ``get_schema`` for
+unlisted tables, but does not parse SQL queries to validate table references.
+
+An LLM can craft ``SELECT * FROM secrets`` even when
+``allowed_tables=["orders"]``. Parsing SQL for table references (including
+CTEs, subqueries, aliases, and vendor-specific syntax) is complex and
+error-prone; we chose not to provide a false sense of security.
+
+For query-level restrictions, use database permissions:
+
+.. code-block:: sql
+
+ -- Create a read-only role with access to specific tables only
+ CREATE ROLE airflow_agent_reader;
+ GRANT SELECT ON orders, customers TO airflow_agent_reader;
+ -- Use this role's credentials in the Airflow connection
+
+The Airflow connection should use a database user with the minimum privileges
+required.
+
+
+HookToolset Guidelines
+""""""""""""""""""""""
+
+- List only the methods the agent needs. Never expose ``run()`` or
+ ``get_connection()`` — these give broad access.
+- Prefer read-only methods (``list_*``, ``get_*``, ``describe_*``).
+- The agent controls arguments. If a method accepts a ``path`` parameter,
+ the agent can pass any path the hook has access to.
+
+.. code-block:: python
+
+ # Good: expose only list and read
+ HookToolset(
+ s3_hook,
+ allowed_methods=["list_keys", "read_key"],
+ tool_name_prefix="s3_",
+ )
+
+ # Bad: exposes delete and write operations
+ HookToolset(
+ s3_hook,
+ allowed_methods=["list_keys", "read_key", "delete_object",
"load_string"],
+ )
+
+
+Recommended Configuration
+"""""""""""""""""""""""""
+
+**Read-only analytics** (the most common pattern):
+
+.. code-block:: python
+
+ SQLToolset(
+ db_conn_id="analytics_readonly", # Connection with SELECT-only grants
+ allowed_tables=["orders", "customers"], # Hide other tables from agent
+ allow_writes=False, # Default — validates SQL
+ max_rows=50, # Default — truncate large results
+ )
+
+**Agents that need to modify data** (use with caution):
+
+.. code-block:: python
+
+ SQLToolset(
+ db_conn_id="app_db",
+ allowed_tables=["user_preferences"],
+ allow_writes=True, # Disables SQL validation — agent can INSERT/UPDATE
+ max_rows=100,
+ )
+
+
+Production Checklist
+""""""""""""""""""""
+
+Before deploying an agent task to production:
+
+1. **Connection credentials**: Use Airflow's secret backend. Never hardcode
+ API keys in DAG files.
+2. **Database permissions**: Create a dedicated database user with minimum
+ required grants. Don't reuse the admin connection.
+3. **Tool allow-list**: Review ``allowed_methods`` / ``allowed_tables``. The
+ agent can call any exposed tool with any arguments.
+4. **Read-only default**: Keep ``allow_writes=False`` unless the task
+ specifically requires writes.
+5. **Row limits**: Set ``max_rows`` appropriate to the use case. Large
+ result sets consume LLM context and increase cost.
+6. **Model budget**: Configure pydantic-ai's ``model_settings`` (e.g.
+ ``max_tokens``) and ``retries`` to bound cost and prevent runaway loops.
+7. **System prompt**: Include safety instructions in ``system_prompt`` (e.g.
+ "Only query tables related to the question. Never modify data.").
+8. **Prompt injection**: Be cautious when the prompt includes untrusted data
+ (user input, external API responses, upstream XCom). Consider sanitizing
+ inputs before passing them to the agent.
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/__init__.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/__init__.py
new file mode 100644
index 00000000000..aba5a45ee07
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/__init__.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Toolsets for exposing Airflow hooks as pydantic-ai agent tools."""
+
+from __future__ import annotations
+
+from airflow.providers.common.ai.toolsets.hook import HookToolset
+
+__all__ = ["HookToolset", "SQLToolset"]
+
+
+def __getattr__(name: str):
+ if name == "SQLToolset":
+ try:
+ from airflow.providers.common.ai.toolsets.sql import SQLToolset
+ except ImportError as e:
+ from airflow.providers.common.compat.sdk import
AirflowOptionalProviderFeatureException
+
+ raise AirflowOptionalProviderFeatureException(e)
+ return SQLToolset
+ raise AttributeError(f"module {__name__!r} has no attribute {name!r}")
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/hook.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/hook.py
new file mode 100644
index 00000000000..ae3987b6c0e
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/hook.py
@@ -0,0 +1,267 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Generic adapter that exposes Airflow Hook methods as pydantic-ai tools."""
+
+from __future__ import annotations
+
+import inspect
+import json
+import re
+import types
+from typing import TYPE_CHECKING, Any, Union, get_args, get_origin,
get_type_hints
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+ from collections.abc import Callable
+
+ from pydantic_ai._run_context import RunContext
+
+ from airflow.providers.common.compat.sdk import BaseHook
+
+# Single shared validator — accepts any JSON-decoded dict from the LLM.
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# Maps Python types to JSON Schema fragments.
+_TYPE_MAP: dict[type, dict[str, Any]] = {
+ str: {"type": "string"},
+ int: {"type": "integer"},
+ float: {"type": "number"},
+ bool: {"type": "boolean"},
+ list: {"type": "array"},
+ dict: {"type": "object"},
+ bytes: {"type": "string"},
+}
+
+
+class HookToolset(AbstractToolset[Any]):
+ """
+ Expose selected methods of an Airflow Hook as pydantic-ai tools.
+
+ This adapter introspects the method signatures and docstrings of the given
+ hook to build :class:`~pydantic_ai.tools.ToolDefinition` objects that an
LLM
+ agent can call.
+
+ :param hook: An instantiated Airflow Hook.
+ :param allowed_methods: Method names to expose as tools. Required —
+ auto-discovery is intentionally not supported for safety.
+ :param tool_name_prefix: Optional prefix prepended to each tool name
+ (e.g. ``"s3_"`` → ``"s3_list_keys"``).
+ """
+
+ def __init__(
+ self,
+ hook: BaseHook,
+ *,
+ allowed_methods: list[str],
+ tool_name_prefix: str = "",
+ ) -> None:
+ if not allowed_methods:
+ raise ValueError("allowed_methods must be a non-empty list.")
+
+ hook_cls_name = type(hook).__name__
+ for method_name in allowed_methods:
+ if not hasattr(hook, method_name):
+ raise ValueError(
+ f"Hook {hook_cls_name!r} has no method {method_name!r}.
Check your allowed_methods list."
+ )
+ if not callable(getattr(hook, method_name)):
+ raise ValueError(f"{hook_cls_name}.{method_name} is not
callable.")
+
+ self._hook = hook
+ self._allowed_methods = allowed_methods
+ self._tool_name_prefix = tool_name_prefix
+ self._id = f"hook-{type(hook).__name__}"
+
+ @property
+ def id(self) -> str:
+ return self._id
+
+ async def get_tools(self, ctx: RunContext[Any]) -> dict[str,
ToolsetTool[Any]]:
+ tools: dict[str, ToolsetTool[Any]] = {}
+ for method_name in self._allowed_methods:
+ method = getattr(self._hook, method_name)
+ tool_name = f"{self._tool_name_prefix}{method_name}" if
self._tool_name_prefix else method_name
+
+ json_schema = _build_json_schema_from_signature(method)
+ description = _extract_description(method)
+ param_docs = _parse_param_docs(method.__doc__ or "")
+
+ # Enrich parameter descriptions from docstring.
+ for param_name, param_desc in param_docs.items():
+ if param_name in json_schema.get("properties", {}):
+ json_schema["properties"][param_name]["description"] =
param_desc
+
+ # sequential=True because hook methods perform synchronous I/O
+ # (network calls, DB queries) and should not run concurrently.
+ tool_def = ToolDefinition(
+ name=tool_name,
+ description=description,
+ parameters_json_schema=json_schema,
+ sequential=True,
+ )
+ tools[tool_name] = ToolsetTool(
+ toolset=self,
+ tool_def=tool_def,
+ max_retries=1,
+ args_validator=_PASSTHROUGH_VALIDATOR,
+ )
+ return tools
+
+ async def call_tool(
+ self,
+ name: str,
+ tool_args: dict[str, Any],
+ ctx: RunContext[Any],
+ tool: ToolsetTool[Any],
+ ) -> Any:
+ method_name = name.removeprefix(self._tool_name_prefix) if
self._tool_name_prefix else name
+ method: Callable[..., Any] = getattr(self._hook, method_name)
+ result = method(**tool_args)
+ return _serialize_for_llm(result)
+
+
+# ---------------------------------------------------------------------------
+# Private introspection helpers
+# ---------------------------------------------------------------------------
+
+
+def _python_type_to_json_schema(annotation: Any) -> dict[str, Any]:
+ """Convert a Python type annotation to a JSON Schema fragment."""
+ if annotation is inspect.Parameter.empty or annotation is Any:
+ return {"type": "string"}
+
+ origin = get_origin(annotation)
+ args = get_args(annotation)
+
+ # Optional[X] is Union[X, None] — handle both types.UnionType (3.10+) and
typing.Union
+ if origin is types.UnionType or origin is Union:
+ non_none = [a for a in args if a is not type(None)]
+ if len(non_none) == 1:
+ return _python_type_to_json_schema(non_none[0])
+ return {"type": "string"}
+
+ # list[X]
+ if origin is list:
+ items = _python_type_to_json_schema(args[0]) if args else {"type":
"string"}
+ return {"type": "array", "items": items}
+
+ # dict[K, V]
+ if origin is dict:
+ return {"type": "object"}
+
+ # Always return a fresh copy — callers may mutate the dict (e.g. adding
"description").
+ schema = _TYPE_MAP.get(annotation)
+ return dict(schema) if schema else {"type": "string"}
+
+
+def _build_json_schema_from_signature(method: Callable[..., Any]) -> dict[str,
Any]:
+ """Build a JSON Schema ``object`` from a method's signature and type
hints."""
+ sig = inspect.signature(method)
+
+ try:
+ hints = get_type_hints(method)
+ except Exception:
+ hints = {}
+
+ properties: dict[str, Any] = {}
+ required: list[str] = []
+
+ for name, param in sig.parameters.items():
+ if name in ("self", "cls"):
+ continue
+ # Skip **kwargs and *args
+ if param.kind in (param.VAR_POSITIONAL, param.VAR_KEYWORD):
+ continue
+
+ annotation = hints.get(name, param.annotation)
+ prop = _python_type_to_json_schema(annotation)
+ properties[name] = prop
+
+ if param.default is inspect.Parameter.empty:
+ required.append(name)
+
+ schema: dict[str, Any] = {"type": "object", "properties": properties}
+ if required:
+ schema["required"] = required
+ return schema
+
+
+def _extract_description(method: Callable[..., Any]) -> str:
+ """Return the first paragraph of a method's docstring."""
+ doc = inspect.getdoc(method)
+ if not doc:
+ return method.__name__.replace("_", " ").capitalize()
+
+ # First paragraph = everything up to the first blank line.
+ lines: list[str] = []
+ for line in doc.splitlines():
+ if not line.strip():
+ if lines:
+ break
+ continue
+ lines.append(line.strip())
+ return " ".join(lines) if lines else method.__name__.replace("_", "
").capitalize()
+
+
+# Matches Sphinx-style `:param name:` and Google-style `name:` under an
``Args:`` block.
+_SPHINX_PARAM_RE = re.compile(r":param\s+(\w+):\s*(.+?)(?=\n\s*:|$)",
re.DOTALL)
+_GOOGLE_ARGS_RE = re.compile(r"^\s{2,}(\w+)\s*(?:\(.+?\))?:\s*(.+)",
re.MULTILINE)
+
+
+def _parse_param_docs(docstring: str) -> dict[str, str]:
+ """Parse parameter descriptions from Sphinx or Google-style docstrings."""
+ params: dict[str, str] = {}
+
+ # Try Sphinx style first.
+ for match in _SPHINX_PARAM_RE.finditer(docstring):
+ name = match.group(1)
+ desc = " ".join(match.group(2).split())
+ params[name] = desc
+
+ if params:
+ return params
+
+ # Fall back to Google style (``Args:`` section).
+ in_args = False
+ for line in docstring.splitlines():
+ stripped = line.strip()
+ if stripped.lower().startswith("args:"):
+ in_args = True
+ continue
+ if in_args:
+ if stripped and not stripped[0].isspace() and ":" not in stripped:
+ break
+ m = _GOOGLE_ARGS_RE.match(line)
+ if m:
+ params[m.group(1)] = " ".join(m.group(2).split())
+
+ return params
+
+
+def _serialize_for_llm(value: Any) -> str:
+ """Convert a Python return value to a string suitable for an LLM."""
+ if value is None:
+ return "null"
+ if isinstance(value, str):
+ return value
+ try:
+ return json.dumps(value, default=str)
+ except (TypeError, ValueError):
+ return str(value)
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
new file mode 100644
index 00000000000..f60f4b621c3
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/sql.py
@@ -0,0 +1,231 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DbApiHook for agentic database workflows."""
+
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any
+
+try:
+ from airflow.providers.common.ai.utils.sql_validation import validate_sql
as _validate_sql
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
+except ImportError as e:
+ from airflow.providers.common.compat.sdk import
AirflowOptionalProviderFeatureException
+
+ raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+from airflow.providers.common.compat.sdk import BaseHook
+
+if TYPE_CHECKING:
+ from pydantic_ai._run_context import RunContext
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the four SQL tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+ "type": "object",
+ "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+ "type": "object",
+ "properties": {
+ "table_name": {"type": "string", "description": "Name of the table to
inspect."},
+ },
+ "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+ "type": "object",
+ "properties": {
+ "sql": {"type": "string", "description": "SQL query to execute."},
+ },
+ "required": ["sql"],
+}
+
+_CHECK_QUERY_SCHEMA: dict[str, Any] = {
+ "type": "object",
+ "properties": {
+ "sql": {"type": "string", "description": "SQL query to validate."},
+ },
+ "required": ["sql"],
+}
+
+
+class SQLToolset(AbstractToolset[Any]):
+ """
+ Curated toolset that gives an LLM agent safe access to a SQL database.
+
+ Provides four tools — ``list_tables``, ``get_schema``, ``query``, and
+ ``check_query`` — inspired by LangChain's ``SQLDatabaseToolkit`` pattern.
+
+ Uses a :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook` resolved
+ lazily from the given ``db_conn_id``.
+
+ :param db_conn_id: Airflow connection ID for the database.
+ :param allowed_tables: Restrict which tables the agent can discover via
+ ``list_tables`` and ``get_schema``. ``None`` (default) exposes all
tables.
+
+ .. note::
+ ``allowed_tables`` controls metadata visibility only. It does
**not**
+ parse or validate table references in SQL queries. An LLM can still
+ query tables outside this list if it guesses the name. For
query-level
+ restrictions, use database-level permissions (e.g. a read-only role
+ with grants limited to specific tables).
+
+ :param schema: Database schema/namespace for table listing and
introspection.
+ :param allow_writes: Allow data-modifying SQL (INSERT, UPDATE, DELETE,
etc.).
+ Default ``False`` — only SELECT-family statements are permitted.
+ :param max_rows: Maximum number of rows returned from the ``query`` tool.
+ Default ``50``.
+ """
+
+ def __init__(
+ self,
+ db_conn_id: str,
+ *,
+ allowed_tables: list[str] | None = None,
+ schema: str | None = None,
+ allow_writes: bool = False,
+ max_rows: int = 50,
+ ) -> None:
+ self._db_conn_id = db_conn_id
+ self._allowed_tables: frozenset[str] | None =
frozenset(allowed_tables) if allowed_tables else None
+ self._schema = schema
+ self._allow_writes = allow_writes
+ self._max_rows = max_rows
+ self._hook: DbApiHook | None = None
+
+ @property
+ def id(self) -> str:
+ return f"sql-{self._db_conn_id}"
+
+ # ------------------------------------------------------------------
+ # Lazy hook resolution
+ # ------------------------------------------------------------------
+
+ def _get_db_hook(self) -> DbApiHook:
+ if self._hook is None:
+ connection = BaseHook.get_connection(self._db_conn_id)
+ hook = connection.get_hook()
+ if not isinstance(hook, DbApiHook):
+ raise ValueError(
+ f"Connection {self._db_conn_id!r} does not provide a
DbApiHook. "
+ f"Got {type(hook).__name__}."
+ )
+ self._hook = hook
+ return self._hook
+
+ # ------------------------------------------------------------------
+ # AbstractToolset interface
+ # ------------------------------------------------------------------
+
+ async def get_tools(self, ctx: RunContext[Any]) -> dict[str,
ToolsetTool[Any]]:
+ tools: dict[str, ToolsetTool[Any]] = {}
+
+ for name, description, schema in (
+ ("list_tables", "List available table names in the database.",
_LIST_TABLES_SCHEMA),
+ ("get_schema", "Get column names and types for a table.",
_GET_SCHEMA_SCHEMA),
+ ("query", "Execute a SQL query and return rows as JSON.",
_QUERY_SCHEMA),
+ ("check_query", "Validate SQL syntax without executing it.",
_CHECK_QUERY_SCHEMA),
+ ):
+ # sequential=True because all tools use a shared DbApiHook with
+ # synchronous I/O — they must not run concurrently.
+ tool_def = ToolDefinition(
+ name=name,
+ description=description,
+ parameters_json_schema=schema,
+ sequential=True,
+ )
+ tools[name] = ToolsetTool(
+ toolset=self,
+ tool_def=tool_def,
+ max_retries=1,
+ args_validator=_PASSTHROUGH_VALIDATOR,
+ )
+ return tools
+
+ async def call_tool(
+ self,
+ name: str,
+ tool_args: dict[str, Any],
+ ctx: RunContext[Any],
+ tool: ToolsetTool[Any],
+ ) -> Any:
+ if name == "list_tables":
+ return self._list_tables()
+ if name == "get_schema":
+ return self._get_schema(tool_args["table_name"])
+ if name == "query":
+ return self._query(tool_args["sql"])
+ if name == "check_query":
+ return self._check_query(tool_args["sql"])
+ raise ValueError(f"Unknown tool: {name!r}")
+
+ # ------------------------------------------------------------------
+ # Tool implementations
+ # ------------------------------------------------------------------
+
+ def _list_tables(self) -> str:
+ hook = self._get_db_hook()
+ tables: list[str] = hook.inspector.get_table_names(schema=self._schema)
+ if self._allowed_tables is not None:
+ tables = [t for t in tables if t in self._allowed_tables]
+ return json.dumps(tables)
+
+ def _get_schema(self, table_name: str) -> str:
+ if self._allowed_tables is not None and table_name not in
self._allowed_tables:
+ return json.dumps({"error": f"Table {table_name!r} is not in the
allowed tables list."})
+ hook = self._get_db_hook()
+ columns = hook.get_table_schema(table_name, schema=self._schema)
+ return json.dumps(columns)
+
+ def _query(self, sql: str) -> str:
+ if not self._allow_writes:
+ _validate_sql(sql)
+
+ hook = self._get_db_hook()
+ rows = hook.get_records(sql)
+ # Fetch column names from cursor description.
+ col_names: list[str] | None = None
+ if hook.last_description:
+ col_names = [desc[0] for desc in hook.last_description]
+
+ result: list[dict[str, Any]] | list[list[Any]]
+ if rows and col_names:
+ result = [dict(zip(col_names, row)) for row in rows[:
self._max_rows]]
+ else:
+ result = [list(row) for row in (rows or [])[: self._max_rows]]
+
+ truncated = len(rows or []) > self._max_rows
+ output: dict[str, Any] = {"rows": result, "count": len(rows or [])}
+ if truncated:
+ output["truncated"] = True
+ output["max_rows"] = self._max_rows
+ return json.dumps(output, default=str)
+
+ def _check_query(self, sql: str) -> str:
+ try:
+ _validate_sql(sql)
+ return json.dumps({"valid": True})
+ except Exception as e:
+ return json.dumps({"valid": False, "error": str(e)})
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/__init__.py
b/providers/common/ai/tests/unit/common/ai/toolsets/__init__.py
new file mode 100644
index 00000000000..13a83393a91
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/__init__.py
@@ -0,0 +1,16 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_hook.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_hook.py
new file mode 100644
index 00000000000..2a40bdf0c4b
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_hook.py
@@ -0,0 +1,281 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.common.ai.toolsets.hook import (
+ HookToolset,
+ _build_json_schema_from_signature,
+ _extract_description,
+ _parse_param_docs,
+ _serialize_for_llm,
+)
+
+
+class _FakeHook:
+ """Fake hook for testing HookToolset introspection."""
+
+ def list_keys(self, bucket: str, prefix: str = "") -> list[str]:
+ """List object keys in a bucket.
+
+ :param bucket: Name of the S3 bucket.
+ :param prefix: Key prefix to filter by.
+ """
+ return [f"{prefix}file1.txt", f"{prefix}file2.txt"]
+
+ def read_file(self, key: str) -> str:
+ """Read a file from storage."""
+ return f"contents of {key}"
+
+ def no_docstring(self, x: int) -> int:
+ return x * 2
+
+
+class TestHookToolsetInit:
+ def test_requires_non_empty_allowed_methods(self):
+ with pytest.raises(ValueError, match="non-empty"):
+ HookToolset(MagicMock(), allowed_methods=[])
+
+ def test_rejects_nonexistent_method(self):
+ hook = _FakeHook()
+ with pytest.raises(ValueError, match="has no method 'nonexistent'"):
+ HookToolset(hook, allowed_methods=["nonexistent"])
+
+ def test_rejects_non_callable_attribute(self):
+ hook = MagicMock()
+ hook.some_attr = "not callable"
+
+ # MagicMock attributes are callable by default, so use a real object
+ class HookWithAttr:
+ data = [1, 2, 3]
+
+ with pytest.raises(ValueError, match="not callable"):
+ HookToolset(HookWithAttr(), allowed_methods=["data"])
+
+ def test_id_includes_hook_class_name(self):
+ hook = _FakeHook()
+ ts = HookToolset(hook, allowed_methods=["list_keys"])
+ assert "FakeHook" in ts.id
+
+
+class TestHookToolsetGetTools:
+ def test_returns_tools_for_allowed_methods(self):
+ hook = _FakeHook()
+ ts = HookToolset(hook, allowed_methods=["list_keys", "read_file"])
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+ assert set(tools.keys()) == {"list_keys", "read_file"}
+
+ def test_tool_definitions_have_correct_schemas(self):
+ hook = _FakeHook()
+ ts = HookToolset(hook, allowed_methods=["list_keys"])
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+ tool_def = tools["list_keys"].tool_def
+ assert tool_def.name == "list_keys"
+ assert "bucket" in tool_def.parameters_json_schema["properties"]
+ assert "prefix" in tool_def.parameters_json_schema["properties"]
+ assert "bucket" in tool_def.parameters_json_schema["required"]
+ # prefix has a default, so it's not required
+ assert "prefix" not in tool_def.parameters_json_schema.get("required",
[])
+
+ def test_tool_name_prefix(self):
+ hook = _FakeHook()
+ ts = HookToolset(hook, allowed_methods=["list_keys"],
tool_name_prefix="s3_")
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+ assert "s3_list_keys" in tools
+
+ def test_description_from_docstring(self):
+ hook = _FakeHook()
+ ts = HookToolset(hook, allowed_methods=["list_keys"])
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+ assert tools["list_keys"].tool_def.description == "List object keys in
a bucket."
+
+ def test_description_fallback_for_no_docstring(self):
+ hook = _FakeHook()
+ ts = HookToolset(hook, allowed_methods=["no_docstring"])
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+ assert tools["no_docstring"].tool_def.description == "No docstring"
+
+ def test_tools_are_sequential(self):
+ hook = _FakeHook()
+ ts = HookToolset(hook, allowed_methods=["list_keys"])
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+ assert tools["list_keys"].tool_def.sequential is True
+
+ def test_param_docs_enriched_in_schema(self):
+ hook = _FakeHook()
+ ts = HookToolset(hook, allowed_methods=["list_keys"])
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+ props =
tools["list_keys"].tool_def.parameters_json_schema["properties"]
+ assert "description" in props["bucket"]
+ assert "S3 bucket" in props["bucket"]["description"]
+
+
+class TestHookToolsetCallTool:
+ def test_dispatches_to_hook_method(self):
+ hook = _FakeHook()
+ ts = HookToolset(hook, allowed_methods=["list_keys"])
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+ result = asyncio.run(
+ ts.call_tool(
+ "list_keys",
+ {"bucket": "my-bucket", "prefix": "data/"},
+ ctx=MagicMock(),
+ tool=tools["list_keys"],
+ )
+ )
+ assert "data/file1.txt" in result
+
+ def test_dispatches_with_prefix(self):
+ hook = _FakeHook()
+ ts = HookToolset(hook, allowed_methods=["read_file"],
tool_name_prefix="storage_")
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+
+ result = asyncio.run(
+ ts.call_tool(
+ "storage_read_file", {"key": "test.txt"}, ctx=MagicMock(),
tool=tools["storage_read_file"]
+ )
+ )
+ assert result == "contents of test.txt"
+
+
+class TestBuildJsonSchemaFromSignature:
+ def test_basic_types(self):
+ def fn(name: str, count: int, rate: float, active: bool):
+ pass
+
+ schema = _build_json_schema_from_signature(fn)
+ assert schema["properties"]["name"] == {"type": "string"}
+ assert schema["properties"]["count"] == {"type": "integer"}
+ assert schema["properties"]["rate"] == {"type": "number"}
+ assert schema["properties"]["active"] == {"type": "boolean"}
+ assert set(schema["required"]) == {"name", "count", "rate", "active"}
+
+ def test_optional_params_not_required(self):
+ def fn(name: str, prefix: str = ""):
+ pass
+
+ schema = _build_json_schema_from_signature(fn)
+ assert schema["required"] == ["name"]
+
+ def test_list_type(self):
+ def fn(items: list[str]):
+ pass
+
+ schema = _build_json_schema_from_signature(fn)
+ assert schema["properties"]["items"] == {"type": "array", "items":
{"type": "string"}}
+
+ def test_no_annotation_defaults_to_string(self):
+ def fn(x):
+ pass
+
+ schema = _build_json_schema_from_signature(fn)
+ assert schema["properties"]["x"] == {"type": "string"}
+
+ def test_skips_self_and_cls(self):
+ class Foo:
+ def method(self, x: int):
+ pass
+
+ schema = _build_json_schema_from_signature(Foo().method)
+ assert "self" not in schema["properties"]
+
+ def test_skips_var_args(self):
+ def fn(x: int, *args, **kwargs):
+ pass
+
+ schema = _build_json_schema_from_signature(fn)
+ assert set(schema["properties"].keys()) == {"x"}
+
+
+class TestExtractDescription:
+ def test_first_paragraph(self):
+ def fn():
+ """First paragraph.
+
+ Second paragraph with details.
+ """
+
+ assert _extract_description(fn) == "First paragraph."
+
+ def test_multiline_first_paragraph(self):
+ def fn():
+ """First line of
+ the first paragraph.
+
+ Second paragraph.
+ """
+
+ assert _extract_description(fn) == "First line of the first paragraph."
+
+ def test_no_docstring_uses_method_name(self):
+ def some_method():
+ pass
+
+ assert _extract_description(some_method) == "Some method"
+
+
+class TestParseParamDocs:
+ def test_sphinx_style(self):
+ docstring = """Do something.
+
+ :param name: The name of the thing.
+ :param count: How many items.
+ """
+ result = _parse_param_docs(docstring)
+ assert result["name"] == "The name of the thing."
+ assert result["count"] == "How many items."
+
+ def test_google_style(self):
+ docstring = """Do something.
+
+ Args:
+ name: The name of the thing.
+ count: How many items.
+ """
+ result = _parse_param_docs(docstring)
+ assert result["name"] == "The name of the thing."
+ assert result["count"] == "How many items."
+
+
+class TestSerializeForLlm:
+ def test_string_passthrough(self):
+ assert _serialize_for_llm("hello") == "hello"
+
+ def test_none_returns_null(self):
+ assert _serialize_for_llm(None) == "null"
+
+ def test_dict_to_json(self):
+ result = _serialize_for_llm({"key": "value"})
+ assert result == '{"key": "value"}'
+
+ def test_list_to_json(self):
+ result = _serialize_for_llm([1, 2, 3])
+ assert result == "[1, 2, 3]"
+
+ def test_non_serializable_falls_back_to_str(self):
+ obj = object()
+ result = _serialize_for_llm(obj)
+ assert "object" in result
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
new file mode 100644
index 00000000000..0573acd2a77
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_sql.py
@@ -0,0 +1,233 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+from unittest.mock import MagicMock, PropertyMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.toolsets.sql import SQLToolset
+from airflow.providers.common.ai.utils.sql_validation import SQLSafetyError
+
+
+def _make_mock_db_hook(
+ table_names: list[str] | None = None,
+ table_schema: list[dict[str, str]] | None = None,
+ records: list[tuple] | None = None,
+ last_description: list[tuple] | None = None,
+):
+ """Create a mock DbApiHook with sensible defaults."""
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+ mock = MagicMock(spec=DbApiHook)
+ mock.inspector = MagicMock()
+ mock.inspector.get_table_names.return_value = table_names or ["users",
"orders"]
+ mock.get_table_schema.return_value = table_schema or [
+ {"name": "id", "type": "INTEGER"},
+ {"name": "name", "type": "VARCHAR"},
+ ]
+ mock.get_records.return_value = records or [(1, "Alice"), (2, "Bob")]
+ type(mock).last_description = PropertyMock(return_value=last_description
or [("id",), ("name",)])
+ return mock
+
+
+class TestSQLToolsetInit:
+ def test_id_includes_conn_id(self):
+ ts = SQLToolset("my_pg")
+ assert ts.id == "sql-my_pg"
+
+
+class TestSQLToolsetGetTools:
+ def test_returns_four_tools(self):
+ ts = SQLToolset("pg_default")
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+ assert set(tools.keys()) == {"list_tables", "get_schema", "query",
"check_query"}
+
+ def test_tool_definitions_have_descriptions(self):
+ ts = SQLToolset("pg_default")
+ tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+ for tool in tools.values():
+ assert tool.tool_def.description
+
+
+class TestSQLToolsetListTables:
+ def test_returns_all_tables(self):
+ ts = SQLToolset("pg_default")
+ mock_hook = _make_mock_db_hook(table_names=["users", "orders",
"products"])
+ ts._hook = mock_hook
+
+ result = asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(),
tool=MagicMock()))
+ tables = json.loads(result)
+ assert tables == ["users", "orders", "products"]
+
+ def test_filters_by_allowed_tables(self):
+ ts = SQLToolset("pg_default", allowed_tables=["orders"])
+ mock_hook = _make_mock_db_hook(table_names=["users", "orders",
"products"])
+ ts._hook = mock_hook
+
+ result = asyncio.run(ts.call_tool("list_tables", {}, ctx=MagicMock(),
tool=MagicMock()))
+ tables = json.loads(result)
+ assert tables == ["orders"]
+
+
+class TestSQLToolsetGetSchema:
+ def test_returns_column_info(self):
+ ts = SQLToolset("pg_default")
+ mock_hook = _make_mock_db_hook()
+ ts._hook = mock_hook
+
+ result = asyncio.run(
+ ts.call_tool("get_schema", {"table_name": "users"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ columns = json.loads(result)
+ assert columns == [{"name": "id", "type": "INTEGER"}, {"name": "name",
"type": "VARCHAR"}]
+ mock_hook.get_table_schema.assert_called_once_with("users",
schema=None)
+
+ def test_blocks_table_not_in_allowed_list(self):
+ ts = SQLToolset("pg_default", allowed_tables=["orders"])
+ ts._hook = _make_mock_db_hook()
+
+ result = asyncio.run(
+ ts.call_tool("get_schema", {"table_name": "secrets"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ data = json.loads(result)
+ assert "error" in data
+ assert "secrets" in data["error"]
+
+
+class TestSQLToolsetQuery:
+ def test_returns_rows_as_json(self):
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook(
+ records=[(1, "Alice"), (2, "Bob")],
+ last_description=[("id",), ("name",)],
+ )
+
+ result = asyncio.run(
+ ts.call_tool("query", {"sql": "SELECT id, name FROM users"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ data = json.loads(result)
+ assert data["rows"] == [{"id": 1, "name": "Alice"}, {"id": 2, "name":
"Bob"}]
+ assert data["count"] == 2
+
+ def test_truncates_at_max_rows(self):
+ ts = SQLToolset("pg_default", max_rows=1)
+ ts._hook = _make_mock_db_hook(
+ records=[(1, "Alice"), (2, "Bob"), (3, "Charlie")],
+ last_description=[("id",), ("name",)],
+ )
+
+ result = asyncio.run(
+ ts.call_tool("query", {"sql": "SELECT id, name FROM users"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ data = json.loads(result)
+ assert len(data["rows"]) == 1
+ assert data["truncated"] is True
+ assert data["count"] == 3
+
+ def test_blocks_unsafe_sql_by_default(self):
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+
+ with pytest.raises(SQLSafetyError, match="not allowed"):
+ asyncio.run(ts.call_tool("query", {"sql": "DROP TABLE users"},
ctx=MagicMock(), tool=MagicMock()))
+
+ def test_allows_writes_when_enabled(self):
+ ts = SQLToolset("pg_default", allow_writes=True)
+ ts._hook = _make_mock_db_hook(
+ records=[(1,)],
+ last_description=[("count",)],
+ )
+
+ # Should not raise even with INSERT
+ result = asyncio.run(
+ ts.call_tool(
+ "query", {"sql": "INSERT INTO users VALUES (3, 'Eve')"},
ctx=MagicMock(), tool=MagicMock()
+ )
+ )
+ # The mock doesn't actually execute, just returns mocked records
+ data = json.loads(result)
+ assert "rows" in data
+
+
+class TestSQLToolsetCheckQuery:
+ def test_valid_select(self):
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+
+ result = asyncio.run(
+ ts.call_tool("check_query", {"sql": "SELECT 1"}, ctx=MagicMock(),
tool=MagicMock())
+ )
+ data = json.loads(result)
+ assert data["valid"] is True
+
+ def test_invalid_sql(self):
+ ts = SQLToolset("pg_default")
+ ts._hook = _make_mock_db_hook()
+
+ result = asyncio.run(
+ ts.call_tool("check_query", {"sql": "DROP TABLE users"},
ctx=MagicMock(), tool=MagicMock())
+ )
+ data = json.loads(result)
+ assert data["valid"] is False
+ assert "error" in data
+
+
+class TestSQLToolsetHookResolution:
+ @patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
+ def test_lazy_resolves_db_hook(self, mock_base_hook):
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+ mock_hook = MagicMock(spec=DbApiHook)
+ mock_conn = MagicMock(spec=["get_hook"])
+ mock_conn.get_hook.return_value = mock_hook
+ mock_base_hook.get_connection.return_value = mock_conn
+
+ ts = SQLToolset("pg_default")
+ hook = ts._get_db_hook()
+
+ assert hook is mock_hook
+ mock_base_hook.get_connection.assert_called_once_with("pg_default")
+
+ @patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
+ def test_raises_for_non_dbapi_hook(self, mock_base_hook):
+ mock_conn = MagicMock(spec=["get_hook"])
+ mock_conn.get_hook.return_value = MagicMock() # Not a DbApiHook
+ mock_base_hook.get_connection.return_value = mock_conn
+
+ ts = SQLToolset("bad_conn")
+
+ with pytest.raises(ValueError, match="does not provide a DbApiHook"):
+ ts._get_db_hook()
+
+ @patch("airflow.providers.common.ai.toolsets.sql.BaseHook", autospec=True)
+ def test_caches_hook_after_first_resolution(self, mock_base_hook):
+ from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+ mock_hook = MagicMock(spec=DbApiHook)
+ mock_conn = MagicMock(spec=["get_hook"])
+ mock_conn.get_hook.return_value = mock_hook
+ mock_base_hook.get_connection.return_value = mock_conn
+
+ ts = SQLToolset("pg_default")
+ ts._get_db_hook()
+ ts._get_db_hook()
+
+ # Only called once because result is cached.
+ mock_base_hook.get_connection.assert_called_once()