This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 318bb25ffc7 AIP-99: Add `AgentOperator` and `@task.agent` for agentic
LLM workflows (#62825)
318bb25ffc7 is described below
commit 318bb25ffc78aa2fbe921042417ece99b83f3665
Author: Kaxil Naik <[email protected]>
AuthorDate: Wed Mar 4 12:52:49 2026 +0000
AIP-99: Add `AgentOperator` and `@task.agent` for agentic LLM workflows
(#62825)
Docs for HookToolset (generic hook→tools adapter) and SQLToolset
(curated 4-tool DB toolset). Includes defense layers table,
allowed_tables limitation, HookToolset guidelines, recommended
configurations, and production checklist.
AgentOperator runs a pydantic-ai Agent with tools and multi-turn
reasoning. The operator builds the agent from an Airflow connection
(llm_conn_id) and optional toolsets (HookToolset, SQLToolset, etc.),
keeping credentials in the secret backend.
@task.agent decorator wraps AgentOperator — the decorated function
returns the prompt string, all other params are passed through.
Includes docs with security section (defense layers, allowed_tables
limitation, HookToolset guidelines, production checklist), example
DAGs, and unit tests.
Help users choose between LLMOperator, LLMBranchOperator,
LLMSQLQueryOperator, and AgentOperator with a comparison table
and short descriptions of when to use each.
closes #62826
---
docs/spelling_wordlist.txt | 1 +
providers/common/ai/docs/operators/agent.rst | 138 ++++++++++++++++
providers/common/ai/docs/operators/index.rst | 40 +++++
providers/common/ai/provider.yaml | 4 +
.../providers/common/ai/decorators/agent.py | 123 ++++++++++++++
.../common/ai/example_dags/example_agent.py | 178 +++++++++++++++++++++
.../providers/common/ai/get_provider_info.py | 3 +
.../airflow/providers/common/ai/operators/agent.py | 108 +++++++++++++
.../tests/unit/common/ai/decorators/test_agent.py | 131 +++++++++++++++
.../tests/unit/common/ai/operators/test_agent.py | 138 ++++++++++++++++
10 files changed, 864 insertions(+)
diff --git a/docs/spelling_wordlist.txt b/docs/spelling_wordlist.txt
index 81fd9cc567e..2cbae4da50e 100644
--- a/docs/spelling_wordlist.txt
+++ b/docs/spelling_wordlist.txt
@@ -1760,6 +1760,7 @@ stackdriver
stacklevel
stacktrace
starttls
+stateful
StatefulSet
StatefulSets
statics
diff --git a/providers/common/ai/docs/operators/agent.rst
b/providers/common/ai/docs/operators/agent.rst
new file mode 100644
index 00000000000..bafd2114589
--- /dev/null
+++ b/providers/common/ai/docs/operators/agent.rst
@@ -0,0 +1,138 @@
+ .. Licensed to the Apache Software Foundation (ASF) under one
+ or more contributor license agreements. See the NOTICE file
+ distributed with this work for additional information
+ regarding copyright ownership. The ASF licenses this file
+ to you under the Apache License, Version 2.0 (the
+ "License"); you may not use this file except in compliance
+ with the License. You may obtain a copy of the License at
+
+ .. http://www.apache.org/licenses/LICENSE-2.0
+
+ .. Unless required by applicable law or agreed to in writing,
+ software distributed under the License is distributed on an
+ "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+ KIND, either express or implied. See the License for the
+ specific language governing permissions and limitations
+ under the License.
+
+.. _howto/operator:agent:
+
+``AgentOperator`` & ``@task.agent``
+===================================
+
+Use :class:`~airflow.providers.common.ai.operators.agent.AgentOperator` or
+the ``@task.agent`` decorator to run an LLM agent with **tools** — the agent
+reasons about the prompt, calls tools (database queries, API calls, etc.) in
+a multi-turn loop, and returns a final answer.
+
+This is different from
+:class:`~airflow.providers.common.ai.operators.llm.LLMOperator`, which sends
+a single prompt and returns the output. ``AgentOperator`` manages a stateful
+tool-call loop where the LLM decides which tools to call and when to stop.
+
+.. seealso::
+ :ref:`Connection configuration <howto/connection:pydantic_ai>`
+
+
+SQL Agent
+---------
+
+The most common pattern: give an agent access to a database so it can answer
+questions by writing and executing SQL.
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+ :language: python
+ :start-after: [START howto_operator_agent_sql]
+ :end-before: [END howto_operator_agent_sql]
+
+The ``SQLToolset`` provides four tools to the agent:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 20 50
+
+ * - Tool
+ - Description
+ * - ``list_tables``
+ - Lists available table names (filtered by ``allowed_tables`` if set)
+ * - ``get_schema``
+ - Returns column names and types for a table
+ * - ``query``
+ - Executes a SQL query and returns rows as JSON
+ * - ``check_query``
+ - Validates SQL syntax without executing it
+
+
+Hook-based Tools
+----------------
+
+Wrap any Airflow Hook's methods as agent tools using ``HookToolset``. Only
+methods you explicitly list are exposed — there is no auto-discovery.
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+ :language: python
+ :start-after: [START howto_operator_agent_hook]
+ :end-before: [END howto_operator_agent_hook]
+
+
+TaskFlow Decorator
+------------------
+
+The ``@task.agent`` decorator wraps ``AgentOperator``. The function returns
+the prompt string; all other parameters are passed to the operator.
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+ :language: python
+ :start-after: [START howto_decorator_agent]
+ :end-before: [END howto_decorator_agent]
+
+
+Structured Output
+-----------------
+
+Set ``output_type`` to a Pydantic ``BaseModel`` subclass to get structured
+data back. The result is serialized via ``model_dump()`` for XCom.
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+ :language: python
+ :start-after: [START howto_decorator_agent_structured]
+ :end-before: [END howto_decorator_agent_structured]
+
+
+Chaining with Downstream Tasks
+-------------------------------
+
+The agent's output is pushed to XCom like any other operator, so downstream
+tasks can consume it.
+
+.. exampleinclude::
/../../ai/src/airflow/providers/common/ai/example_dags/example_agent.py
+ :language: python
+ :start-after: [START howto_agent_chain]
+ :end-before: [END howto_agent_chain]
+
+
+Parameters
+----------
+
+- ``prompt``: The prompt to send to the agent (operator) or the return value
+ of the decorated function (decorator).
+- ``llm_conn_id``: Airflow connection ID for the LLM provider.
+- ``model_id``: Model identifier (e.g. ``"openai:gpt-5"``). Overrides the
+ connection's extra field.
+- ``system_prompt``: System-level instructions for the agent. Supports Jinja
+ templating.
+- ``output_type``: Expected output type (default: ``str``). Set to a Pydantic
+ ``BaseModel`` for structured output.
+- ``toolsets``: List of pydantic-ai toolsets (``SQLToolset``, ``HookToolset``,
+ etc.).
+- ``agent_params``: Additional keyword arguments passed to the pydantic-ai
+ ``Agent`` constructor (e.g. ``retries``, ``model_settings``).
+
+
+Security
+--------
+
+.. seealso::
+ :ref:`Toolsets — Security <howto/toolsets>` for defense layers,
+ ``allowed_tables`` limitations, ``HookToolset`` guidelines, recommended
+ configurations, and the production checklist.
diff --git a/providers/common/ai/docs/operators/index.rst
b/providers/common/ai/docs/operators/index.rst
index 5ca15266335..e961931925a 100644
--- a/providers/common/ai/docs/operators/index.rst
+++ b/providers/common/ai/docs/operators/index.rst
@@ -18,6 +18,46 @@
Common AI Operators
===================
+Choosing the right operator
+---------------------------
+
+The common-ai provider ships four operators (and matching ``@task``
decorators). Use this table
+to pick the one that fits your use case:
+
+.. list-table::
+ :header-rows: 1
+ :widths: 40 30 30
+
+ * - Need
+ - Operator
+ - Decorator
+ * - Single prompt → text or structured output
+ - :class:`~airflow.providers.common.ai.operators.llm.LLMOperator`
+ - ``@task.llm``
+ * - LLM picks which downstream task runs
+ -
:class:`~airflow.providers.common.ai.operators.llm_branch.LLMBranchOperator`
+ - ``@task.llm_branch``
+ * - Natural-language → SQL generation (no execution)
+ -
:class:`~airflow.providers.common.ai.operators.llm_sql.LLMSQLQueryOperator`
+ - ``@task.llm_sql``
+ * - Multi-turn reasoning with tools (DB queries, API calls, etc.)
+ - :class:`~airflow.providers.common.ai.operators.agent.AgentOperator`
+ - ``@task.agent``
+
+**LLMOperator / @task.llm** — stateless, single-turn calls. Use this for
classification,
+summarization, extraction, or any prompt that produces one response. Supports
structured output
+via a ``response_format`` Pydantic model.
+
+**AgentOperator / @task.agent** — multi-turn tool-calling loop. The model
decides which tools to
+invoke and when to stop. Use this when the LLM needs to take actions (query
databases, call APIs,
+read files) to produce its answer. You configure available tools through
``toolsets``.
+
+AgentOperator *works* without toolsets — pydantic-ai supports tool-less agents
for multi-turn
+reasoning — but if you don't need tools, ``LLMOperator`` is simpler and more
explicit.
+
+Operator guides
+---------------
+
.. toctree::
:maxdepth: 1
:glob:
diff --git a/providers/common/ai/provider.yaml
b/providers/common/ai/provider.yaml
index 7e2cc85bf19..24507cd9277 100644
--- a/providers/common/ai/provider.yaml
+++ b/providers/common/ai/provider.yaml
@@ -32,6 +32,7 @@ integrations:
- integration-name: Common AI
external-doc-url:
https://airflow.apache.org/docs/apache-airflow-providers-common-ai/
how-to-guide:
+ - /docs/apache-airflow-providers-common-ai/operators/agent.rst
- /docs/apache-airflow-providers-common-ai/operators/llm.rst
- /docs/apache-airflow-providers-common-ai/operators/llm_branch.rst
- /docs/apache-airflow-providers-common-ai/operators/llm_sql.rst
@@ -70,12 +71,15 @@ connection-types:
operators:
- integration-name: Common AI
python-modules:
+ - airflow.providers.common.ai.operators.agent
- airflow.providers.common.ai.operators.llm
- airflow.providers.common.ai.operators.llm_branch
- airflow.providers.common.ai.operators.llm_sql
- airflow.providers.common.ai.operators.llm_schema_compare
task-decorators:
+ - class-name: airflow.providers.common.ai.decorators.agent.agent_task
+ name: agent
- class-name: airflow.providers.common.ai.decorators.llm.llm_task
name: llm
- class-name:
airflow.providers.common.ai.decorators.llm_branch.llm_branch_task
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
b/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
new file mode 100644
index 00000000000..40c55f630c0
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/decorators/agent.py
@@ -0,0 +1,123 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""
+TaskFlow decorator for agentic LLM workflows.
+
+The user writes a function that **returns the prompt string**. The decorator
+handles hook creation, agent configuration with toolsets, multi-turn reasoning,
+and output serialization.
+"""
+
+from __future__ import annotations
+
+from collections.abc import Callable, Collection, Mapping, Sequence
+from typing import TYPE_CHECKING, Any, ClassVar
+
+from airflow.providers.common.ai.operators.agent import AgentOperator
+from airflow.providers.common.compat.sdk import (
+ DecoratedOperator,
+ TaskDecorator,
+ context_merge,
+ determine_kwargs,
+ task_decorator_factory,
+)
+from airflow.sdk.definitions._internal.types import SET_DURING_EXECUTION
+
+if TYPE_CHECKING:
+ from airflow.sdk import Context
+
+
+class _AgentDecoratedOperator(DecoratedOperator, AgentOperator):
+ """
+ Wraps a callable that returns a prompt for an agentic LLM workflow.
+
+ The user function is called at execution time to produce the prompt string.
+ All other parameters (``llm_conn_id``, ``toolsets``, ``system_prompt``,
etc.)
+ are passed through to
:class:`~airflow.providers.common.ai.operators.agent.AgentOperator`.
+
+ :param python_callable: A reference to a callable that returns the prompt
string.
+ :param op_args: Positional arguments for the callable.
+ :param op_kwargs: Keyword arguments for the callable.
+ """
+
+ template_fields: Sequence[str] = (
+ *DecoratedOperator.template_fields,
+ *AgentOperator.template_fields,
+ )
+ template_fields_renderers: ClassVar[dict[str, str]] = {
+ **DecoratedOperator.template_fields_renderers,
+ }
+
+ custom_operator_name: str = "@task.agent"
+
+ def __init__(
+ self,
+ *,
+ python_callable: Callable,
+ op_args: Collection[Any] | None = None,
+ op_kwargs: Mapping[str, Any] | None = None,
+ **kwargs,
+ ) -> None:
+ super().__init__(
+ python_callable=python_callable,
+ op_args=op_args,
+ op_kwargs=op_kwargs,
+ prompt=SET_DURING_EXECUTION,
+ **kwargs,
+ )
+
+ def execute(self, context: Context) -> Any:
+ context_merge(context, self.op_kwargs)
+ kwargs = determine_kwargs(self.python_callable, self.op_args, context)
+
+ self.prompt = self.python_callable(*self.op_args, **kwargs)
+
+ if not isinstance(self.prompt, str) or not self.prompt.strip():
+ raise TypeError("The returned value from the @task.agent callable
must be a non-empty string.")
+
+ self.render_template_fields(context)
+ return AgentOperator.execute(self, context)
+
+
+def agent_task(
+ python_callable: Callable | None = None,
+ **kwargs,
+) -> TaskDecorator:
+ """
+ Wrap a function that returns a prompt into an agentic LLM task.
+
+ The function body constructs the prompt (can use Airflow context, XCom,
etc.).
+ The decorator handles hook creation, agent configuration with toolsets,
+ multi-turn reasoning, and output serialization.
+
+ Usage::
+
+ @task.agent(
+ llm_conn_id="pydantic_ai_default",
+ system_prompt="You are a data analyst.",
+ toolsets=[SQLToolset(db_conn_id="postgres_default")],
+ )
+ def analyze(question: str):
+ return f"Answer: {question}"
+
+ :param python_callable: Function to decorate.
+ """
+ return task_decorator_factory(
+ python_callable=python_callable,
+ decorated_operator_class=_AgentDecoratedOperator,
+ **kwargs,
+ )
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
new file mode 100644
index 00000000000..985d1019818
--- /dev/null
+++
b/providers/common/ai/src/airflow/providers/common/ai/example_dags/example_agent.py
@@ -0,0 +1,178 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Example DAGs demonstrating AgentOperator, @task.agent, and toolsets."""
+
+from __future__ import annotations
+
+from airflow.providers.common.ai.operators.agent import AgentOperator
+from airflow.providers.common.ai.toolsets.hook import HookToolset
+from airflow.providers.common.ai.toolsets.sql import SQLToolset
+from airflow.providers.common.compat.sdk import dag, task
+
+# ---------------------------------------------------------------------------
+# 1. SQL Agent: answer a question using database tools
+# ---------------------------------------------------------------------------
+
+
+# [START howto_operator_agent_sql]
+@dag
+def example_agent_operator_sql():
+ AgentOperator(
+ task_id="analyst",
+ prompt="What are the top 5 customers by order count?",
+ llm_conn_id="pydantic_ai_default",
+ system_prompt=(
+ "You are a SQL analyst. Use the available tools to explore "
+ "the schema and answer the question with data."
+ ),
+ toolsets=[
+ SQLToolset(
+ db_conn_id="postgres_default",
+ allowed_tables=["customers", "orders"],
+ max_rows=20,
+ )
+ ],
+ )
+
+
+# [END howto_operator_agent_sql]
+
+example_agent_operator_sql()
+
+
+# ---------------------------------------------------------------------------
+# 2. Hook-based tools: wrap an existing hook for the agent
+# ---------------------------------------------------------------------------
+
+
+# [START howto_operator_agent_hook]
+@dag
+def example_agent_operator_hook():
+ from airflow.providers.http.hooks.http import HttpHook
+
+ http_hook = HttpHook(http_conn_id="my_api")
+
+ AgentOperator(
+ task_id="api_explorer",
+ prompt="What endpoints are available and what does /status return?",
+ llm_conn_id="pydantic_ai_default",
+ system_prompt="You are an API explorer. Use the tools to discover and
call endpoints.",
+ toolsets=[
+ HookToolset(
+ http_hook,
+ allowed_methods=["run"],
+ tool_name_prefix="http_",
+ )
+ ],
+ )
+
+
+# [END howto_operator_agent_hook]
+
+example_agent_operator_hook()
+
+
+# ---------------------------------------------------------------------------
+# 3. @task.agent decorator with dynamic prompt
+# ---------------------------------------------------------------------------
+
+
+# [START howto_decorator_agent]
+@dag
+def example_agent_decorator():
+ @task.agent(
+ llm_conn_id="pydantic_ai_default",
+ system_prompt="You are a data analyst. Use tools to answer questions.",
+ toolsets=[
+ SQLToolset(
+ db_conn_id="postgres_default",
+ allowed_tables=["orders"],
+ )
+ ],
+ )
+ def analyze(question: str):
+ return f"Answer this question about our orders data: {question}"
+
+ analyze("What was our total revenue last month?")
+
+
+# [END howto_decorator_agent]
+
+example_agent_decorator()
+
+
+# ---------------------------------------------------------------------------
+# 4. Structured output — agent returns a Pydantic model
+# ---------------------------------------------------------------------------
+
+
+# [START howto_decorator_agent_structured]
+@dag
+def example_agent_structured_output():
+ from pydantic import BaseModel
+
+ class Analysis(BaseModel):
+ summary: str
+ top_items: list[str]
+ row_count: int
+
+ @task.agent(
+ llm_conn_id="pydantic_ai_default",
+ system_prompt="You are a data analyst. Return structured results.",
+ output_type=Analysis,
+ toolsets=[SQLToolset(db_conn_id="postgres_default")],
+ )
+ def analyze(question: str):
+ return f"Analyze: {question}"
+
+ analyze("What are the trending products this week?")
+
+
+# [END howto_decorator_agent_structured]
+
+example_agent_structured_output()
+
+
+# ---------------------------------------------------------------------------
+# 5. Chaining: agent output feeds into downstream tasks via XCom
+# ---------------------------------------------------------------------------
+
+
+# [START howto_agent_chain]
+@dag
+def example_agent_chain():
+ @task.agent(
+ llm_conn_id="pydantic_ai_default",
+ system_prompt="You are a SQL analyst.",
+ toolsets=[SQLToolset(db_conn_id="postgres_default",
allowed_tables=["orders"])],
+ )
+ def investigate(question: str):
+ return f"Investigate: {question}"
+
+ @task
+ def send_report(analysis: str):
+ """Send the agent's analysis to a downstream system."""
+ print(f"Report: {analysis}")
+ return analysis
+
+ result = investigate("Summarize order trends for last quarter")
+ send_report(result)
+
+
+# [END howto_agent_chain]
+
+example_agent_chain()
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
index 77a3c1b86c0..e5113d7fb3d 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/get_provider_info.py
@@ -31,6 +31,7 @@ def get_provider_info():
"integration-name": "Common AI",
"external-doc-url":
"https://airflow.apache.org/docs/apache-airflow-providers-common-ai/",
"how-to-guide": [
+
"/docs/apache-airflow-providers-common-ai/operators/agent.rst",
"/docs/apache-airflow-providers-common-ai/operators/llm.rst",
"/docs/apache-airflow-providers-common-ai/operators/llm_branch.rst",
"/docs/apache-airflow-providers-common-ai/operators/llm_sql.rst",
@@ -72,6 +73,7 @@ def get_provider_info():
{
"integration-name": "Common AI",
"python-modules": [
+ "airflow.providers.common.ai.operators.agent",
"airflow.providers.common.ai.operators.llm",
"airflow.providers.common.ai.operators.llm_branch",
"airflow.providers.common.ai.operators.llm_sql",
@@ -80,6 +82,7 @@ def get_provider_info():
}
],
"task-decorators": [
+ {"class-name":
"airflow.providers.common.ai.decorators.agent.agent_task", "name": "agent"},
{"class-name":
"airflow.providers.common.ai.decorators.llm.llm_task", "name": "llm"},
{
"class-name":
"airflow.providers.common.ai.decorators.llm_branch.llm_branch_task",
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
new file mode 100644
index 00000000000..ca4d61c86ec
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -0,0 +1,108 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operator for running pydantic-ai agents with tools and multi-turn
reasoning."""
+
+from __future__ import annotations
+
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any
+
+from pydantic import BaseModel
+
+from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+from airflow.providers.common.compat.sdk import BaseOperator
+
+if TYPE_CHECKING:
+ from pydantic_ai import Agent
+ from pydantic_ai.toolsets.abstract import AbstractToolset
+
+ from airflow.sdk import Context
+
+
+class AgentOperator(BaseOperator):
+ """
+ Run a pydantic-ai Agent with tools and multi-turn reasoning.
+
+ Provide ``llm_conn_id`` and optional ``toolsets`` to let the operator build
+ and run the agent. The agent reasons about the prompt, calls tools in a
+ multi-turn loop, and returns a final answer.
+
+ :param prompt: The prompt to send to the agent.
+ :param llm_conn_id: Connection ID for the LLM provider.
+ :param model_id: Model identifier (e.g. ``"openai:gpt-5"``).
+ Overrides the model stored in the connection's extra field.
+ :param system_prompt: System-level instructions for the agent.
+ :param output_type: Expected output type. Default ``str``. Set to a
Pydantic
+ ``BaseModel`` subclass for structured output.
+ :param toolsets: List of pydantic-ai toolsets the agent can use
+ (e.g. ``SQLToolset``, ``HookToolset``).
+ :param agent_params: Additional keyword arguments passed to the pydantic-ai
+ ``Agent`` constructor (e.g. ``retries``, ``model_settings``).
+ """
+
+ template_fields: Sequence[str] = (
+ "prompt",
+ "llm_conn_id",
+ "model_id",
+ "system_prompt",
+ "agent_params",
+ )
+
+ def __init__(
+ self,
+ *,
+ prompt: str,
+ llm_conn_id: str,
+ model_id: str | None = None,
+ system_prompt: str = "",
+ output_type: type = str,
+ toolsets: list[AbstractToolset] | None = None,
+ agent_params: dict[str, Any] | None = None,
+ **kwargs: Any,
+ ) -> None:
+ super().__init__(**kwargs)
+
+ self.prompt = prompt
+ self.llm_conn_id = llm_conn_id
+ self.model_id = model_id
+ self.system_prompt = system_prompt
+ self.output_type = output_type
+ self.toolsets = toolsets
+ self.agent_params = agent_params or {}
+
+ @cached_property
+ def llm_hook(self) -> PydanticAIHook:
+ """Return PydanticAIHook for the configured LLM connection."""
+ return PydanticAIHook(llm_conn_id=self.llm_conn_id,
model_id=self.model_id)
+
+ def execute(self, context: Context) -> Any:
+ extra_kwargs = dict(self.agent_params)
+ if self.toolsets:
+ extra_kwargs["toolsets"] = self.toolsets
+ agent: Agent[None, Any] = self.llm_hook.create_agent(
+ output_type=self.output_type,
+ instructions=self.system_prompt,
+ **extra_kwargs,
+ )
+
+ result = agent.run_sync(self.prompt)
+ output = result.output
+
+ if isinstance(output, BaseModel):
+ return output.model_dump()
+ return output
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
new file mode 100644
index 00000000000..99e2e35aafc
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
@@ -0,0 +1,131 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.decorators.agent import
_AgentDecoratedOperator
+
+
+class TestAgentDecoratedOperator:
+ def test_custom_operator_name(self):
+ assert _AgentDecoratedOperator.custom_operator_name == "@task.agent"
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_calls_callable_and_returns_output(self, mock_hook_cls):
+ """The callable's return value becomes the agent prompt."""
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = "The top customer is Acme Corp."
+ mock_agent.run_sync.return_value = mock_result
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ def my_prompt():
+ return "Who is our top customer?"
+
+ op = _AgentDecoratedOperator(task_id="test",
python_callable=my_prompt, llm_conn_id="my_llm")
+ result = op.execute(context={})
+
+ assert result == "The top customer is Acme Corp."
+ assert op.prompt == "Who is our top customer?"
+ mock_agent.run_sync.assert_called_once_with("Who is our top customer?")
+
+ @pytest.mark.parametrize(
+ "return_value",
+ [42, "", " ", None],
+ ids=["non-string", "empty", "whitespace", "none"],
+ )
+ def test_execute_raises_on_invalid_prompt(self, return_value):
+ """TypeError when the callable returns a non-string or blank string."""
+ op = _AgentDecoratedOperator(
+ task_id="test",
+ python_callable=lambda: return_value,
+ llm_conn_id="my_llm",
+ )
+ with pytest.raises(TypeError, match="non-empty string"):
+ op.execute(context={})
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls):
+ """op_kwargs are resolved by the callable to build the prompt."""
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = "done"
+ mock_agent.run_sync.return_value = mock_result
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ def my_prompt(topic):
+ return f"Analyze {topic}"
+
+ op = _AgentDecoratedOperator(
+ task_id="test",
+ python_callable=my_prompt,
+ llm_conn_id="my_llm",
+ op_kwargs={"topic": "revenue trends"},
+ )
+ op.execute(context={"task_instance": MagicMock()})
+
+ assert op.prompt == "Analyze revenue trends"
+ mock_agent.run_sync.assert_called_once_with("Analyze revenue trends")
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_passes_toolsets_through(self, mock_hook_cls):
+ """Toolsets passed to the decorator are forwarded to the agent."""
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = "result"
+ mock_agent.run_sync.return_value = mock_result
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ mock_toolset = MagicMock()
+
+ op = _AgentDecoratedOperator(
+ task_id="test",
+ python_callable=lambda: "Do something",
+ llm_conn_id="my_llm",
+ toolsets=[mock_toolset],
+ )
+ op.execute(context={})
+
+ create_call = mock_hook_cls.return_value.create_agent.call_args
+ assert create_call[1]["toolsets"] == [mock_toolset]
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_structured_output(self, mock_hook_cls):
+ """BaseModel output is serialized with model_dump."""
+ from pydantic import BaseModel
+
+ class Summary(BaseModel):
+ text: str
+
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = Summary(text="Great results")
+ mock_agent.run_sync.return_value = mock_result
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ op = _AgentDecoratedOperator(
+ task_id="test",
+ python_callable=lambda: "Summarize",
+ llm_conn_id="my_llm",
+ output_type=Summary,
+ )
+ result = op.execute(context={})
+
+ assert result == {"text": "Great results"}
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
new file mode 100644
index 00000000000..3d949854189
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
@@ -0,0 +1,138 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock, patch
+
+import pytest
+from pydantic import BaseModel
+
+from airflow.providers.common.ai.operators.agent import AgentOperator
+
+
+def _make_mock_agent(output):
+ """Create a mock agent that returns the given output."""
+ mock_result = MagicMock(spec=["output"])
+ mock_result.output = output
+ mock_agent = MagicMock(spec=["run_sync"])
+ mock_agent.run_sync.return_value = mock_result
+ return mock_agent
+
+
+class TestAgentOperatorValidation:
+ def test_requires_llm_conn_id(self):
+ with pytest.raises(TypeError):
+ AgentOperator(task_id="test", prompt="hello")
+
+
+class TestAgentOperatorTemplateFields:
+ def test_template_fields(self):
+ expected = {"prompt", "llm_conn_id", "model_id", "system_prompt",
"agent_params"}
+ assert set(AgentOperator.template_fields) == expected
+
+
+class TestAgentOperatorExecute:
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_creates_agent_from_hook(self, mock_hook_cls):
+ mock_agent = _make_mock_agent("The answer is 42.")
+ mock_hook_cls.return_value.create_agent.return_value = mock_agent
+
+ op = AgentOperator(
+ task_id="test",
+ prompt="What is the answer?",
+ llm_conn_id="my_llm",
+ system_prompt="You are helpful.",
+ )
+ result = op.execute(context=MagicMock())
+
+ assert result == "The answer is 42."
+ mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm",
model_id=None)
+ mock_hook_cls.return_value.create_agent.assert_called_once_with(
+ output_type=str, instructions="You are helpful."
+ )
+ mock_agent.run_sync.assert_called_once_with("What is the answer?")
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_passes_toolsets_in_agent_kwargs(self, mock_hook_cls):
+ """Toolsets are passed through to the agent constructor."""
+ mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("done")
+
+ mock_toolset = MagicMock()
+ op = AgentOperator(
+ task_id="test",
+ prompt="Do something",
+ llm_conn_id="my_llm",
+ toolsets=[mock_toolset],
+ )
+ op.execute(context=MagicMock())
+
+ create_call = mock_hook_cls.return_value.create_agent.call_args
+ assert create_call[1]["toolsets"] == [mock_toolset]
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_passes_agent_params(self, mock_hook_cls):
+ """agent_params are unpacked into create_agent."""
+ mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("ok")
+
+ op = AgentOperator(
+ task_id="test",
+ prompt="test",
+ llm_conn_id="my_llm",
+ agent_params={"retries": 3, "model_settings": {"temperature": 0}},
+ )
+ op.execute(context=MagicMock())
+
+ create_call = mock_hook_cls.return_value.create_agent.call_args
+ assert create_call[1]["retries"] == 3
+ assert create_call[1]["model_settings"] == {"temperature": 0}
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_structured_output(self, mock_hook_cls):
+ """Structured output via BaseModel is serialized with model_dump."""
+
+ class Summary(BaseModel):
+ text: str
+ score: float
+
+ mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent(
+ Summary(text="Great", score=0.95)
+ )
+
+ op = AgentOperator(
+ task_id="test",
+ prompt="Analyze this",
+ llm_conn_id="my_llm",
+ output_type=Summary,
+ )
+ result = op.execute(context=MagicMock())
+
+ assert result == {"text": "Great", "score": 0.95}
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_execute_with_model_id(self, mock_hook_cls):
+ """model_id is passed to PydanticAIHook."""
+ mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("ok")
+
+ op = AgentOperator(
+ task_id="test",
+ prompt="test",
+ llm_conn_id="my_llm",
+ model_id="openai:gpt-5",
+ )
+ op.execute(context=MagicMock())
+
+ mock_hook_cls.assert_called_once_with(llm_conn_id="my_llm",
model_id="openai:gpt-5")