gopidesupavan commented on code in PR #62793: URL: https://github.com/apache/airflow/pull/62793#discussion_r2881069280
########## providers/common/ai/src/airflow/providers/common/ai/operators/llm_schema_compare.py: ########## @@ -0,0 +1,316 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +"""Operator for cross-system schema drift detection powered by LLM reasoning.""" + +from __future__ import annotations + +import json +from collections.abc import Sequence +from functools import cached_property +from typing import TYPE_CHECKING, Any, Literal + +from pydantic import BaseModel, Field + +from airflow.providers.common.ai.operators.llm import LLMOperator +from airflow.providers.common.compat.sdk import AirflowException, BaseHook + +if TYPE_CHECKING: + from airflow.providers.common.sql.config import DataSourceConfig + from airflow.providers.common.sql.hooks.sql import DbApiHook + from airflow.sdk import Context + + +class SchemaMismatch(BaseModel): + """A single schema mismatch between data sources.""" + + source: str = Field(description="Source table") + target: str = Field(description="Target table") + column: str = Field(description="Column name where the mismatch was detected") + source_type: str = Field(description="Data type in the source system") + target_type: str = Field(description="Data type in the target system") + severity: Literal["critical", "warning", "info"] = Field(description="Mismatch severity") + description: str = Field(description="Human-readable description of the mismatch") + suggested_action: str = Field(description="Recommended action to resolve the mismatch") + migration_query: str = Field(description="Provide migration query to resolve the mismatch") + + +class SchemaCompareResult(BaseModel): + """Structured output from schema comparison.""" + + compatible: bool = Field(description="Whether the schemas are compatible for data loading") + mismatches: list[SchemaMismatch] = Field(default_factory=list) + summary: str = Field(description="High-level summary of the comparison") + + +DEFAULT_SYSTEM_PROMPT = ( + "Consider cross-system type equivalences:\n" + "- varchar(n) / text / string / TEXT may be compatible\n" + "- int / integer / int4 / INT32 are equivalent\n" + "- bigint / int8 / int64 / BIGINT are equivalent\n" + "- timestamp / timestamptz / TIMESTAMP_NTZ / datetime may differ in timezone handling\n" + "- numeric(p,s) / decimal(p,s) / NUMBER — check precision and scale\n" + "- boolean / bool / BOOLEAN / tinyint(1) — check semantic equivalence\n\n" + "Severity levels:\n" + "- critical: Will cause data loading failures or data loss " + "(e.g., column missing in target, incompatible types)\n" + "- warning: May cause data quality issues " + "(e.g., precision loss, timezone mismatch)\n" + "- info: Cosmetic differences that won't affect data loading " + "(e.g., varchar length differences within safe range)\n\n" +) + + +class LLMSchemaCompareOperator(LLMOperator): + """ + Compare schemas across different database systems and detect drift using LLM reasoning. + + The LLM handles complex cross-system type mapping that simple equality checks + miss (e.g., ``varchar(255)`` vs ``string``, ``timestamp`` vs ``timestamptz``). + + Accepts data sources via two patterns: + + 1. **data_sources** — a list of + :class:`~airflow.providers.common.sql.config.DataSourceConfig` for each + system. If the connection resolves to a + :class:`~airflow.providers.common.sql.hooks.sql.DbApiHook`, schema is + introspected via SQLAlchemy; otherwise DataFusion is used. + 2. **db_conn_ids + table_names** — shorthand for comparing the same table + across multiple database connections (all must resolve to ``DbApiHook``). + + :param prompt: Instructions for the LLM on what to compare and flag. + :param llm_conn_id: Connection ID for the LLM provider. + :param model_id: Model identifier (e.g. ``"openai:gpt-5"``). + :param system_prompt: Additional instructions appended to the built-in + schema comparison prompt. + :param agent_params: Extra keyword arguments for the pydantic-ai ``Agent``. + :param data_sources: List of DataSourceConfig objects, one per system. + :param db_conn_ids: Connection IDs for databases to compare (used with + ``table_names``). + :param table_names: Tables to introspect from each ``db_conn_id``. + :param context_strategy: ``"basic"`` for column names and types only; + ``"full"`` to include primary keys, foreign keys, and indexes. + Default ``"full"``. + """ + + template_fields: Sequence[str] = ( + *LLMOperator.template_fields, + "data_sources", + "db_conn_ids", + "table_names", + "context_strategy", + ) + + def __init__( + self, + *, + data_sources: list[DataSourceConfig] | None = None, + db_conn_ids: list[str] | None = None, + table_names: list[str] | None = None, + context_strategy: Literal["basic", "full"] = "full", + system_prompt: str = DEFAULT_SYSTEM_PROMPT, + **kwargs: Any, + ) -> None: + kwargs.pop("output_type", None) + super().__init__(**kwargs) + self.data_sources = data_sources or [] + self.db_conn_ids = db_conn_ids or [] + self.table_names = table_names or [] + self.context_strategy = context_strategy + self.system_prompt = system_prompt + + if not self.data_sources and not self.db_conn_ids: + raise ValueError("Provide at least one of 'data_sources' or 'db_conn_ids'.") + + if self.db_conn_ids and not self.table_names: + raise ValueError("'table_names' is required when using 'db_conn_ids'.") + + total_sources = len(self.db_conn_ids) + len(self.data_sources) + if total_sources < 2: + raise ValueError( + "Provide at-least two combinations of 'db_conn_ids' and 'table_names' or 'data_sources' " + "to compare." + ) + + @staticmethod + def _get_db_hook(conn_id: str) -> DbApiHook: + """Resolve a connection ID to a DbApiHook.""" + from airflow.providers.common.sql.hooks.sql import DbApiHook + + connection = BaseHook.get_connection(conn_id) + hook = connection.get_hook() + if not isinstance(hook, DbApiHook): + raise ValueError( + f"Connection {conn_id!r} does not provide a DbApiHook. Got {type(hook).__name__}." + ) + return hook + + def _is_dbapi_connection(self, conn_id: str) -> bool: + """Check whether a connection resolves to a DbApiHook.""" + from airflow.providers.common.sql.hooks.sql import DbApiHook + + try: + connection = BaseHook.get_connection(conn_id) + hook = connection.get_hook() + return isinstance(hook, DbApiHook) + except (AirflowException, ValueError) as exc: + self.log.debug("Connection %s does not resolve to a DbApiHook: %s", conn_id, exc, exc_info=True) + return False + + def _introspect_db_schema(self, hook: DbApiHook, table_name: str) -> str: + """Introspect schema from a database connection via DbApiHook.""" + columns = hook.get_table_schema(table_name) + if not columns: + self.log.warning("Table %r returned no columns — it may not exist.", table_name) + return "" + + col_info = ", ".join(f"{c['name']} {c['type']}" for c in columns) + parts = [f"Columns: {col_info}"] + + if self.context_strategy == "full": + try: + pks = hook.dialect.get_primary_keys(table_name) + if pks: + parts.append(f"Primary Key: {', '.join(pks)}") + except NotImplementedError: + self.log.warning( + "primary key introspection not implemented for dialect %s", hook.dialect_name + ) + except Exception as ex: + self.log.warning("Could not retrieve PK %s", ex) Review Comment: updated. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected]
