Copilot commented on code in PR #62850:
URL: https://github.com/apache/airflow/pull/62850#discussion_r2886669063


##########
providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py:
##########
@@ -0,0 +1,300 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.toolsets.datafusion import DataFusionToolset
+
+
+def _make_mock_datasource_config(table_name: str = "sales_data"):
+    """Create a mock DataSourceConfig."""
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+    mock = MagicMock(spec=DataSourceConfig)
+    mock.table_name = table_name
+    return mock
+
+
+def _make_mock_engine(
+    registered_tables: dict[str, str] | None = None,
+    schema_fields: list[tuple[str, str]] | None = None,
+    query_result: dict[str, list] | None = None,
+):
+    """Create a mock DataFusionEngine with sensible defaults."""
+    mock = MagicMock()
+    tables = registered_tables or {"sales_data": "s3://bucket/sales/"}
+    mock.registered_tables = tables
+    mock.session_context.catalog().schema().table_names.return_value = 
list(tables.keys())

Review Comment:
   `_make_mock_engine` creates an unspecced `MagicMock()` for the engine, then 
configures deep chained attributes on it. This can hide typos / API mismatches 
in the `SessionContext` calls under test. Prefer using `create_autospec` / 
`MagicMock(spec=...)` for the engine and its `session_context` to keep the 
tests aligned with the real API.



##########
providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py:
##########
@@ -0,0 +1,300 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.toolsets.datafusion import DataFusionToolset
+
+
+def _make_mock_datasource_config(table_name: str = "sales_data"):
+    """Create a mock DataSourceConfig."""
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+    mock = MagicMock(spec=DataSourceConfig)
+    mock.table_name = table_name
+    return mock
+
+
+def _make_mock_engine(
+    registered_tables: dict[str, str] | None = None,
+    schema_fields: list[tuple[str, str]] | None = None,
+    query_result: dict[str, list] | None = None,
+):
+    """Create a mock DataFusionEngine with sensible defaults."""
+    mock = MagicMock()
+    tables = registered_tables or {"sales_data": "s3://bucket/sales/"}
+    mock.registered_tables = tables
+    mock.session_context.catalog().schema().table_names.return_value = 
list(tables.keys())
+    mock.session_context.table_exist.side_effect = lambda name: name in tables
+
+    fields = schema_fields or [("id", "Int64"), ("amount", "Float64")]
+    arrow_fields = []
+    for name, ftype in fields:
+        field = MagicMock()
+        field.name = name
+        field.type = ftype
+        arrow_fields.append(field)
+    for tname in tables:
+        mock.session_context.table(tname).schema.return_value = arrow_fields
+
+    mock.execute_query.return_value = (
+        query_result
+        if query_result is not None
+        else {
+            "id": [1, 2],
+            "amount": [10.5, 20.0],
+        }
+    )
+    return mock
+
+
+class TestDataFusionToolsetInit:
+    def test_id_includes_table_names(self):
+        cfg_a = _make_mock_datasource_config("alpha")
+        cfg_b = _make_mock_datasource_config("beta")
+        ts = DataFusionToolset([cfg_b, cfg_a])
+        assert ts.id == "sql_datafusion_beta_alpha"
+
+    def test_single_table_id(self):
+        cfg = _make_mock_datasource_config("orders")
+        ts = DataFusionToolset([cfg])
+        assert ts.id == "sql_datafusion_orders"
+
+
+class TestDataFusionToolsetGetTools:
+    def test_returns_three_tools(self):
+        cfg = _make_mock_datasource_config()
+        ts = DataFusionToolset([cfg])
+        tools = asyncio.run(ts.get_tools(ctx=MagicMock()))
+        assert set(tools.keys()) == {"list_tables", "get_schema", "query"}

Review Comment:
   These tests frequently pass `MagicMock()` for `ctx` and `tool` without a 
spec. Using a spec (e.g., `RunContext` for `ctx` and `ToolsetTool` for `tool`) 
helps ensure the call signature stays correct and avoids false-positive tests 
if the underlying interface changes.



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store 
workflows."""
+
+from __future__ import annotations
+
+import json
+import logging
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.ai.utils.sql_validation import 
SQLSafetyError, validate_sql as _validate_sql
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+    from airflow.providers.common.sql.datafusion.exceptions import 
QueryExecutionException
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+    from pydantic_ai._run_context import RunContext
+
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+log = logging.getLogger(__name__)
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the three DataFusion tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "table_name": {"type": "string", "description": "Name of the table to 
inspect."},
+    },
+    "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to execute."},
+    },
+    "required": ["sql"],
+}
+
+
+class DataFusionToolset(AbstractToolset[Any]):
+    """
+    Curated toolset that gives an LLM agent SQL access to object-storage data 
via Apache DataFusion.
+
+    Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
+    backed by
+    :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.
+
+    Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+    registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
+    local storage. Multiple configs can be registered so that SQL queries can
+    join across tables.
+
+    Requires the ``datafusion`` extra of 
``apache-airflow-providers-common-sql``.
+
+    :param datasource_configs: One or more DataFusion data-source 
configurations.
+    :param allow_writes: Allow data-modifying SQL (CREATE TABLE, CREATE VIEW,
+        INSERT INTO, etc.). Default ``False`` — only SELECT-family statements
+        are permitted.
+    :param max_rows: Maximum number of rows returned from the ``query`` tool.
+        Default ``50``.
+    """
+
+    def __init__(
+        self,
+        datasource_configs: list[DataSourceConfig],
+        *,
+        allow_writes: bool = False,
+        max_rows: int = 50,
+    ) -> None:
+        self._datasource_configs = datasource_configs
+        self._allow_writes = allow_writes
+        self._max_rows = max_rows
+        self._engine: DataFusionEngine | None = None
+
+    @property
+    def id(self) -> str:
+        suffix = "_".join(config.table_name.replace("-", "_") for config in 
self._datasource_configs)
+        return f"sql_datafusion_{suffix}"
+
+    def _get_engine(self) -> DataFusionEngine:
+        """Lazily create and configure a DataFusionEngine from 
*datasource_configs*."""
+        if self._engine is None:
+            engine = DataFusionEngine()
+            for config in self._datasource_configs:
+                engine.register_datasource(config)
+            self._engine = engine
+        return self._engine
+
+    async def get_tools(self, ctx: RunContext[Any]) -> dict[str, 
ToolsetTool[Any]]:
+        tools: dict[str, ToolsetTool[Any]] = {}
+
+        for name, description, schema in (
+            ("list_tables", "List available table names.", 
_LIST_TABLES_SCHEMA),
+            ("get_schema", "Get column names and types for a table.", 
_GET_SCHEMA_SCHEMA),
+            ("query", "Execute a SQL query and return rows as JSON.", 
_QUERY_SCHEMA),
+        ):
+            tool_def = ToolDefinition(
+                name=name,
+                description=description,
+                parameters_json_schema=schema,
+                sequential=True,
+            )
+            tools[name] = ToolsetTool(
+                toolset=self,
+                tool_def=tool_def,
+                max_retries=1,
+                args_validator=_PASSTHROUGH_VALIDATOR,
+            )
+        return tools
+
+    async def call_tool(
+        self,
+        name: str,
+        tool_args: dict[str, Any],
+        ctx: RunContext[Any],
+        tool: ToolsetTool[Any],
+    ) -> Any:
+        if name == "list_tables":
+            return self._list_tables()
+        if name == "get_schema":
+            return self._get_schema(tool_args["table_name"])
+        if name == "query":
+            return self._query(tool_args["sql"])
+        raise ValueError(f"Unknown tool: {name!r}")
+
+    def _list_tables(self) -> str:
+        try:
+            engine = self._get_engine()
+            tables: list[str] = 
list(engine.session_context.catalog().schema().table_names())
+            return json.dumps(tables)
+        except QueryExecutionException as ex:
+            log.warning("list_tables failed: %s", ex)
+            return json.dumps({"error": str(ex)})
+
+    def _get_schema(self, table_name: str) -> str:
+        engine = self._get_engine()
+        # session_context lookup is required here instead of 
engine.registered_tables,
+        # because registered_tables only tracks tables registered via 
datasource config.
+        # When allow_writes is enabled, the agent may create temporary 
in-memory tables
+        # that would not be captured there.
+        if not engine.session_context.table_exist(table_name):
+            return json.dumps({"error": f"Table {table_name!r} is not 
available"})
+        # Intentionally using session_context instead of engine.get_schema() —
+        # the latter returns a pre-formatted string intended for other 
operators,
+        # not a JSON-compatible format.
+        # TODO: refactor engine.get_schema() to return JSON and update this 
accordingly
+        table = engine.session_context.table(table_name)
+        columns = [{"name": f.name, "type": str(f.type)} for f in 
table.schema()]
+        return json.dumps(columns)
+
+    def _query(self, sql: str) -> str:
+        try:
+            if not self._allow_writes:
+                _validate_sql(sql)
+
+            engine = self._get_engine()
+            pydict = engine.execute_query(sql)
+            col_names = list(pydict.keys())
+            num_rows = len(next(iter(pydict.values()), []))
+
+            result: list[dict[str, Any]] = [
+                {col: pydict[col][i] for col in col_names} for i in 
range(min(num_rows, self._max_rows))
+            ]
+
+            truncated = num_rows > self._max_rows
+            output: dict[str, Any] = {"rows": result, "count": num_rows}
+            if truncated:
+                output["truncated"] = True
+                output["max_rows"] = self._max_rows
+            return json.dumps(output, default=str)
+        except (SQLSafetyError, QueryExecutionException) as ex:

Review Comment:
   `_query` converts `SQLSafetyError` into a JSON `{error: ...}` response, 
whereas `SQLToolset.query` lets `SQLSafetyError` raise (so the tool call is 
treated as failed). This inconsistency can make it harder for agent runners to 
handle failures uniformly; consider aligning behavior (e.g., re-raise 
`SQLSafetyError` and only JSON-wrap execution errors, or update both toolsets 
to use the same error contract).
   ```suggestion
           except SQLSafetyError as ex:
               log.warning("query failed SQL safety validation: %s", ex)
               raise
           except QueryExecutionException as ex:
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store 
workflows."""
+
+from __future__ import annotations
+
+import json
+import logging
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.ai.utils.sql_validation import 
SQLSafetyError, validate_sql as _validate_sql
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+    from airflow.providers.common.sql.datafusion.exceptions import 
QueryExecutionException
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+    from pydantic_ai._run_context import RunContext
+
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+log = logging.getLogger(__name__)
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the three DataFusion tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "table_name": {"type": "string", "description": "Name of the table to 
inspect."},
+    },
+    "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to execute."},
+    },
+    "required": ["sql"],
+}
+
+
+class DataFusionToolset(AbstractToolset[Any]):
+    """
+    Curated toolset that gives an LLM agent SQL access to object-storage data 
via Apache DataFusion.
+
+    Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
+    backed by
+    :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.
+
+    Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+    registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
+    local storage. Multiple configs can be registered so that SQL queries can
+    join across tables.
+
+    Requires the ``datafusion`` extra of 
``apache-airflow-providers-common-sql``.
+
+    :param datasource_configs: One or more DataFusion data-source 
configurations.
+    :param allow_writes: Allow data-modifying SQL (CREATE TABLE, CREATE VIEW,
+        INSERT INTO, etc.). Default ``False`` — only SELECT-family statements
+        are permitted.
+    :param max_rows: Maximum number of rows returned from the ``query`` tool.
+        Default ``50``.
+    """
+
+    def __init__(
+        self,
+        datasource_configs: list[DataSourceConfig],
+        *,
+        allow_writes: bool = False,
+        max_rows: int = 50,
+    ) -> None:
+        self._datasource_configs = datasource_configs
+        self._allow_writes = allow_writes
+        self._max_rows = max_rows
+        self._engine: DataFusionEngine | None = None
+
+    @property
+    def id(self) -> str:
+        suffix = "_".join(config.table_name.replace("-", "_") for config in 
self._datasource_configs)
+        return f"sql_datafusion_{suffix}"
+
+    def _get_engine(self) -> DataFusionEngine:
+        """Lazily create and configure a DataFusionEngine from 
*datasource_configs*."""
+        if self._engine is None:
+            engine = DataFusionEngine()
+            for config in self._datasource_configs:
+                engine.register_datasource(config)
+            self._engine = engine
+        return self._engine
+
+    async def get_tools(self, ctx: RunContext[Any]) -> dict[str, 
ToolsetTool[Any]]:
+        tools: dict[str, ToolsetTool[Any]] = {}
+
+        for name, description, schema in (
+            ("list_tables", "List available table names.", 
_LIST_TABLES_SCHEMA),
+            ("get_schema", "Get column names and types for a table.", 
_GET_SCHEMA_SCHEMA),
+            ("query", "Execute a SQL query and return rows as JSON.", 
_QUERY_SCHEMA),
+        ):
+            tool_def = ToolDefinition(
+                name=name,
+                description=description,
+                parameters_json_schema=schema,
+                sequential=True,
+            )
+            tools[name] = ToolsetTool(
+                toolset=self,
+                tool_def=tool_def,
+                max_retries=1,
+                args_validator=_PASSTHROUGH_VALIDATOR,
+            )
+        return tools
+
+    async def call_tool(
+        self,
+        name: str,
+        tool_args: dict[str, Any],
+        ctx: RunContext[Any],
+        tool: ToolsetTool[Any],
+    ) -> Any:
+        if name == "list_tables":
+            return self._list_tables()
+        if name == "get_schema":
+            return self._get_schema(tool_args["table_name"])
+        if name == "query":
+            return self._query(tool_args["sql"])
+        raise ValueError(f"Unknown tool: {name!r}")
+
+    def _list_tables(self) -> str:
+        try:
+            engine = self._get_engine()
+            tables: list[str] = 
list(engine.session_context.catalog().schema().table_names())
+            return json.dumps(tables)
+        except QueryExecutionException as ex:

Review Comment:
   `_list_tables` catches `QueryExecutionException`, but this method doesn’t 
call `DataFusionEngine.execute_query()` (the main place that wraps errors as 
`QueryExecutionException`). Most failures here will bypass this handler. Either 
remove the try/except or catch the exception types that can actually be raised 
here (or catch `Exception` and return a structured error consistently).
   ```suggestion
           except Exception as ex:
   ```



##########
providers/common/ai/tests/unit/common/ai/toolsets/test_datafusion.py:
##########
@@ -0,0 +1,300 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import json
+from unittest.mock import MagicMock, patch
+
+import pytest
+
+from airflow.providers.common.ai.toolsets.datafusion import DataFusionToolset
+
+
+def _make_mock_datasource_config(table_name: str = "sales_data"):
+    """Create a mock DataSourceConfig."""
+    from airflow.providers.common.sql.config import DataSourceConfig
+

Review Comment:
   Imports inside helper functions make dependency and linting behavior harder 
to reason about. Since this test module already depends on 
`airflow.providers.common.sql` via `DataFusionToolset`, consider moving the 
`DataSourceConfig` import to module scope (or add a comment explaining why it 
must be lazy).
   ```suggestion
   from airflow.providers.common.sql.config import DataSourceConfig
   
   
   def _make_mock_datasource_config(table_name: str = "sales_data"):
       """Create a mock DataSourceConfig."""
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -0,0 +1,203 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store 
workflows."""
+
+from __future__ import annotations
+
+import json
+import logging
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.ai.utils.sql_validation import 
SQLSafetyError, validate_sql as _validate_sql
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+    from airflow.providers.common.sql.datafusion.exceptions import 
QueryExecutionException
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+    from pydantic_ai._run_context import RunContext
+
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+log = logging.getLogger(__name__)
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the three DataFusion tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "table_name": {"type": "string", "description": "Name of the table to 
inspect."},
+    },
+    "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to execute."},
+    },
+    "required": ["sql"],
+}
+
+
+class DataFusionToolset(AbstractToolset[Any]):
+    """
+    Curated toolset that gives an LLM agent SQL access to object-storage data 
via Apache DataFusion.
+
+    Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
+    backed by
+    :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.
+
+    Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+    registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
+    local storage. Multiple configs can be registered so that SQL queries can
+    join across tables.
+
+    Requires the ``datafusion`` extra of 
``apache-airflow-providers-common-sql``.
+
+    :param datasource_configs: One or more DataFusion data-source 
configurations.
+    :param allow_writes: Allow data-modifying SQL (CREATE TABLE, CREATE VIEW,
+        INSERT INTO, etc.). Default ``False`` — only SELECT-family statements
+        are permitted.
+    :param max_rows: Maximum number of rows returned from the ``query`` tool.
+        Default ``50``.
+    """
+
+    def __init__(
+        self,
+        datasource_configs: list[DataSourceConfig],
+        *,
+        allow_writes: bool = False,
+        max_rows: int = 50,
+    ) -> None:

Review Comment:
   `DataFusionToolset.__init__` documents `datasource_configs` as “one or 
more”, but there’s no validation for an empty list. With `[]`, `id` becomes 
`sql_datafusion_` and the toolset is effectively unusable. Consider raising a 
`ValueError` when `datasource_configs` is empty.
   ```suggestion
       ) -> None:
           if not datasource_configs:
               raise ValueError("datasource_configs must contain at least one 
DataSourceConfig")
   ```



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to