kaxil commented on code in PR #62793:
URL: https://github.com/apache/airflow/pull/62793#discussion_r2880993764


##########
providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py:
##########
@@ -0,0 +1,316 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operator for cross-system schema drift detection powered by LLM 
reasoning."""
+
+from __future__ import annotations
+
+import json
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Literal
+
+from pydantic import BaseModel, Field
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+
+if TYPE_CHECKING:
+    from airflow.providers.common.sql.config import DataSourceConfig
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+    from airflow.sdk import Context
+
+
+class SchemaMismatch(BaseModel):
+    """A single schema mismatch between data sources."""
+
+    source: str = Field(description="Source table")
+    target: str = Field(description="Target table")
+    column: str = Field(description="Column name where the mismatch was 
detected")
+    source_type: str = Field(description="Data type in the source system")
+    target_type: str = Field(description="Data type in the target system")
+    severity: Literal["critical", "warning", "info"] = 
Field(description="Mismatch severity")
+    description: str = Field(description="Human-readable description of the 
mismatch")
+    suggested_action: str = Field(description="Recommended action to resolve 
the mismatch")
+    migration_query: str = Field(description="Provide migration query to 
resolve the mismatch")
+
+
+class SchemaCompareResult(BaseModel):
+    """Structured output from schema comparison."""
+
+    compatible: bool = Field(description="Whether the schemas are compatible 
for data loading")
+    mismatches: list[SchemaMismatch] = Field(default_factory=list)
+    summary: str = Field(description="High-level summary of the comparison")
+
+
+DEFAULT_SYSTEM_PROMPT = (
+    "Consider cross-system type equivalences:\n"
+    "- varchar(n) / text / string / TEXT may be compatible\n"
+    "- int / integer / int4 / INT32 are equivalent\n"
+    "- bigint / int8 / int64 / BIGINT are equivalent\n"
+    "- timestamp / timestamptz / TIMESTAMP_NTZ / datetime may differ in 
timezone handling\n"
+    "- numeric(p,s) / decimal(p,s) / NUMBER — check precision and scale\n"
+    "- boolean / bool / BOOLEAN / tinyint(1) — check semantic equivalence\n\n"
+    "Severity levels:\n"
+    "- critical: Will cause data loading failures or data loss "
+    "(e.g., column missing in target, incompatible types)\n"
+    "- warning: May cause data quality issues "
+    "(e.g., precision loss, timezone mismatch)\n"
+    "- info: Cosmetic differences that won't affect data loading "
+    "(e.g., varchar length differences within safe range)\n\n"
+)
+
+
+class LLMSchemaCompareOperator(LLMOperator):
+    """
+    Compare schemas across different database systems and detect drift using 
LLM reasoning.
+
+    The LLM handles complex cross-system type mapping that simple equality 
checks
+    miss (e.g., ``varchar(255)`` vs ``string``, ``timestamp`` vs 
``timestamptz``).
+
+    Accepts data sources via two patterns:
+
+    1. **data_sources** — a list of
+       :class:`~airflow.providers.common.sql.config.DataSourceConfig` for each
+       system. If the connection resolves to a
+       :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`, schema is
+       introspected via SQLAlchemy; otherwise DataFusion is used.
+    2. **db_conn_ids + table_names** — shorthand for comparing the same table
+       across multiple database connections (all must resolve to 
``DbApiHook``).
+
+    :param prompt: Instructions for the LLM on what to compare and flag.
+    :param llm_conn_id: Connection ID for the LLM provider.
+    :param model_id: Model identifier (e.g. ``"openai:gpt-5"``).
+    :param system_prompt: Additional instructions appended to the built-in
+        schema comparison prompt.
+    :param agent_params: Extra keyword arguments for the pydantic-ai ``Agent``.
+    :param data_sources: List of DataSourceConfig objects, one per system.
+    :param db_conn_ids: Connection IDs for databases to compare (used with
+        ``table_names``).
+    :param table_names: Tables to introspect from each ``db_conn_id``.
+    :param context_strategy: ``"basic"`` for column names and types only;
+        ``"full"`` to include primary keys, foreign keys, and indexes.
+        Default ``"full"``.
+    """
+
+    template_fields: Sequence[str] = (
+        *LLMOperator.template_fields,
+        "data_sources",
+        "db_conn_ids",
+        "table_names",
+        "context_strategy",
+    )
+
+    def __init__(
+        self,
+        *,
+        data_sources: list[DataSourceConfig] | None = None,
+        db_conn_ids: list[str] | None = None,
+        table_names: list[str] | None = None,
+        context_strategy: Literal["basic", "full"] = "full",
+        system_prompt: str = DEFAULT_SYSTEM_PROMPT,

Review Comment:
   Design question: since `system_prompt` defaults to `DEFAULT_SYSTEM_PROMPT` 
(type equivalences + severity definitions), passing `system_prompt="focus on 
timestamp columns"` replaces those defaults entirely rather than extending them.
   
   A user who wants both the type-equivalence hints AND custom instructions has 
to manually concatenate:
   ```python
   from airflow.providers.common.ai.operators.llm_schema_compare import 
DEFAULT_SYSTEM_PROMPT
   
   LLMSchemaCompareOperator(
       system_prompt=DEFAULT_SYSTEM_PROMPT + "\nAlso focus on timestamp 
columns.",
       ...
   )
   ```
   
   This is a reasonable trade-off (simple, explicit), but worth documenting in 
the docstring so users know `system_prompt` overrides the default rather than 
appending to it. The current docstring says "Additional instructions appended 
to the built-in schema comparison prompt" which implies additive behavior.



##########
providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py:
##########
@@ -0,0 +1,316 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operator for cross-system schema drift detection powered by LLM 
reasoning."""
+
+from __future__ import annotations
+
+import json
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Literal
+
+from pydantic import BaseModel, Field
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+
+if TYPE_CHECKING:
+    from airflow.providers.common.sql.config import DataSourceConfig
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+    from airflow.sdk import Context
+
+
+class SchemaMismatch(BaseModel):
+    """A single schema mismatch between data sources."""
+
+    source: str = Field(description="Source table")
+    target: str = Field(description="Target table")
+    column: str = Field(description="Column name where the mismatch was 
detected")
+    source_type: str = Field(description="Data type in the source system")
+    target_type: str = Field(description="Data type in the target system")
+    severity: Literal["critical", "warning", "info"] = 
Field(description="Mismatch severity")
+    description: str = Field(description="Human-readable description of the 
mismatch")
+    suggested_action: str = Field(description="Recommended action to resolve 
the mismatch")
+    migration_query: str = Field(description="Provide migration query to 
resolve the mismatch")
+
+
+class SchemaCompareResult(BaseModel):
+    """Structured output from schema comparison."""
+
+    compatible: bool = Field(description="Whether the schemas are compatible 
for data loading")
+    mismatches: list[SchemaMismatch] = Field(default_factory=list)
+    summary: str = Field(description="High-level summary of the comparison")
+
+
+DEFAULT_SYSTEM_PROMPT = (
+    "Consider cross-system type equivalences:\n"
+    "- varchar(n) / text / string / TEXT may be compatible\n"
+    "- int / integer / int4 / INT32 are equivalent\n"
+    "- bigint / int8 / int64 / BIGINT are equivalent\n"
+    "- timestamp / timestamptz / TIMESTAMP_NTZ / datetime may differ in 
timezone handling\n"
+    "- numeric(p,s) / decimal(p,s) / NUMBER — check precision and scale\n"
+    "- boolean / bool / BOOLEAN / tinyint(1) — check semantic equivalence\n\n"
+    "Severity levels:\n"
+    "- critical: Will cause data loading failures or data loss "
+    "(e.g., column missing in target, incompatible types)\n"
+    "- warning: May cause data quality issues "
+    "(e.g., precision loss, timezone mismatch)\n"
+    "- info: Cosmetic differences that won't affect data loading "
+    "(e.g., varchar length differences within safe range)\n\n"
+)
+
+
+class LLMSchemaCompareOperator(LLMOperator):
+    """
+    Compare schemas across different database systems and detect drift using 
LLM reasoning.
+
+    The LLM handles complex cross-system type mapping that simple equality 
checks
+    miss (e.g., ``varchar(255)`` vs ``string``, ``timestamp`` vs 
``timestamptz``).
+
+    Accepts data sources via two patterns:
+
+    1. **data_sources** — a list of
+       :class:`~airflow.providers.common.sql.config.DataSourceConfig` for each
+       system. If the connection resolves to a
+       :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`, schema is
+       introspected via SQLAlchemy; otherwise DataFusion is used.
+    2. **db_conn_ids + table_names** — shorthand for comparing the same table
+       across multiple database connections (all must resolve to 
``DbApiHook``).
+
+    :param prompt: Instructions for the LLM on what to compare and flag.
+    :param llm_conn_id: Connection ID for the LLM provider.
+    :param model_id: Model identifier (e.g. ``"openai:gpt-5"``).
+    :param system_prompt: Additional instructions appended to the built-in
+        schema comparison prompt.
+    :param agent_params: Extra keyword arguments for the pydantic-ai ``Agent``.
+    :param data_sources: List of DataSourceConfig objects, one per system.
+    :param db_conn_ids: Connection IDs for databases to compare (used with
+        ``table_names``).
+    :param table_names: Tables to introspect from each ``db_conn_id``.
+    :param context_strategy: ``"basic"`` for column names and types only;
+        ``"full"`` to include primary keys, foreign keys, and indexes.
+        Default ``"full"``.
+    """
+
+    template_fields: Sequence[str] = (
+        *LLMOperator.template_fields,
+        "data_sources",
+        "db_conn_ids",
+        "table_names",
+        "context_strategy",
+    )
+
+    def __init__(
+        self,
+        *,
+        data_sources: list[DataSourceConfig] | None = None,
+        db_conn_ids: list[str] | None = None,
+        table_names: list[str] | None = None,
+        context_strategy: Literal["basic", "full"] = "full",
+        system_prompt: str = DEFAULT_SYSTEM_PROMPT,
+        **kwargs: Any,
+    ) -> None:
+        kwargs.pop("output_type", None)
+        super().__init__(**kwargs)
+        self.data_sources = data_sources or []
+        self.db_conn_ids = db_conn_ids or []
+        self.table_names = table_names or []
+        self.context_strategy = context_strategy
+        self.system_prompt = system_prompt
+
+        if not self.data_sources and not self.db_conn_ids:
+            raise ValueError("Provide at least one of 'data_sources' or 
'db_conn_ids'.")
+
+        if self.db_conn_ids and not self.table_names:
+            raise ValueError("'table_names' is required when using 
'db_conn_ids'.")
+
+        total_sources = len(self.db_conn_ids) + len(self.data_sources)
+        if total_sources < 2:
+            raise ValueError(
+                "Provide at-least two combinations of 'db_conn_ids' and 
'table_names' or 'data_sources' "
+                "to compare."
+            )
+
+    @staticmethod
+    def _get_db_hook(conn_id: str) -> DbApiHook:
+        """Resolve a connection ID to a DbApiHook."""
+        from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+        connection = BaseHook.get_connection(conn_id)
+        hook = connection.get_hook()
+        if not isinstance(hook, DbApiHook):
+            raise ValueError(
+                f"Connection {conn_id!r} does not provide a DbApiHook. Got 
{type(hook).__name__}."
+            )
+        return hook
+
+    def _is_dbapi_connection(self, conn_id: str) -> bool:
+        """Check whether a connection resolves to a DbApiHook."""
+        from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+        try:
+            connection = BaseHook.get_connection(conn_id)
+            hook = connection.get_hook()
+            return isinstance(hook, DbApiHook)
+        except (AirflowException, ValueError) as exc:
+            self.log.debug("Connection %s does not resolve to a DbApiHook: 
%s", conn_id, exc, exc_info=True)
+            return False
+
+    def _introspect_db_schema(self, hook: DbApiHook, table_name: str) -> str:
+        """Introspect schema from a database connection via DbApiHook."""
+        columns = hook.get_table_schema(table_name)
+        if not columns:
+            self.log.warning("Table %r returned no columns — it may not 
exist.", table_name)
+            return ""
+
+        col_info = ", ".join(f"{c['name']} {c['type']}" for c in columns)
+        parts = [f"Columns: {col_info}"]
+
+        if self.context_strategy == "full":
+            try:
+                pks = hook.dialect.get_primary_keys(table_name)
+                if pks:
+                    parts.append(f"Primary Key: {', '.join(pks)}")
+            except NotImplementedError:
+                self.log.warning(
+                    "primary key introspection not implemented for dialect 
%s", hook.dialect_name
+                )
+            except Exception as ex:
+                self.log.warning("Could not retrieve PK %s", ex)

Review Comment:
   nit: These warnings would be more useful with the table name for context. 
When comparing multiple tables, "Could not retrieve PK" doesn't tell you which 
one failed.
   
   ```python
   self.log.warning("Could not retrieve PK for %r: %s", table_name, ex)
   ```
   
   Same for FK (line 207) and index (line 219) warnings.



##########
providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py:
##########
@@ -0,0 +1,316 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Operator for cross-system schema drift detection powered by LLM 
reasoning."""
+
+from __future__ import annotations
+
+import json
+from collections.abc import Sequence
+from functools import cached_property
+from typing import TYPE_CHECKING, Any, Literal
+
+from pydantic import BaseModel, Field
+
+from airflow.providers.common.ai.operators.llm import LLMOperator
+from airflow.providers.common.compat.sdk import AirflowException, BaseHook
+
+if TYPE_CHECKING:
+    from airflow.providers.common.sql.config import DataSourceConfig
+    from airflow.providers.common.sql.hooks.sql import DbApiHook
+    from airflow.sdk import Context
+
+
+class SchemaMismatch(BaseModel):
+    """A single schema mismatch between data sources."""
+
+    source: str = Field(description="Source table")
+    target: str = Field(description="Target table")
+    column: str = Field(description="Column name where the mismatch was 
detected")
+    source_type: str = Field(description="Data type in the source system")
+    target_type: str = Field(description="Data type in the target system")
+    severity: Literal["critical", "warning", "info"] = 
Field(description="Mismatch severity")
+    description: str = Field(description="Human-readable description of the 
mismatch")
+    suggested_action: str = Field(description="Recommended action to resolve 
the mismatch")
+    migration_query: str = Field(description="Provide migration query to 
resolve the mismatch")
+
+
+class SchemaCompareResult(BaseModel):
+    """Structured output from schema comparison."""
+
+    compatible: bool = Field(description="Whether the schemas are compatible 
for data loading")
+    mismatches: list[SchemaMismatch] = Field(default_factory=list)
+    summary: str = Field(description="High-level summary of the comparison")
+
+
+DEFAULT_SYSTEM_PROMPT = (
+    "Consider cross-system type equivalences:\n"
+    "- varchar(n) / text / string / TEXT may be compatible\n"
+    "- int / integer / int4 / INT32 are equivalent\n"
+    "- bigint / int8 / int64 / BIGINT are equivalent\n"
+    "- timestamp / timestamptz / TIMESTAMP_NTZ / datetime may differ in 
timezone handling\n"
+    "- numeric(p,s) / decimal(p,s) / NUMBER — check precision and scale\n"
+    "- boolean / bool / BOOLEAN / tinyint(1) — check semantic equivalence\n\n"
+    "Severity levels:\n"
+    "- critical: Will cause data loading failures or data loss "
+    "(e.g., column missing in target, incompatible types)\n"
+    "- warning: May cause data quality issues "
+    "(e.g., precision loss, timezone mismatch)\n"
+    "- info: Cosmetic differences that won't affect data loading "
+    "(e.g., varchar length differences within safe range)\n\n"
+)
+
+
+class LLMSchemaCompareOperator(LLMOperator):
+    """
+    Compare schemas across different database systems and detect drift using 
LLM reasoning.
+
+    The LLM handles complex cross-system type mapping that simple equality 
checks
+    miss (e.g., ``varchar(255)`` vs ``string``, ``timestamp`` vs 
``timestamptz``).
+
+    Accepts data sources via two patterns:
+
+    1. **data_sources** — a list of
+       :class:`~airflow.providers.common.sql.config.DataSourceConfig` for each
+       system. If the connection resolves to a
+       :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`, schema is
+       introspected via SQLAlchemy; otherwise DataFusion is used.
+    2. **db_conn_ids + table_names** — shorthand for comparing the same table
+       across multiple database connections (all must resolve to 
``DbApiHook``).
+
+    :param prompt: Instructions for the LLM on what to compare and flag.
+    :param llm_conn_id: Connection ID for the LLM provider.
+    :param model_id: Model identifier (e.g. ``"openai:gpt-5"``).
+    :param system_prompt: Additional instructions appended to the built-in
+        schema comparison prompt.
+    :param agent_params: Extra keyword arguments for the pydantic-ai ``Agent``.
+    :param data_sources: List of DataSourceConfig objects, one per system.
+    :param db_conn_ids: Connection IDs for databases to compare (used with
+        ``table_names``).
+    :param table_names: Tables to introspect from each ``db_conn_id``.
+    :param context_strategy: ``"basic"`` for column names and types only;
+        ``"full"`` to include primary keys, foreign keys, and indexes.
+        Default ``"full"``.
+    """
+
+    template_fields: Sequence[str] = (
+        *LLMOperator.template_fields,
+        "data_sources",
+        "db_conn_ids",
+        "table_names",
+        "context_strategy",
+    )
+
+    def __init__(
+        self,
+        *,
+        data_sources: list[DataSourceConfig] | None = None,
+        db_conn_ids: list[str] | None = None,
+        table_names: list[str] | None = None,
+        context_strategy: Literal["basic", "full"] = "full",
+        system_prompt: str = DEFAULT_SYSTEM_PROMPT,
+        **kwargs: Any,
+    ) -> None:
+        kwargs.pop("output_type", None)
+        super().__init__(**kwargs)
+        self.data_sources = data_sources or []
+        self.db_conn_ids = db_conn_ids or []
+        self.table_names = table_names or []
+        self.context_strategy = context_strategy
+        self.system_prompt = system_prompt
+
+        if not self.data_sources and not self.db_conn_ids:
+            raise ValueError("Provide at least one of 'data_sources' or 
'db_conn_ids'.")
+
+        if self.db_conn_ids and not self.table_names:
+            raise ValueError("'table_names' is required when using 
'db_conn_ids'.")
+
+        total_sources = len(self.db_conn_ids) + len(self.data_sources)
+        if total_sources < 2:
+            raise ValueError(
+                "Provide at-least two combinations of 'db_conn_ids' and 
'table_names' or 'data_sources' "
+                "to compare."
+            )
+
+    @staticmethod
+    def _get_db_hook(conn_id: str) -> DbApiHook:
+        """Resolve a connection ID to a DbApiHook."""
+        from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+        connection = BaseHook.get_connection(conn_id)
+        hook = connection.get_hook()
+        if not isinstance(hook, DbApiHook):
+            raise ValueError(
+                f"Connection {conn_id!r} does not provide a DbApiHook. Got 
{type(hook).__name__}."
+            )
+        return hook
+
+    def _is_dbapi_connection(self, conn_id: str) -> bool:
+        """Check whether a connection resolves to a DbApiHook."""
+        from airflow.providers.common.sql.hooks.sql import DbApiHook
+
+        try:
+            connection = BaseHook.get_connection(conn_id)
+            hook = connection.get_hook()
+            return isinstance(hook, DbApiHook)
+        except (AirflowException, ValueError) as exc:
+            self.log.debug("Connection %s does not resolve to a DbApiHook: 
%s", conn_id, exc, exc_info=True)
+            return False
+
+    def _introspect_db_schema(self, hook: DbApiHook, table_name: str) -> str:
+        """Introspect schema from a database connection via DbApiHook."""
+        columns = hook.get_table_schema(table_name)
+        if not columns:
+            self.log.warning("Table %r returned no columns — it may not 
exist.", table_name)
+            return ""
+
+        col_info = ", ".join(f"{c['name']} {c['type']}" for c in columns)
+        parts = [f"Columns: {col_info}"]
+
+        if self.context_strategy == "full":
+            try:
+                pks = hook.dialect.get_primary_keys(table_name)
+                if pks:
+                    parts.append(f"Primary Key: {', '.join(pks)}")
+            except NotImplementedError:
+                self.log.warning(
+                    "primary key introspection not implemented for dialect 
%s", hook.dialect_name
+                )
+            except Exception as ex:
+                self.log.warning("Could not retrieve PK %s", ex)
+
+            try:
+                fks = hook.inspector.get_foreign_keys(table_name)
+                for fk in fks:
+                    cols = ", ".join(fk.get("constrained_columns", []))
+                    ref = fk.get("referred_table", "?")
+                    ref_cols = ", ".join(fk.get("referred_columns", []))
+                    parts.append(f"Foreign Key: ({cols}) -> {ref}({ref_cols})")
+            except NotImplementedError:
+                self.log.warning(
+                    "foreign key introspection not implemented for dialect 
%s", hook.dialect_name
+                )
+            except Exception as ex:
+                self.log.warning("Could not retrieve FK %s", ex)
+
+            try:
+                indexes = hook.inspector.get_indexes(table_name)
+                for idx in indexes:
+                    column_names = [c for c in idx.get("column_names", []) if 
c is not None]
+                    idx_cols = ", ".join(column_names)
+                    unique = " UNIQUE" if idx.get("unique") else ""
+                    parts.append(f"Index{unique}: {idx.get('name', '?')} 
({idx_cols})")
+            except NotImplementedError:
+                self.log.warning("index introspection not implemented for 
dialect %s", hook.dialect_name)
+            except Exception as ex:
+                self.log.warning("Could not retrieve index %s", ex)
+
+            return "\n".join(parts)
+
+        if self.context_strategy == "basic":
+            return "\n".join(parts)
+
+        raise ValueError(f"Invalid context_strategy: {self.context_strategy}")
+
+    def _introspect_datasource_schema(self, ds_config: DataSourceConfig) -> 
str:
+        """Introspect schema from a DataSourceConfig, choosing DbApiHook or 
DataFusion."""
+        if self._is_dbapi_connection(ds_config.conn_id):
+            hook = self._get_db_hook(ds_config.conn_id)
+            dialect_name = getattr(hook, "dialect_name", "unknown")
+            schema_text = self._introspect_db_schema(hook, 
ds_config.table_name)
+            return (
+                f"Source: {ds_config.conn_id} ({dialect_name})\nTable: 
{ds_config.table_name}\n{schema_text}"
+            )
+
+        return self._introspect_schema_from_datafusion(ds_config)
+
+    @cached_property
+    def _df_engine(self):
+        try:
+            from airflow.providers.common.sql.datafusion.engine import 
DataFusionEngine
+        except ImportError as e:
+            from airflow.providers.common.compat.sdk import 
AirflowOptionalProviderFeatureException
+
+            raise AirflowOptionalProviderFeatureException(e)
+        engine = DataFusionEngine()
+        return engine
+
+    def _introspect_schema_from_datafusion(self, ds_config: DataSourceConfig):
+        self._df_engine.register_datasource(ds_config)
+        schema_text = self._df_engine.get_schema(ds_config.table_name)
+
+        return f"Source: {ds_config.conn_id} \nFormat: 
({ds_config.format})\nTable: {ds_config.table_name}\nColumns: {schema_text}"
+
+    def _build_schema_context(self) -> str:
+        """Collect schemas from all configured sources each clearly."""
+        sections: list[str] = []
+
+        for conn_id in self.db_conn_ids:
+            hook = self._get_db_hook(conn_id)
+            dialect_name = getattr(hook, "dialect_name", "unknown")
+            for table in self.table_names:
+                schema_text = self._introspect_db_schema(hook, table)
+                if schema_text:
+                    sections.append(f"Source: {conn_id} 
({dialect_name}\nTable: {table}\n{schema_text}")

Review Comment:
   Bug: Missing closing `)` after `{dialect_name}`.
   
   ```python
   # Current (broken):
   f"Source: {conn_id} ({dialect_name}\nTable: {table}\n{schema_text}"
   # LLM sees: "Source: postgres_default (postgresql\nTable: orders\n..."
   
   # Fix:
   f"Source: {conn_id} ({dialect_name})\nTable: {table}\n{schema_text}"
   ```
   
   The tests pass because they assert on the prefix `"Source: postgres_default 
(postgresql"` which matches either way.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to