Copilot commented on code in PR #64433:
URL: https://github.com/apache/airflow/pull/64433#discussion_r3025334466


##########
providers/common/ai/tests/unit/common/ai/operators/test_agent.py:
##########
@@ -42,11 +43,39 @@ def _make_mock_run_result(output):
 
 def _make_mock_agent(output):
     """Create a mock agent that returns the given output."""
-    mock_agent = MagicMock(spec=["run_sync"])
+    mock_agent = MagicMock(spec=["run_sync", "model", "override"])
     mock_agent.run_sync.return_value = _make_mock_run_result(output)
+    mock_agent.model = "test-model"
+    mock_agent.override.return_value.__enter__ = MagicMock(return_value=None)
+    mock_agent.override.return_value.__exit__ = MagicMock(return_value=False)
     return mock_agent
 
 
+class _DummyToolset:
+    id = "dummy-toolset"
+
+
+def _configure_mock_hook(
+    mock_hook_cls, *, agent=None, conn_type: str = "pydanticai", model_id: str 
| None = None
+):
+    mock_hook = mock_hook_cls.get_hook.return_value
+    mock_hook.conn_type = conn_type
+    mock_hook.model_id = model_id
+    if agent is not None:
+        mock_hook.create_agent.return_value = agent
+    return mock_hook
+
+
+def _make_mock_context(map_index: int = -1):
+    ti = MagicMock(spec=["xcom_push", "dag_id", "run_id", "task_id", 
"map_index"])
+    ti.dag_id = "example_dag"
+    ti.run_id = "run_1"
+    ti.task_id = "test"
+    ti.map_index = map_index
+    ti.xcom_push = MagicMock()

Review Comment:
   `ti` is already created with a spec that includes `xcom_push`, so 
reassigning `ti.xcom_push = MagicMock()` is redundant and also creates an 
unspecced mock. Prefer relying on the existing `ti.xcom_push` mock created by 
`MagicMock(spec=...)`, or assign a specced callable mock if you need custom 
behavior.
   ```suggestion
   
   ```



##########
providers/common/ai/tests/unit/common/ai/operators/test_agent.py:
##########
@@ -300,11 +334,102 @@ def 
test_execute_propagates_hitl_max_iterations_error(self, mock_hook_cls, mock_
             max_hitl_iterations=5,
             hitl_timeout=timedelta(minutes=5),
         )
-        context = MagicMock()
+        context, _ = _make_mock_context()
 
         with pytest.raises(HITLMaxIterationsError, match="Task exceeded max 
iterations"):
             op.execute(context=context)
 
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_pushes_policy_exposure_report_to_xcom(self, 
mock_hook_cls):
+        _configure_mock_hook(mock_hook_cls, agent=_make_mock_agent("ok"))
+        op = AgentOperator(task_id="test", prompt="test", llm_conn_id="my_llm")
+        context, ti = _make_mock_context()
+
+        op.execute(context=context)
+
+        xcom_push_calls = [call.kwargs for call in ti.xcom_push.call_args_list]
+        policy_push = next(call for call in xcom_push_calls if call["key"] == 
XCOM_POLICY_EXPOSURE)
+        assert policy_push["value"]["task"]["dag_id"] == "example_dag"
+        assert policy_push["value"]["task"]["map_index"] == -1
+        assert policy_push["value"]["llm"]["connection_type"] == "pydanticai"
+        assert policy_push["value"]["llm"]["model_id"] is None
+        assert policy_push["value"]["risk"]["level"] == "low"
+
+    @patch("airflow.providers.common.ai.operators.agent.PydanticAIHook", 
autospec=True)
+    def test_execute_policy_report_includes_runtime_notes_and_map_index(self, 
mock_hook_cls):
+        _configure_mock_hook(mock_hook_cls, agent=_make_mock_agent("ok"))
+        op = AgentOperator(
+            task_id="test",
+            prompt="test",
+            llm_conn_id="my_llm",
+            durable=True,
+            enable_tool_logging=True,
+        )
+        context, ti = _make_mock_context(map_index=3)
+
+        with (
+            patch(
+                
"airflow.providers.common.ai.operators.agent.AgentOperator._build_durable_toolsets",
+                autospec=True,
+                return_value=[],
+            ),
+            
patch("airflow.providers.common.ai.durable.storage._get_base_path"),
+            patch("pydantic_ai.models.infer_model", autospec=True, 
return_value=MagicMock()),
+            patch("pydantic_ai.models.wrapper.infer_model", side_effect=lambda 
model: model),
+        ):

Review Comment:
   The patched `infer_model` uses `return_value=MagicMock()` without a spec, 
which can mask interface mismatches with the real model object. Use a specced 
mock (or a lightweight fake) that matches the attributes accessed by 
`CachingModel`/`agent.override` in this code path.



##########
providers/common/ai/tests/unit/common/ai/utils/test_policy_exposure.py:
##########
@@ -0,0 +1,173 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import logging
+from unittest.mock import MagicMock
+
+import pytest
+
+from airflow.providers.common.ai.toolsets.logging import LoggingToolset
+from airflow.providers.common.ai.utils.policy_exposure import (
+    LLMExposure,
+    ResourceExposure,
+    ToolsetExposure,
+    _dedupe_reasons,
+    classify_policy_risk,
+    describe_toolset_exposure,
+    unwrap_toolset,
+)
+
+
+class _UnknownToolset:
+    id = "custom"
+
+
+class _BaseToolset:
+    def describe_policy_exposure(self) -> ToolsetExposure:
+        return ToolsetExposure(
+            toolset_type="BaseToolset",
+            toolset_id="base",
+            summary="base toolset",
+        )
+
+
+class _BrokenToolset:
+    id = "broken"
+
+    def describe_policy_exposure(self) -> ToolsetExposure:
+        raise RuntimeError("boom")
+
+
+def test_unwrap_toolset_returns_base_toolset():
+    wrapped = _BaseToolset()
+    toolset = LoggingToolset(wrapped=wrapped, 
logger=logging.getLogger(__name__))
+
+    base_toolset = unwrap_toolset(toolset)
+
+    assert base_toolset is wrapped
+
+
+def test_describe_toolset_exposure_uses_base_toolset_for_wrappers():
+    wrapped = MagicMock()

Review Comment:
   `MagicMock()` is created without a spec, which can hide bugs by allowing any 
attribute/method access. Consider using 
`MagicMock(spec=["describe_policy_exposure", "id"])` (or a concrete toolset 
class) so the mock matches the expected toolset surface.
   ```suggestion
       wrapped = MagicMock(spec=["describe_policy_exposure", "id"])
   ```



##########
providers/common/ai/tests/unit/common/ai/decorators/test_agent.py:
##########
@@ -37,6 +37,28 @@ def _make_mock_run_result(output):
     return mock_result
 
 
+class _DummyToolset:
+    id = "dummy-toolset"
+
+
+def _configure_mock_hook(mock_hook_cls, *, agent, conn_type: str = 
"pydanticai", model_id: str | None = None):
+    mock_hook = mock_hook_cls.get_hook.return_value
+    mock_hook.create_agent.return_value = agent
+    mock_hook.conn_type = conn_type
+    mock_hook.model_id = model_id
+    return mock_hook
+
+
+def _make_mock_context():
+    ti = MagicMock(spec=["xcom_push", "dag_id", "run_id", "task_id", 
"map_index"])
+    ti.dag_id = "example_dag"
+    ti.run_id = "run_1"
+    ti.task_id = "test"
+    ti.map_index = -1
+    ti.xcom_push = MagicMock()

Review Comment:
   Similar to the operator tests, `ti` is created with a spec that already 
includes `xcom_push`, so reassigning `ti.xcom_push = MagicMock()` is redundant 
and leaves you with an unspecced callable mock. Prefer relying on the existing 
`ti.xcom_push` mock from `MagicMock(spec=...)`, or replace it with a specced 
callable if needed.
   ```suggestion
   
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -119,6 +121,44 @@ def id(self) -> str:
         suffix = "_".join(config.table_name.replace("-", "_") for config in 
self._datasource_configs)
         return f"sql_datafusion_{suffix}"
 
+    def describe_policy_exposure(self) -> ToolsetExposure:
+        resources: list[ResourceExposure] = []
+        for config in self._datasource_configs:
+            resources.append(
+                ResourceExposure(
+                    category="datasource",
+                    name=config.table_name,
+                    access_mode="read_write" if self._allow_writes else "read",
+                    details={"conn_id": config.conn_id, "format": 
config.format},
+                )
+            )
+            resources.append(
+                ResourceExposure(
+                    category="uri",
+                    name=config.uri,
+                    access_mode="read_write" if self._allow_writes else "read",
+                    details={"table_name": config.table_name},
+                )
+            )

Review Comment:
   `DataSourceConfig.uri` is optional for catalog-managed/table-provider 
formats (e.g. Iceberg) and may be an empty string. This implementation always 
adds a `ResourceExposure(category="uri", name=config.uri, ...)`, which can 
yield empty/meaningless URI entries in the report. Consider only emitting a 
`uri` resource when `config.uri` is non-empty (and/or adding a different 
resource describing the catalog/table identifier from `config.options` when 
`is_table_provider` is true).
   ```suggestion
               if config.uri:
                   resources.append(
                       ResourceExposure(
                           category="uri",
                           name=config.uri,
                           access_mode="read_write" if self._allow_writes else 
"read",
                           details={"table_name": config.table_name},
                       )
                   )
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/utils/policy_exposure.py:
##########
@@ -0,0 +1,204 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Shared models and helpers for configured policy exposure snapshots."""
+
+from __future__ import annotations
+
+from datetime import datetime, timezone
+from typing import Any, Literal
+
+from pydantic import BaseModel, Field
+from pydantic_ai.toolsets.wrapper import WrapperToolset
+
+XCOM_POLICY_EXPOSURE = "airflow_common_ai_policy_exposure"
+
+PolicyRiskLevel = Literal["low", "medium", "high"]
+ResourceCategory = Literal[
+    "database",
+    "table",
+    "schema",
+    "datasource",
+    "uri",
+    "hook_method",
+    "mcp_server",
+    "tool_prefix",
+    "unknown",
+]
+AccessMode = Literal["read", "write", "read_write", "unknown"]
+
+
+class TaskIdentity(BaseModel):
+    """Identifying information for the task instance that produced the 
report."""
+
+    dag_id: str
+    run_id: str
+    task_id: str
+    map_index: int = -1
+    operator_type: str
+
+
+class LLMExposure(BaseModel):
+    """Configuration-derived LLM access surface."""
+
+    llm_conn_id: str
+    connection_type: str | None = None
+    model_id: str | None = None
+
+
+class ApprovalExposure(BaseModel):
+    """Review and approval controls applied to the task."""
+
+    enable_hitl_review: bool = False
+    max_hitl_iterations: int | None = None
+
+
+class ResourceExposure(BaseModel):
+    """A single configured resource or capability exposed to the agent."""
+
+    category: ResourceCategory
+    name: str
+    access_mode: AccessMode = "unknown"
+    details: dict[str, Any] = Field(default_factory=dict)
+
+
+class ToolsetExposure(BaseModel):
+    """Exposure summary for one configured toolset."""
+
+    toolset_type: str
+    toolset_id: str | None = None
+    summary: str
+    resources: list[ResourceExposure] = Field(default_factory=list)
+    risk_flags: list[str] = Field(default_factory=list)
+
+
+class PolicyRiskSummary(BaseModel):
+    """High-level risk summary for the configured exposure snapshot."""
+
+    level: PolicyRiskLevel
+    reasons: list[str] = Field(default_factory=list)
+
+
+class PolicyExposureReport(BaseModel):
+    """Configured policy exposure snapshot for an AI task instance."""
+
+    captured_at: datetime = Field(default_factory=lambda: 
datetime.now(timezone.utc))
+    task: TaskIdentity
+    llm: LLMExposure
+    approval: ApprovalExposure
+    toolsets: list[ToolsetExposure] = Field(default_factory=list)
+    runtime_notes: list[str] = Field(default_factory=list)
+    risk: PolicyRiskSummary
+
+
+def classify_policy_risk(
+    *, llm: LLMExposure, toolsets: list[ToolsetExposure], runtime_notes: 
list[str]
+) -> PolicyRiskSummary:
+    """Classify overall risk using deterministic rules based on configured 
access."""
+    reasons: list[str] = []
+    has_write_access = any(
+        resource.access_mode in {"write", "read_write"}
+        for toolset in toolsets
+        for resource in toolset.resources
+    )
+
+    for toolset in toolsets:
+        reasons.extend(toolset.risk_flags)
+
+    if has_write_access and not any(
+        "write-capable" in reason or "write access" in reason for reason in 
reasons
+    ):
+        reasons.append("write-capable tool access configured")
+
+    deduped_reasons = _dedupe_reasons(reasons)
+
+    if any(
+        reason in deduped_reasons
+        for reason in ("unknown toolset exposure", "potentially mutating hook 
methods exposed")
+    ):
+        return PolicyRiskSummary(level="high", reasons=deduped_reasons or 
["unknown toolset exposure"])
+
+    if has_write_access:
+        return PolicyRiskSummary(level="high", reasons=deduped_reasons)
+
+    if deduped_reasons:
+        return PolicyRiskSummary(level="medium", reasons=deduped_reasons)
+
+    if runtime_notes:
+        return PolicyRiskSummary(level="low", reasons=["configured access 
includes runtime controls"])
+
+    return PolicyRiskSummary(level="low", reasons=["no external tool access 
configured"])
+
+
+def unwrap_toolset(toolset: Any) -> Any:
+    """Unwrap known wrapper-style toolsets to the base policy surface."""
+    current = toolset
+    seen: set[int] = set()
+
+    while isinstance(current, WrapperToolset) and current.wrapped is not None:
+        current_id = id(current)
+        if current_id in seen:
+            break
+        seen.add(current_id)
+        current = current.wrapped
+
+    return current
+
+
+def describe_toolset_exposure(toolset: Any) -> ToolsetExposure:
+    """Describe a toolset's configured exposure with a safe fallback for 
unknown toolsets."""
+    base_toolset = unwrap_toolset(toolset)
+    describe = getattr(base_toolset, "describe_policy_exposure", None)
+    if callable(describe):
+        try:
+            exposure = describe()
+            if isinstance(exposure, ToolsetExposure):
+                return exposure
+            return ToolsetExposure(
+                toolset_type=type(base_toolset).__name__,
+                toolset_id=_get_toolset_id(base_toolset),
+                summary="Toolset returned an invalid policy exposure report.",
+                risk_flags=["invalid toolset exposure report"],
+            )
+        except Exception:
+            return ToolsetExposure(
+                toolset_type=type(base_toolset).__name__,
+                toolset_id=_get_toolset_id(base_toolset),
+                summary="Toolset exposure details are unavailable because 
report generation failed.",
+                risk_flags=["toolset exposure report failed"],
+            )

Review Comment:
   `describe_toolset_exposure()` swallows exceptions from 
`describe_policy_exposure()` and returns a fallback exposure, but it does not 
log the failure. The docs for the policy exposure report state that toolset 
report generation failures are logged; consider logging the exception here (or 
re-raising and logging at the caller) so operators have visibility into why a 
toolset’s exposure report was unavailable.



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -119,6 +121,44 @@ def id(self) -> str:
         suffix = "_".join(config.table_name.replace("-", "_") for config in 
self._datasource_configs)
         return f"sql_datafusion_{suffix}"
 
+    def describe_policy_exposure(self) -> ToolsetExposure:
+        resources: list[ResourceExposure] = []
+        for config in self._datasource_configs:
+            resources.append(
+                ResourceExposure(
+                    category="datasource",
+                    name=config.table_name,
+                    access_mode="read_write" if self._allow_writes else "read",
+                    details={"conn_id": config.conn_id, "format": 
config.format},
+                )
+            )
+            resources.append(
+                ResourceExposure(
+                    category="uri",
+                    name=config.uri,
+                    access_mode="read_write" if self._allow_writes else "read",
+                    details={"table_name": config.table_name},
+                )

Review Comment:
   `describe_policy_exposure()` currently stores the raw `config.uri` value 
into the XCom-backed policy exposure report. URIs can sometimes include 
sensitive details (e.g. bucket/key names, local paths, or embedded 
credentials/query params depending on scheme), and XCom values may be broadly 
visible. Consider redacting or normalizing URIs before persisting (e.g., strip 
credentials/query params, or only store scheme + bucket/table identifier).



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to