gopidesupavan commented on code in PR #62850:
URL: https://github.com/apache/airflow/pull/62850#discussion_r2886174630


##########
providers/common/ai/src/airflow/providers/common/ai/toolsets/datafusion.py:
##########
@@ -0,0 +1,172 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Curated SQL toolset wrapping DataFusionEngine for agentic object-store 
workflows."""
+
+from __future__ import annotations
+
+import json
+from typing import TYPE_CHECKING, Any
+
+try:
+    from airflow.providers.common.sql.datafusion.engine import DataFusionEngine
+except ImportError as e:
+    from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+    raise AirflowOptionalProviderFeatureException(e)
+
+from pydantic_ai.tools import ToolDefinition
+from pydantic_ai.toolsets.abstract import AbstractToolset, ToolsetTool
+from pydantic_core import SchemaValidator, core_schema
+
+if TYPE_CHECKING:
+    from pydantic_ai._run_context import RunContext
+
+    from airflow.providers.common.sql.config import DataSourceConfig
+
+_PASSTHROUGH_VALIDATOR = SchemaValidator(core_schema.any_schema())
+
+# JSON Schemas for the three DataFusion tools.
+_LIST_TABLES_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {},
+}
+
+_GET_SCHEMA_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "table_name": {"type": "string", "description": "Name of the table to 
inspect."},
+    },
+    "required": ["table_name"],
+}
+
+_QUERY_SCHEMA: dict[str, Any] = {
+    "type": "object",
+    "properties": {
+        "sql": {"type": "string", "description": "SQL query to execute."},
+    },
+    "required": ["sql"],
+}
+
+
+class DataFusionToolset(AbstractToolset[Any]):
+    """
+    Curated toolset that gives an LLM agent SQL access to object-storage data 
vi Apache DataFusion.
+
+    Provides three tools — ``list_tables``, ``get_schema``, and ``query`` —
+    backed by
+    :class:`~airflow.providers.common.sql.datafusion.engine.DataFusionEngine`.
+
+    Each :class:`~airflow.providers.common.sql.config.DataSourceConfig` entry
+    registers a table backed by Parquet, CSV, Avro, or Iceberg data on S3 or
+    local storage. Multiple configs can be registered so that SQL queries can
+    join across tables.
+
+    Requires the ``datafusion`` extra of 
``apache-airflow-providers-common-sql``.
+
+    :param datasource_configs: One or more DataFusion data-source 
configurations.
+    :param max_rows: Maximum number of rows returned from the ``query`` tool.
+        Default ``50``.
+    """
+
+    def __init__(
+        self,
+        datasource_configs: list[DataSourceConfig],
+        *,
+        max_rows: int = 50,
+    ) -> None:
+        self._datasource_configs = datasource_configs
+        self._max_rows = max_rows
+        self._engine: DataFusionEngine | None = None
+
+    @property
+    def id(self) -> str:
+        table_names = sorted(c.table_name for c in self._datasource_configs)
+        return f"sql-datafusion-{'-'.join(table_names)}"
+
+    def _get_engine(self) -> DataFusionEngine:
+        """Lazily create and configure a DataFusionEngine from 
*datasource_configs*."""
+        if self._engine is None:
+            engine = DataFusionEngine()
+            for config in self._datasource_configs:
+                engine.register_datasource(config)
+            self._engine = engine
+        return self._engine
+
+    async def get_tools(self, ctx: RunContext[Any]) -> dict[str, 
ToolsetTool[Any]]:
+        tools: dict[str, ToolsetTool[Any]] = {}
+
+        for name, description, schema in (
+            ("list_tables", "List available table names.", 
_LIST_TABLES_SCHEMA),
+            ("get_schema", "Get column names and types for a table.", 
_GET_SCHEMA_SCHEMA),
+            ("query", "Execute a SQL query and return rows as JSON.", 
_QUERY_SCHEMA),
+        ):
+            tool_def = ToolDefinition(
+                name=name,
+                description=description,
+                parameters_json_schema=schema,
+                sequential=True,
+            )
+            tools[name] = ToolsetTool(
+                toolset=self,
+                tool_def=tool_def,
+                max_retries=1,
+                args_validator=_PASSTHROUGH_VALIDATOR,
+            )
+        return tools
+
+    async def call_tool(
+        self,
+        name: str,
+        tool_args: dict[str, Any],
+        ctx: RunContext[Any],
+        tool: ToolsetTool[Any],
+    ) -> Any:
+        if name == "list_tables":
+            return self._list_tables()
+        if name == "get_schema":
+            return self._get_schema(tool_args["table_name"])
+        if name == "query":
+            return self._query(tool_args["sql"])
+        raise ValueError(f"Unknown tool: {name!r}")
+
+    def _list_tables(self) -> str:
+        engine = self._get_engine()
+        tables: list[str] = list(engine.registered_tables.keys())
+        return json.dumps(tables)
+

Review Comment:
   handled query exceptions now.
   
   Regarding second point, will handle..



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to