This is an automated email from the ASF dual-hosted git repository.
kaxilnaik pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push:
new 1184e7cec04 AIP:99: Add tool call logging for all AI operators (#62901)
1184e7cec04 is described below
commit 1184e7cec04a88db362bce38d39936a97d2b13f3
Author: Kaxil Naik <[email protected]>
AuthorDate: Thu Mar 5 12:29:14 2026 +0000
AIP:99: Add tool call logging for all AI operators (#62901)
Two-layer logging for pydantic-ai agent runs:
1. Real-time tool call logging via LoggingToolset (wraps pydantic-ai's
WrapperToolset). Logs tool name at INFO, args at DEBUG (to avoid
leaking sensitive data), and elapsed time. Uses Airflow log groups
(::group::/::endgroup::) for collapsible sections in the UI.
2. Post-run summary via log_run_summary() called after agent.run_sync().
Logs model name, token usage, and tool call sequence. Output logged
at DEBUG level with truncation.
AgentOperator gets an enable_tool_logging parameter (default True) to
opt out of toolset wrapping when needed.
---
providers/common/ai/docs/operators/agent.rst | 51 +++++++
providers/common/ai/docs/operators/llm.rst | 8 ++
providers/common/ai/docs/operators/llm_branch.rst | 7 +
.../ai/docs/operators/llm_schema_compare.rst | 7 +
providers/common/ai/docs/operators/llm_sql.rst | 8 ++
providers/common/ai/docs/toolsets.rst | 21 +++
.../airflow/providers/common/ai/operators/agent.py | 12 +-
.../airflow/providers/common/ai/operators/llm.py | 2 +
.../providers/common/ai/operators/llm_branch.py | 2 +
.../common/ai/operators/llm_schema_compare.py | 3 +-
.../providers/common/ai/operators/llm_sql.py | 2 +
.../providers/common/ai/toolsets/logging.py | 62 +++++++++
.../airflow/providers/common/ai/utils/logging.py | 90 ++++++++++++
.../common/ai/tests/unit/common/ai/conftest.py | 35 +++++
.../tests/unit/common/ai/decorators/test_agent.py | 36 +++--
.../ai/tests/unit/common/ai/decorators/test_llm.py | 20 ++-
.../unit/common/ai/decorators/test_llm_branch.py | 20 ++-
.../ai/decorators/test_llm_schema_compare.py | 16 ++-
.../unit/common/ai/decorators/test_llm_sql.py | 20 ++-
.../tests/unit/common/ai/operators/test_agent.py | 38 +++++-
.../ai/tests/unit/common/ai/operators/test_llm.py | 20 ++-
.../unit/common/ai/operators/test_llm_branch.py | 30 ++--
.../common/ai/operators/test_llm_schema_compare.py | 37 +++--
.../tests/unit/common/ai/operators/test_llm_sql.py | 16 ++-
.../tests/unit/common/ai/toolsets/test_logging.py | 111 +++++++++++++++
.../ai/tests/unit/common/ai/utils/test_logging.py | 151 +++++++++++++++++++++
26 files changed, 754 insertions(+), 71 deletions(-)
diff --git a/providers/common/ai/docs/operators/agent.rst
b/providers/common/ai/docs/operators/agent.rst
index bafd2114589..01d1efc0cc4 100644
--- a/providers/common/ai/docs/operators/agent.rst
+++ b/providers/common/ai/docs/operators/agent.rst
@@ -125,10 +125,61 @@ Parameters
``BaseModel`` for structured output.
- ``toolsets``: List of pydantic-ai toolsets (``SQLToolset``, ``HookToolset``,
etc.).
+- ``enable_tool_logging``: Wrap each toolset in
+ :class:`~airflow.providers.common.ai.toolsets.logging.LoggingToolset` so that
+ every tool call is logged in real time. Default ``True``.
- ``agent_params``: Additional keyword arguments passed to the pydantic-ai
``Agent`` constructor (e.g. ``retries``, ``model_settings``).
+Logging
+-------
+
+All AI operators automatically log a post-run summary after ``run_sync()``
+completes. ``AgentOperator`` additionally wraps toolsets for real-time
+per-tool-call logging (controlled by ``enable_tool_logging``).
+
+**Real-time tool call logging** (AgentOperator only) — each tool call is
+logged as it happens:
+
+.. code-block:: text
+
+ INFO - Tool call: list_tables
+ INFO - Tool list_tables returned in 0.12s
+ INFO - Tool call: get_schema
+ INFO - Tool get_schema returned in 0.08s
+ INFO - Tool call: query
+ INFO - Tool query returned in 0.34s
+
+Tool arguments are logged at DEBUG level to avoid leaking sensitive data at
+the default log level.
+
+**Post-run summary** (all operators) — after the LLM run finishes, a summary
+is logged with model name, token usage, and the full tool call sequence:
+
+.. code-block:: text
+
+ INFO - LLM run complete: model=gpt-5, requests=4, tool_calls=3,
input_tokens=2847, output_tokens=512, total_tokens=3359
+ INFO - Tool call sequence: list_tables -> get_schema -> query
+
+At DEBUG level, the LLM output is also logged (truncated to 500 characters).
+
+Both layers use Airflow's ``::group::`` / ``::endgroup::`` log markers, which
+render as collapsible sections in the Airflow UI task log viewer.
+
+To disable real-time tool logging while keeping the post-run summary:
+
+.. code-block:: python
+
+ AgentOperator(
+ task_id="my_agent",
+ prompt="...",
+ llm_conn_id="my_llm",
+ toolsets=[SQLToolset(db_conn_id="my_db")],
+ enable_tool_logging=False,
+ )
+
+
Security
--------
diff --git a/providers/common/ai/docs/operators/llm.rst
b/providers/common/ai/docs/operators/llm.rst
index 21df887b041..af2abb8f714 100644
--- a/providers/common/ai/docs/operators/llm.rst
+++ b/providers/common/ai/docs/operators/llm.rst
@@ -118,3 +118,11 @@ Parameters
for structured output.
- ``agent_params``: Additional keyword arguments passed to the pydantic-ai
``Agent``
constructor (e.g. ``retries``, ``model_settings``, ``tools``). Supports
Jinja templating.
+
+Logging
+-------
+
+After each LLM call, the operator logs a summary with model name, token usage,
+and request count at INFO level. At DEBUG level, the LLM output is also logged
+(truncated to 500 characters). See :ref:`AgentOperator — Logging
<howto/operator:agent>`
+for details on the log format.
diff --git a/providers/common/ai/docs/operators/llm_branch.rst
b/providers/common/ai/docs/operators/llm_branch.rst
index 9d1bc059a5e..4e7630bbbb5 100644
--- a/providers/common/ai/docs/operators/llm_branch.rst
+++ b/providers/common/ai/docs/operators/llm_branch.rst
@@ -95,3 +95,10 @@ Parameters
task ID. When ``True`` the LLM may return one or more task IDs.
- ``agent_params``: Additional keyword arguments passed to the pydantic-ai
``Agent``
constructor (e.g. ``retries``, ``model_settings``). Supports Jinja
templating.
+
+Logging
+-------
+
+After each LLM call, the operator logs a summary with model name, token usage,
+and request count at INFO level. See :ref:`AgentOperator — Logging
<howto/operator:agent>`
+for details on the log format.
diff --git a/providers/common/ai/docs/operators/llm_schema_compare.rst
b/providers/common/ai/docs/operators/llm_schema_compare.rst
index d2e0ab5cff1..f6767e04dbe 100644
--- a/providers/common/ai/docs/operators/llm_schema_compare.rst
+++ b/providers/common/ai/docs/operators/llm_schema_compare.rst
@@ -162,3 +162,10 @@ Parameters
catalog-managed sources.
- ``context_strategy``: To fetch primary keys, foreign keys, and
indexes.``full`` or ``basic``,
strongly recommended for cross-system comparisons. default is ``full``
+
+Logging
+-------
+
+After each LLM call, the operator logs a summary with model name, token usage,
+and request count at INFO level. See :ref:`AgentOperator — Logging
<howto/operator:agent>`
+for details on the log format.
diff --git a/providers/common/ai/docs/operators/llm_sql.rst
b/providers/common/ai/docs/operators/llm_sql.rst
index 5497dcfcf89..443e4952f74 100644
--- a/providers/common/ai/docs/operators/llm_sql.rst
+++ b/providers/common/ai/docs/operators/llm_sql.rst
@@ -132,3 +132,11 @@ By default, the operator validates generated SQL using an
allowlist approach:
You can disable validation with ``validate_sql=False`` or customize the allowed
statement types with ``allowed_sql_types``.
+
+Logging
+-------
+
+After each LLM call, the operator logs a summary with model name, token usage,
+and request count at INFO level. At DEBUG level, the generated SQL is also
+logged (truncated to 500 characters). See :ref:`AgentOperator — Logging
<howto/operator:agent>`
+for details on the log format.
diff --git a/providers/common/ai/docs/toolsets.rst
b/providers/common/ai/docs/toolsets.rst
index b9c3feb7477..a8dd9e1512c 100644
--- a/providers/common/ai/docs/toolsets.rst
+++ b/providers/common/ai/docs/toolsets.rst
@@ -185,6 +185,27 @@ Parameters
support DDL for in-memory tables; this guard blocks those by default.
- ``max_rows``: Maximum rows returned from the ``query`` tool. Default ``50``.
+``LoggingToolset``
+------------------
+
+:class:`~airflow.providers.common.ai.toolsets.logging.LoggingToolset` is a
+``WrapperToolset`` that intercepts ``call_tool()`` to log each tool invocation
+in real time. ``AgentOperator`` applies it automatically (see
+``enable_tool_logging``), but you can also use it directly with any pydantic-ai
+``Agent``:
+
+.. code-block:: python
+
+ from airflow.providers.common.ai.toolsets.logging import LoggingToolset
+ from airflow.providers.common.ai.toolsets.sql import SQLToolset
+
+ sql_toolset = SQLToolset(db_conn_id="my_db")
+ logged_toolset = LoggingToolset(wrapped=sql_toolset, logger=my_logger)
+
+Each tool call produces two INFO log lines (name + timing) and optional
+DEBUG-level argument logging. Exceptions are logged and re-raised.
+
+
Security
--------
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
index ca4d61c86ec..a9e00c12078 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/agent.py
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+from airflow.providers.common.ai.utils.logging import log_run_summary,
wrap_toolsets_for_logging
from airflow.providers.common.compat.sdk import BaseOperator
if TYPE_CHECKING:
@@ -51,6 +52,9 @@ class AgentOperator(BaseOperator):
``BaseModel`` subclass for structured output.
:param toolsets: List of pydantic-ai toolsets the agent can use
(e.g. ``SQLToolset``, ``HookToolset``).
+ :param enable_tool_logging: When ``True`` (default), wraps each toolset in
a
+ ``LoggingToolset`` that logs tool calls with timing at INFO level and
+ arguments at DEBUG level. Set to ``False`` to disable.
:param agent_params: Additional keyword arguments passed to the pydantic-ai
``Agent`` constructor (e.g. ``retries``, ``model_settings``).
"""
@@ -72,6 +76,7 @@ class AgentOperator(BaseOperator):
system_prompt: str = "",
output_type: type = str,
toolsets: list[AbstractToolset] | None = None,
+ enable_tool_logging: bool = True,
agent_params: dict[str, Any] | None = None,
**kwargs: Any,
) -> None:
@@ -83,6 +88,7 @@ class AgentOperator(BaseOperator):
self.system_prompt = system_prompt
self.output_type = output_type
self.toolsets = toolsets
+ self.enable_tool_logging = enable_tool_logging
self.agent_params = agent_params or {}
@cached_property
@@ -93,7 +99,10 @@ class AgentOperator(BaseOperator):
def execute(self, context: Context) -> Any:
extra_kwargs = dict(self.agent_params)
if self.toolsets:
- extra_kwargs["toolsets"] = self.toolsets
+ if self.enable_tool_logging:
+ extra_kwargs["toolsets"] =
wrap_toolsets_for_logging(self.toolsets, self.log)
+ else:
+ extra_kwargs["toolsets"] = self.toolsets
agent: Agent[None, Any] = self.llm_hook.create_agent(
output_type=self.output_type,
instructions=self.system_prompt,
@@ -101,6 +110,7 @@ class AgentOperator(BaseOperator):
)
result = agent.run_sync(self.prompt)
+ log_run_summary(self.log, result)
output = result.output
if isinstance(output, BaseModel):
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
index b7cd69ecf1e..98702bc0102 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm.py
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any
from pydantic import BaseModel
from airflow.providers.common.ai.hooks.pydantic_ai import PydanticAIHook
+from airflow.providers.common.ai.utils.logging import log_run_summary
from airflow.providers.common.compat.sdk import BaseOperator
if TYPE_CHECKING:
@@ -92,6 +93,7 @@ class LLMOperator(BaseOperator):
output_type=self.output_type, instructions=self.system_prompt,
**self.agent_params
)
result = agent.run_sync(self.prompt)
+ log_run_summary(self.log, result)
output = result.output
if isinstance(output, BaseModel):
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py
index b7f3028ec92..7a2cda2a671 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_branch.py
@@ -23,6 +23,7 @@ from enum import Enum
from typing import TYPE_CHECKING, Any
from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.ai.utils.logging import log_run_summary
from airflow.providers.standard.operators.branch import BranchMixIn
if TYPE_CHECKING:
@@ -81,6 +82,7 @@ class LLMBranchOperator(LLMOperator, BranchMixIn):
**self.agent_params,
)
result = agent.run_sync(self.prompt)
+ log_run_summary(self.log, result)
output = result.output
branches: str | list[str]
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py
index c46dcbe7c7c..5e13eb55966 100644
---
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py
+++
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py
@@ -26,6 +26,7 @@ from typing import TYPE_CHECKING, Any, Literal
from pydantic import BaseModel, Field
from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.ai.utils.logging import log_run_summary
from airflow.providers.common.compat.sdk import AirflowException, BaseHook
if TYPE_CHECKING:
@@ -309,7 +310,7 @@ class LLMSchemaCompareOperator(LLMOperator):
)
self.log.info("Running LLM schema comparison...")
result = agent.run_sync(self.prompt)
- self.log.info("LLM schema comparison completed.")
+ log_run_summary(self.log, result)
output_result = result.output.model_dump()
self.log.info("Schema comparison result: \n %s",
json.dumps(output_result, indent=2))
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
index 81bec01ec77..369c76ae3ff 100644
--- a/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
+++ b/providers/common/ai/src/airflow/providers/common/ai/operators/llm_sql.py
@@ -34,6 +34,7 @@ except ImportError as e:
raise AirflowOptionalProviderFeatureException(e)
from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.ai.utils.logging import log_run_summary
from airflow.providers.common.compat.sdk import BaseHook
if TYPE_CHECKING:
@@ -140,6 +141,7 @@ class LLMSQLQueryOperator(LLMOperator):
output_type=str, instructions=full_system_prompt,
**self.agent_params
)
result = agent.run_sync(self.prompt)
+ log_run_summary(self.log, result)
sql = self._strip_llm_output(result.output)
if self.validate_sql:
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/toolsets/logging.py
b/providers/common/ai/src/airflow/providers/common/ai/toolsets/logging.py
new file mode 100644
index 00000000000..0ce234f7550
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/toolsets/logging.py
@@ -0,0 +1,62 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Logging wrapper toolset for pydantic-ai tool calls."""
+
+from __future__ import annotations
+
+import json
+import logging
+import time
+from dataclasses import dataclass, field
+from typing import TYPE_CHECKING, Any
+
+from pydantic_ai.toolsets.wrapper import WrapperToolset
+
+if TYPE_CHECKING:
+ from pydantic_ai.toolsets.abstract import ToolsetTool
+
+ from airflow.sdk.types import Logger
+
+
+@dataclass
+class LoggingToolset(WrapperToolset[Any]):
+ """Wrap a toolset to log each tool call with timing."""
+
+ logger: Logger | logging.Logger = field(default_factory=lambda:
logging.getLogger(__name__))
+
+ async def call_tool(
+ self,
+ name: str,
+ tool_args: dict[str, Any],
+ ctx: Any,
+ tool: ToolsetTool[Any],
+ ) -> Any:
+ self.logger.info("::group::Tool call: %s", name)
+ if tool_args:
+ self.logger.debug("Tool args: %s", json.dumps(tool_args,
default=str))
+ start = time.monotonic()
+ try:
+ result = await self.wrapped.call_tool(name, tool_args, ctx, tool)
+ elapsed = time.monotonic() - start
+ self.logger.info("Tool %s returned in %.2fs", name, elapsed)
+ self.logger.info("::endgroup::")
+ return result
+ except Exception:
+ elapsed = time.monotonic() - start
+ self.logger.exception("Tool %s failed after %.2fs", name, elapsed)
+ self.logger.info("::endgroup::")
+ raise
diff --git
a/providers/common/ai/src/airflow/providers/common/ai/utils/logging.py
b/providers/common/ai/src/airflow/providers/common/ai/utils/logging.py
new file mode 100644
index 00000000000..47cabf7ce6c
--- /dev/null
+++ b/providers/common/ai/src/airflow/providers/common/ai/utils/logging.py
@@ -0,0 +1,90 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Logging utilities for pydantic-ai agent runs."""
+
+from __future__ import annotations
+
+import logging
+from typing import TYPE_CHECKING, Any
+
+from pydantic_ai.messages import ToolCallPart
+
+from airflow.providers.common.ai.toolsets.logging import LoggingToolset
+
+if TYPE_CHECKING:
+ from pydantic_ai.result import AgentRunResult
+ from pydantic_ai.toolsets.abstract import AbstractToolset
+
+ from airflow.sdk.types import Logger
+
+_MAX_OUTPUT_LEN = 500
+
+
+def log_run_summary(logger: Logger | logging.Logger, result:
AgentRunResult[Any]) -> None:
+ """Log model name, token usage, and tool call sequence from an agent
run."""
+ usage = result.usage()
+ model_name = getattr(result.response, "model_name", "unknown")
+ logger.info(
+ "::group::LLM run complete: model=%s, requests=%s, tool_calls=%s, "
+ "input_tokens=%s, output_tokens=%s, total_tokens=%s",
+ model_name,
+ usage.requests,
+ usage.tool_calls,
+ usage.input_tokens,
+ usage.output_tokens,
+ usage.total_tokens,
+ )
+
+ tool_names = _extract_tool_sequence(result)
+ if tool_names:
+ logger.info("Tool call sequence: %s", " -> ".join(tool_names))
+
+ _log_output_debug(logger, result.output)
+ logger.info("::endgroup::")
+
+
+def _log_output_debug(logger: Logger | logging.Logger, output: Any) -> None:
+ """Log a truncated representation of the agent output at DEBUG level."""
+ if not logger.isEnabledFor(logging.DEBUG):
+ return
+ from pydantic import BaseModel
+
+ if isinstance(output, BaseModel):
+ text = repr(output.model_dump())
+ else:
+ text = repr(output)
+ if len(text) > _MAX_OUTPUT_LEN:
+ text = text[:_MAX_OUTPUT_LEN] + "..."
+ logger.debug("Output: %s", text)
+
+
+def _extract_tool_sequence(result: AgentRunResult[Any]) -> list[str]:
+ """Extract ordered tool names from the message history."""
+ tool_names: list[str] = []
+ for message in result.all_messages():
+ for part in getattr(message, "parts", []):
+ if isinstance(part, ToolCallPart):
+ tool_names.append(part.tool_name)
+ return tool_names
+
+
+def wrap_toolsets_for_logging(
+ toolsets: list[AbstractToolset[Any]],
+ logger: Logger | logging.Logger,
+) -> list[AbstractToolset[Any]]:
+ """Wrap each toolset in a LoggingToolset."""
+ return [LoggingToolset(wrapped=ts, logger=logger) for ts in toolsets]
diff --git a/providers/common/ai/tests/unit/common/ai/conftest.py
b/providers/common/ai/tests/unit/common/ai/conftest.py
new file mode 100644
index 00000000000..0e6a49f36cb
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/conftest.py
@@ -0,0 +1,35 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+from unittest.mock import MagicMock
+
+
+def make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary.
+
+ Returns a MagicMock with .output, .usage(), .response, and .all_messages()
+ configured so that log_run_summary can read them without error.
+ """
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
index 99e2e35aafc..52ed82c53cd 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_agent.py
@@ -19,8 +19,22 @@ from __future__ import annotations
from unittest.mock import MagicMock, patch
import pytest
+from pydantic import BaseModel
from airflow.providers.common.ai.decorators.agent import
_AgentDecoratedOperator
+from airflow.providers.common.ai.toolsets.logging import LoggingToolset
+
+
+def _make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary."""
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
class TestAgentDecoratedOperator:
@@ -31,9 +45,7 @@ class TestAgentDecoratedOperator:
def test_execute_calls_callable_and_returns_output(self, mock_hook_cls):
"""The callable's return value becomes the agent prompt."""
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = "The top customer is Acme Corp."
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result("The top
customer is Acme Corp.")
mock_hook_cls.return_value.create_agent.return_value = mock_agent
def my_prompt():
@@ -65,9 +77,7 @@ class TestAgentDecoratedOperator:
def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls):
"""op_kwargs are resolved by the callable to build the prompt."""
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = "done"
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result("done")
mock_hook_cls.return_value.create_agent.return_value = mock_agent
def my_prompt(topic):
@@ -88,9 +98,7 @@ class TestAgentDecoratedOperator:
def test_execute_passes_toolsets_through(self, mock_hook_cls):
"""Toolsets passed to the decorator are forwarded to the agent."""
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = "result"
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result("result")
mock_hook_cls.return_value.create_agent.return_value = mock_agent
mock_toolset = MagicMock()
@@ -104,20 +112,20 @@ class TestAgentDecoratedOperator:
op.execute(context={})
create_call = mock_hook_cls.return_value.create_agent.call_args
- assert create_call[1]["toolsets"] == [mock_toolset]
+ passed_toolsets = create_call[1]["toolsets"]
+ assert len(passed_toolsets) == 1
+ assert isinstance(passed_toolsets[0], LoggingToolset)
+ assert passed_toolsets[0].wrapped is mock_toolset
@patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
def test_execute_structured_output(self, mock_hook_cls):
"""BaseModel output is serialized with model_dump."""
- from pydantic import BaseModel
class Summary(BaseModel):
text: str
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = Summary(text="Great results")
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value =
_make_mock_run_result(Summary(text="Great results"))
mock_hook_cls.return_value.create_agent.return_value = mock_agent
op = _AgentDecoratedOperator(
diff --git a/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
index d7161dfef1c..ff3240529ff 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm.py
@@ -23,6 +23,18 @@ import pytest
from airflow.providers.common.ai.decorators.llm import _LLMDecoratedOperator
+def _make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary."""
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
+
+
class TestLLMDecoratedOperator:
def test_custom_operator_name(self):
assert _LLMDecoratedOperator.custom_operator_name == "@task.llm"
@@ -31,9 +43,7 @@ class TestLLMDecoratedOperator:
def test_execute_calls_callable_and_returns_output(self, mock_hook_cls):
"""The callable's return value becomes the LLM prompt."""
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = "This is a summary."
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result("This is a
summary.")
mock_hook_cls.return_value.create_agent.return_value = mock_agent
def my_prompt():
@@ -65,9 +75,7 @@ class TestLLMDecoratedOperator:
def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls):
"""op_kwargs are resolved by the callable to build the prompt."""
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = "done"
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result("done")
mock_hook_cls.return_value.create_agent.return_value = mock_agent
def my_prompt(topic):
diff --git
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
index 66620426a3b..8daf2799b5f 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_branch.py
@@ -25,6 +25,18 @@ from airflow.providers.common.ai.decorators.llm_branch
import _LLMBranchDecorate
from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+def _make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary."""
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
+
+
class TestLLMBranchDecoratedOperator:
def test_custom_operator_name(self):
assert _LLMBranchDecoratedOperator.custom_operator_name ==
"@task.llm_branch"
@@ -36,9 +48,7 @@ class TestLLMBranchDecoratedOperator:
downstream_enum = Enum("DownstreamTasks", {"positive": "positive",
"negative": "negative"})
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = downstream_enum.positive
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.positive)
mock_hook_cls.return_value.create_agent.return_value = mock_agent
mock_do_branch.return_value = "positive"
@@ -81,9 +91,7 @@ class TestLLMBranchDecoratedOperator:
downstream_enum = Enum("DownstreamTasks", {"task_a": "task_a"})
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = downstream_enum.task_a
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.task_a)
mock_hook_cls.return_value.create_agent.return_value = mock_agent
def my_prompt(ticket_type):
diff --git
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
index 81ee53806c3..7722ea4e450 100644
---
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
+++
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_schema_compare.py
@@ -27,6 +27,18 @@ from
airflow.providers.common.ai.operators.llm_schema_compare import (
)
+def _make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary."""
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
+
+
def _make_compare_result():
return SchemaCompareResult(
mismatches=[],
@@ -36,10 +48,8 @@ def _make_compare_result():
def _make_mock_agent(output: SchemaCompareResult):
- mock_result = MagicMock(spec=["output"])
- mock_result.output = output
mock_agent = MagicMock(spec=["run_sync"])
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result(output)
return mock_agent
diff --git
a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
index 849e9c1e1fd..a1b7e3e3f36 100644
--- a/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/decorators/test_llm_sql.py
@@ -23,6 +23,18 @@ import pytest
from airflow.providers.common.ai.decorators.llm_sql import
_LLMSQLDecoratedOperator
+def _make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary."""
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
+
+
class TestLLMSQLDecoratedOperator:
def test_custom_operator_name(self):
assert _LLMSQLDecoratedOperator.custom_operator_name == "@task.llm_sql"
@@ -31,9 +43,7 @@ class TestLLMSQLDecoratedOperator:
def test_execute_calls_callable_and_uses_result_as_prompt(self,
mock_hook_cls):
"""The user's callable return value becomes the LLM prompt."""
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = "SELECT 1"
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result("SELECT 1")
mock_hook_cls.return_value.create_agent.return_value = mock_agent
def my_prompt_fn():
@@ -65,9 +75,7 @@ class TestLLMSQLDecoratedOperator:
def test_execute_merges_op_kwargs_into_callable(self, mock_hook_cls):
"""op_kwargs are resolved by the callable to build the prompt."""
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = "SELECT 1"
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result("SELECT 1")
mock_hook_cls.return_value.create_agent.return_value = mock_agent
def my_prompt_fn(table_name):
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
index 3d949854189..960876ecdda 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_agent.py
@@ -22,14 +22,25 @@ import pytest
from pydantic import BaseModel
from airflow.providers.common.ai.operators.agent import AgentOperator
+from airflow.providers.common.ai.toolsets.logging import LoggingToolset
+
+
+def _make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary."""
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
def _make_mock_agent(output):
"""Create a mock agent that returns the given output."""
- mock_result = MagicMock(spec=["output"])
- mock_result.output = output
mock_agent = MagicMock(spec=["run_sync"])
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result(output)
return mock_agent
@@ -80,6 +91,27 @@ class TestAgentOperatorExecute:
)
op.execute(context=MagicMock())
+ create_call = mock_hook_cls.return_value.create_agent.call_args
+ passed_toolsets = create_call[1]["toolsets"]
+ assert len(passed_toolsets) == 1
+ assert isinstance(passed_toolsets[0], LoggingToolset)
+ assert passed_toolsets[0].wrapped is mock_toolset
+
+ @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook",
autospec=True)
+ def test_enable_tool_logging_false_skips_wrapping(self, mock_hook_cls):
+ """enable_tool_logging=False passes toolsets through unwrapped."""
+ mock_hook_cls.return_value.create_agent.return_value =
_make_mock_agent("done")
+
+ mock_toolset = MagicMock()
+ op = AgentOperator(
+ task_id="test",
+ prompt="Do something",
+ llm_conn_id="my_llm",
+ toolsets=[mock_toolset],
+ enable_tool_logging=False,
+ )
+ op.execute(context=MagicMock())
+
create_call = mock_hook_cls.return_value.create_agent.call_args
assert create_call[1]["toolsets"] == [mock_toolset]
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
index de7b35000af..59ed0d446a2 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm.py
@@ -23,6 +23,18 @@ from pydantic import BaseModel
from airflow.providers.common.ai.operators.llm import LLMOperator
+def _make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary."""
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
+
+
class TestLLMOperator:
def test_template_fields(self):
expected = {"prompt", "llm_conn_id", "model_id", "system_prompt",
"agent_params"}
@@ -32,9 +44,7 @@ class TestLLMOperator:
def test_execute_returns_string_output(self, mock_hook_cls):
"""Default output_type=str returns the LLM string directly."""
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = "Paris is the capital of France."
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result("Paris is the
capital of France.")
mock_hook_cls.return_value.create_agent.return_value = mock_agent
op = LLMOperator(task_id="test", prompt="What is the capital of
France?", llm_conn_id="my_llm")
@@ -53,9 +63,7 @@ class TestLLMOperator:
names: list[str]
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = Entities(names=["Alice", "Bob"])
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value =
_make_mock_run_result(Entities(names=["Alice", "Bob"]))
mock_hook_cls.return_value.create_agent.return_value = mock_agent
op = LLMOperator(
diff --git
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
index d94fc552178..ffd475f71b8 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_branch.py
@@ -25,6 +25,18 @@ from airflow.providers.common.ai.operators.llm import
LLMOperator
from airflow.providers.common.ai.operators.llm_branch import LLMBranchOperator
+def _make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary."""
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
+
+
class TestLLMBranchOperator:
def test_inherits_from_skipmixin_is_true(self):
assert LLMBranchOperator.inherits_from_skipmixin is True
@@ -51,9 +63,7 @@ class TestLLMBranchOperator:
downstream_enum = Enum("DownstreamTasks", {"task_a": "task_a",
"task_b": "task_b"})
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = downstream_enum.task_a
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.task_a)
mock_hook_cls.return_value.create_agent.return_value = mock_agent
mock_do_branch.return_value = "task_a"
@@ -80,9 +90,9 @@ class TestLLMBranchOperator:
)
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = [downstream_enum.task_a, downstream_enum.task_c]
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result(
+ [downstream_enum.task_a, downstream_enum.task_c]
+ )
mock_hook_cls.return_value.create_agent.return_value = mock_agent
mock_do_branch.return_value = ["task_a", "task_c"]
@@ -107,9 +117,7 @@ class TestLLMBranchOperator:
downstream_enum = Enum("DownstreamTasks", {"task_a": "task_a"})
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = downstream_enum.task_a
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.task_a)
mock_hook_cls.return_value.create_agent.return_value = mock_agent
op = LLMBranchOperator(
@@ -134,9 +142,7 @@ class TestLLMBranchOperator:
)
mock_agent = MagicMock(spec=["run_sync"])
- mock_result = MagicMock(spec=["output"])
- mock_result.output = downstream_enum.billing
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value =
_make_mock_run_result(downstream_enum.billing)
mock_hook_cls.return_value.create_agent.return_value = mock_agent
op = LLMBranchOperator(
diff --git
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_schema_compare.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_schema_compare.py
index f778f88472f..784d94eb7fc 100644
---
a/providers/common/ai/tests/unit/common/ai/operators/test_llm_schema_compare.py
+++
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_schema_compare.py
@@ -30,6 +30,19 @@ from airflow.providers.common.sql.config import
DataSourceConfig
from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+def _make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary."""
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
+
+
_BASE_KWARGS = dict(task_id="test_task", prompt="test prompt",
llm_conn_id="llm_conn")
@@ -243,8 +256,8 @@ class TestLLMSchemaCompareOperator:
mock_llm_hook = mock.Mock()
mock_agent = mock.Mock()
- mock_agent.run_sync.return_value.output = SchemaCompareResult(
- compatible=True, mismatches=[], summary="All good"
+ mock_agent.run_sync.return_value = _make_mock_run_result(
+ SchemaCompareResult(compatible=True, mismatches=[], summary="All
good")
)
mock_llm_hook.create_agent.return_value = mock_agent
op.llm_hook = mock_llm_hook
@@ -319,8 +332,10 @@ class TestLLMSchemaCompareOperator:
mock_llm_hook = mock.Mock()
mock_agent = mock.Mock()
- mock_agent.run_sync.return_value.output = SchemaCompareResult(
- compatible=True, mismatches=[], summary="S3 and Postgres schemas
are compatible"
+ mock_agent.run_sync.return_value = _make_mock_run_result(
+ SchemaCompareResult(
+ compatible=True, mismatches=[], summary="S3 and Postgres
schemas are compatible"
+ )
)
mock_llm_hook.create_agent.return_value = mock_agent
op.llm_hook = mock_llm_hook
@@ -395,8 +410,8 @@ class TestLLMSchemaCompareOperator:
mock_llm_hook = mock.Mock()
mock_agent = mock.Mock()
- mock_agent.run_sync.return_value.output = SchemaCompareResult(
- compatible=True, mismatches=[], summary="Schemas are compatible"
+ mock_agent.run_sync.return_value = _make_mock_run_result(
+ SchemaCompareResult(compatible=True, mismatches=[],
summary="Schemas are compatible")
)
mock_llm_hook.create_agent.return_value = mock_agent
op.llm_hook = mock_llm_hook
@@ -444,10 +459,12 @@ class TestLLMSchemaCompareOperator:
mock_llm_hook = mock.Mock()
mock_agent = mock.Mock()
- mock_agent.run_sync.return_value.output = SchemaCompareResult(
- compatible=False,
- mismatches=[],
- summary="Timestamp column type differs between Parquet and CSV",
+ mock_agent.run_sync.return_value = _make_mock_run_result(
+ SchemaCompareResult(
+ compatible=False,
+ mismatches=[],
+ summary="Timestamp column type differs between Parquet and
CSV",
+ )
)
mock_llm_hook.create_agent.return_value = mock_agent
op.llm_hook = mock_llm_hook
diff --git a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
index 445df28c17c..b0a91f52edc 100644
--- a/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
+++ b/providers/common/ai/tests/unit/common/ai/operators/test_llm_sql.py
@@ -25,12 +25,22 @@ from airflow.providers.common.ai.utils.sql_validation
import SQLSafetyError
from airflow.providers.common.sql.config import DataSourceConfig
+def _make_mock_run_result(output):
+ """Create a mock AgentRunResult compatible with log_run_summary."""
+ mock_result = MagicMock()
+ mock_result.output = output
+ mock_result.usage.return_value = MagicMock(
+ requests=1, tool_calls=0, input_tokens=0, output_tokens=0,
total_tokens=0
+ )
+ mock_result.response = MagicMock(model_name="test-model")
+ mock_result.all_messages.return_value = []
+ return mock_result
+
+
def _make_mock_agent(output: str):
"""Create a mock agent that returns the given output string."""
- mock_result = MagicMock(spec=["output"])
- mock_result.output = output
mock_agent = MagicMock(spec=["run_sync"])
- mock_agent.run_sync.return_value = mock_result
+ mock_agent.run_sync.return_value = _make_mock_run_result(output)
return mock_agent
diff --git a/providers/common/ai/tests/unit/common/ai/toolsets/test_logging.py
b/providers/common/ai/tests/unit/common/ai/toolsets/test_logging.py
new file mode 100644
index 00000000000..2bf88a4cd25
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/toolsets/test_logging.py
@@ -0,0 +1,111 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from unittest.mock import AsyncMock, MagicMock
+
+import pytest
+
+from airflow.providers.common.ai.toolsets.logging import LoggingToolset
+
+
[email protected]
+def wrapped_toolset():
+ ts = AsyncMock()
+ ts.id = "test-toolset"
+ ts.get_tools = AsyncMock(return_value={"tool_a": MagicMock()})
+ return ts
+
+
[email protected]
+def logger():
+ return logging.getLogger("test.logging_toolset")
+
+
[email protected]
+def logging_toolset(wrapped_toolset, logger):
+ return LoggingToolset(wrapped=wrapped_toolset, logger=logger)
+
+
+class TestLoggingToolset:
+ @pytest.mark.asyncio
+ async def test_logs_tool_call_name(self, logging_toolset, wrapped_toolset,
logger, caplog):
+ wrapped_toolset.call_tool = AsyncMock(return_value="result")
+ ctx = MagicMock()
+ tool = MagicMock()
+
+ with caplog.at_level(logging.INFO, logger="test.logging_toolset"):
+ await logging_toolset.call_tool("list_tables", {"schema":
"public"}, ctx, tool)
+
+ assert any("::group::Tool call: list_tables" in r.message for r in
caplog.records)
+
+ @pytest.mark.asyncio
+ async def test_logs_args_at_debug_level(self, logging_toolset,
wrapped_toolset, logger, caplog):
+ wrapped_toolset.call_tool = AsyncMock(return_value="result")
+ ctx = MagicMock()
+ tool = MagicMock()
+
+ with caplog.at_level(logging.DEBUG, logger="test.logging_toolset"):
+ await logging_toolset.call_tool("list_tables", {"schema":
"public"}, ctx, tool)
+
+ debug_records = [r for r in caplog.records if r.levelno ==
logging.DEBUG]
+ assert any('Tool args: {"schema": "public"}' in r.message for r in
debug_records)
+
+ @pytest.mark.asyncio
+ async def test_logs_timing(self, logging_toolset, wrapped_toolset, logger,
caplog):
+ wrapped_toolset.call_tool = AsyncMock(return_value="ok")
+ ctx = MagicMock()
+ tool = MagicMock()
+
+ with caplog.at_level(logging.INFO, logger="test.logging_toolset"):
+ await logging_toolset.call_tool("query", {}, ctx, tool)
+
+ assert any("Tool query returned in" in r.message for r in
caplog.records)
+ assert any("::endgroup::" in r.message for r in caplog.records)
+
+ @pytest.mark.asyncio
+ async def test_logs_error_on_exception(self, logging_toolset,
wrapped_toolset, logger, caplog):
+ wrapped_toolset.call_tool = AsyncMock(side_effect=RuntimeError("boom"))
+ ctx = MagicMock()
+ tool = MagicMock()
+
+ with caplog.at_level(logging.INFO, logger="test.logging_toolset"):
+ with pytest.raises(RuntimeError, match="boom"):
+ await logging_toolset.call_tool("bad_tool", {}, ctx, tool)
+
+ assert any("Tool bad_tool failed after" in r.message for r in
caplog.records)
+ assert any("::endgroup::" in r.message for r in caplog.records)
+
+ @pytest.mark.asyncio
+ async def test_delegates_get_tools(self, logging_toolset, wrapped_toolset):
+ ctx = MagicMock()
+ tools = await logging_toolset.get_tools(ctx)
+
+ assert tools == {"tool_a":
wrapped_toolset.get_tools.return_value["tool_a"]}
+ wrapped_toolset.get_tools.assert_awaited_once_with(ctx)
+
+ @pytest.mark.asyncio
+ async def test_empty_args_not_logged(self, logging_toolset,
wrapped_toolset, caplog):
+ wrapped_toolset.call_tool = AsyncMock(return_value="ok")
+ ctx = MagicMock()
+ tool = MagicMock()
+
+ with caplog.at_level(logging.DEBUG, logger="test.logging_toolset"):
+ await logging_toolset.call_tool("list_tables", {}, ctx, tool)
+
+ assert not any("Tool args:" in r.message for r in caplog.records)
diff --git a/providers/common/ai/tests/unit/common/ai/utils/test_logging.py
b/providers/common/ai/tests/unit/common/ai/utils/test_logging.py
new file mode 100644
index 00000000000..230335a0e02
--- /dev/null
+++ b/providers/common/ai/tests/unit/common/ai/utils/test_logging.py
@@ -0,0 +1,151 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from unittest.mock import MagicMock
+
+from pydantic import BaseModel
+from pydantic_ai.messages import (
+ ModelResponse,
+ ModelResponsePart,
+ ToolCallPart,
+)
+
+from airflow.providers.common.ai.toolsets.logging import LoggingToolset
+from airflow.providers.common.ai.utils.logging import (
+ _log_output_debug,
+ log_run_summary,
+ wrap_toolsets_for_logging,
+)
+
+
+def _make_mock_result(model_name="gpt-5", tool_names=None, usage_kwargs=None):
+ """Build a mock AgentRunResult with usage, response, and messages."""
+ usage_kwargs = usage_kwargs or {
+ "requests": 4,
+ "tool_calls": 3,
+ "input_tokens": 2847,
+ "output_tokens": 512,
+ "total_tokens": 3359,
+ }
+ result = MagicMock()
+ result.usage.return_value = MagicMock(**usage_kwargs)
+ result.response = MagicMock(model_name=model_name)
+
+ messages: list = []
+ if tool_names:
+ parts: list[ModelResponsePart] = [ToolCallPart(tool_name=name,
args="{}") for name in tool_names]
+ messages.append(ModelResponse(parts=parts))
+ result.all_messages.return_value = messages
+ return result
+
+
+class TestLogRunSummary:
+ def test_logs_usage(self, caplog):
+ logger = logging.getLogger("test.log_run_summary")
+ result = _make_mock_result()
+
+ with caplog.at_level(logging.INFO, logger="test.log_run_summary"):
+ log_run_summary(logger, result)
+
+ records = [r for r in caplog.records if r.name ==
"test.log_run_summary"]
+ summary_line = records[0].message
+ assert summary_line.startswith("::group::")
+ assert "model=gpt-5" in summary_line
+ assert "requests=4" in summary_line
+ assert "tool_calls=3" in summary_line
+ assert "input_tokens=2847" in summary_line
+ assert "output_tokens=512" in summary_line
+ assert "total_tokens=3359" in summary_line
+ assert records[-1].message == "::endgroup::"
+
+ def test_logs_tool_sequence(self, caplog):
+ logger = logging.getLogger("test.log_run_summary")
+ result = _make_mock_result(tool_names=["list_tables", "get_schema",
"query"])
+
+ with caplog.at_level(logging.INFO, logger="test.log_run_summary"):
+ log_run_summary(logger, result)
+
+ records = [r for r in caplog.records if r.name ==
"test.log_run_summary"]
+ tool_line = records[1].message
+ assert tool_line == "Tool call sequence: list_tables -> get_schema ->
query"
+ assert records[-1].message == "::endgroup::"
+
+ def test_no_tools_skips_sequence_line(self, caplog):
+ logger = logging.getLogger("test.log_run_summary")
+ result = _make_mock_result(tool_names=None)
+
+ with caplog.at_level(logging.INFO, logger="test.log_run_summary"):
+ log_run_summary(logger, result)
+
+ records = [r for r in caplog.records if r.name ==
"test.log_run_summary"]
+ assert len(records) == 2 # summary line + endgroup (no tool sequence)
+ assert records[-1].message == "::endgroup::"
+
+
+class TestLogOutputDebug:
+ def test_logs_string_output(self, caplog):
+ logger = logging.getLogger("test.output_debug")
+ with caplog.at_level(logging.DEBUG, logger="test.output_debug"):
+ _log_output_debug(logger, "Hello world")
+
+ debug_records = [r for r in caplog.records if r.levelno ==
logging.DEBUG]
+ assert any("Output: 'Hello world'" in r.message for r in debug_records)
+
+ def test_logs_pydantic_model_dump(self, caplog):
+ class Info(BaseModel):
+ name: str
+
+ logger = logging.getLogger("test.output_debug")
+ with caplog.at_level(logging.DEBUG, logger="test.output_debug"):
+ _log_output_debug(logger, Info(name="Alice"))
+
+ debug_records = [r for r in caplog.records if r.levelno ==
logging.DEBUG]
+ assert any("'name': 'Alice'" in r.message for r in debug_records)
+
+ def test_truncates_long_output(self, caplog):
+ logger = logging.getLogger("test.output_debug")
+ long_text = "x" * 1000
+ with caplog.at_level(logging.DEBUG, logger="test.output_debug"):
+ _log_output_debug(logger, long_text)
+
+ debug_records = [r for r in caplog.records if r.levelno ==
logging.DEBUG]
+ assert any(r.message.endswith("...") for r in debug_records)
+
+ def test_skipped_when_debug_disabled(self, caplog):
+ logger = logging.getLogger("test.output_debug")
+ with caplog.at_level(logging.INFO, logger="test.output_debug"):
+ _log_output_debug(logger, "should not appear")
+
+ debug_records = [r for r in caplog.records if r.levelno ==
logging.DEBUG]
+ assert len(debug_records) == 0
+
+
+class TestWrapToolsetsForLogging:
+ def test_wraps_each_toolset(self):
+ ts_a = MagicMock()
+ ts_b = MagicMock()
+ logger = logging.getLogger("test.wrap")
+
+ wrapped = wrap_toolsets_for_logging([ts_a, ts_b], logger)
+
+ assert len(wrapped) == 2
+ assert all(isinstance(w, LoggingToolset) for w in wrapped)
+ assert wrapped[0].wrapped is ts_a
+ assert wrapped[1].wrapped is ts_b
+ assert wrapped[0].logger is logger