kaxil commented on code in PR #62816:
URL: https://github.com/apache/airflow/pull/62816#discussion_r2908398829


##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -157,13 +190,238 @@ def test_connection(self) -> tuple[bool, str]:
         """
         Test connection by resolving the model.
 
-        Validates that the model string is valid, the provider package is
-        installed, and the provider class can be instantiated. Does NOT make an
-        LLM API call — that would be expensive, flaky, and fail for reasons
-        unrelated to connectivity (quotas, billing, rate limits).
+        Validates that the model string is valid and the provider class can be
+        instantiated with the supplied credentials.  Does NOT make an LLM API
+        call — that would be expensive and fail for reasons unrelated to
+        connectivity (quotas, billing, rate limits).
         """
         try:
             self.get_conn()
             return True, "Model resolved successfully."
         except Exception as e:
             return False, str(e)
+
+
+class PydanticAIAzureHook(PydanticAIHook):
+    """
+    Hook for Azure OpenAI via pydantic-ai.
+
+    Connection fields:
+        - **password**: Azure API key
+        - **host**: Azure endpoint (e.g. 
``https://<resource>.openai.azure.com``)
+        - **extra** JSON::
+
+            {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"}
+
+    :param llm_conn_id: Airflow connection ID.
+    :param model_id: Model identifier, e.g. ``"azure:gpt-4o"``.
+    """
+
+    conn_type = "pydanticai_azure"
+    default_conn_name = "pydanticai_azure_default"
+    hook_name = "Pydantic AI (Azure OpenAI)"
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Return custom field behaviour for the Airflow connection form."""
+        return {
+            "hidden_fields": ["schema", "port", "login"],
+            "relabeling": {"password": "API Key", "host": "Azure Endpoint"},
+            "placeholders": {
+                "host": "https://<resource>.openai.azure.com",
+                "extra": '{"model": "azure:gpt-4o", "api_version": 
"2024-07-01-preview"}',
+            },
+        }
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        kwargs: dict[str, Any] = {}
+        if api_key:
+            kwargs["api_key"] = api_key
+        if base_url:
+            kwargs["azure_endpoint"] = base_url
+        if extra.get("api_version"):
+            kwargs["api_version"] = extra["api_version"]
+        return kwargs
+
+
+class PydanticAIBedrockHook(PydanticAIHook):
+    """
+    Hook for AWS Bedrock via pydantic-ai.
+
+    Credentials are resolved in order:
+
+    1. IAM keys from ``extra`` (``aws_access_key_id`` + 
``aws_secret_access_key``,
+       optionally ``aws_session_token``).
+    2. Bearer token in ``extra`` (``api_key``, maps to env 
``AWS_BEARER_TOKEN_BEDROCK``).
+    3. Environment-variable / instance-role chain (``AWS_PROFILE``, IAM role, 
…)
+       when no explicit keys are provided.
+
+    Connection fields:
+        - **extra** JSON::
+
+            {
+              "model": "bedrock:us.anthropic.claude-opus-4-5",
+              "region_name": "us-east-1",
+              "aws_access_key_id": "AKIA...",
+              "aws_secret_access_key": "...",
+              "aws_session_token": "...",
+              "profile_name": "my-aws-profile",
+              "api_key": "bearer-token",
+              "base_url": "https://custom-bedrock-endpoint";,
+              "aws_read_timeout": 60.0,
+              "aws_connect_timeout": 10.0
+            }
+
+          Leave ``aws_access_key_id`` / ``aws_secret_access_key`` and 
``api_key``
+          empty to use the default AWS credential chain.
+
+    :param llm_conn_id: Airflow connection ID.
+    :param model_id: Model identifier, e.g. 
``"bedrock:us.anthropic.claude-opus-4-5"``.
+    """
+
+    conn_type = "pydanticai_bedrock"
+    default_conn_name = "pydanticai_bedrock_default"
+    hook_name = "Pydantic AI (AWS Bedrock)"
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Return custom field behaviour for the Airflow connection form."""
+        return {
+            "hidden_fields": ["schema", "port", "login", "host", "password"],
+            "relabeling": {},
+            "placeholders": {
+                "extra": (
+                    '{"model": "bedrock:us.anthropic.claude-opus-4-5", '
+                    '"region_name": "us-east-1"}'
+                    "  — leave aws_access_key_id empty for IAM role / env-var 
auth"
+                ),
+            },
+        }
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        """
+        Return kwargs for ``BedrockProvider``.
+
+        .. note::
+            The ``api_key`` and ``base_url`` parameters (sourced from
+            ``conn.password`` and ``conn.host``) are intentionally ignored 
here.
+            Bedrock connections hide those fields in the UI; all config is
+            stored in ``extra`` instead.  The ``api_key`` and ``base_url``
+            keys below refer to *extra* fields, not the method parameters.
+        """
+        _str_keys = (
+            "aws_access_key_id",
+            "aws_secret_access_key",
+            "aws_session_token",
+            "region_name",
+            "profile_name",
+            # Bearer-token auth (alternative to IAM key/secret).
+            # Maps to AWS_BEARER_TOKEN_BEDROCK env var.
+            "api_key",
+            # Custom Bedrock runtime endpoint.
+            "base_url",
+        )
+        kwargs: dict[str, Any] = {k: extra[k] for k in _str_keys if 
extra.get(k) is not None}

Review Comment:
   The base class uses truthiness checks to filter empty values (`if api_key:`, 
`if base_url:`), but here `is not None` lets empty strings through. If a user 
clears a conn-field in the UI, `aws_access_key_id=""` would be passed to 
`BedrockProvider`, causing confusing auth failures instead of falling through 
to environment-based auth.
   
   Same pattern appears in `PydanticAIVertexHook._get_provider_kwargs` (line 
~409).
   
   Suggestion: use truthiness to match the base class:
   ```python
   kwargs = {k: extra[k] for k in _str_keys if extra.get(k)}
   ```



##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -75,62 +78,92 @@ def get_ui_field_behaviour() -> dict[str, Any]:
             "hidden_fields": ["schema", "port", "login"],
             "relabeling": {"password": "API Key"},
             "placeholders": {
-                "host": "https://api.openai.com/v1 (optional, for custom 
endpoints)",
+                "host": "https://api.openai.com/v1  (optional, for custom 
endpoints / Ollama)",
+                "extra": '{"model": "openai:gpt-5.3"}',
             },
         }
 
+    # ------------------------------------------------------------------
+    # Core connection / agent API
+    # ------------------------------------------------------------------
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        """
+        Return the kwargs to pass to the provider constructor.
+
+        Subclasses override this method to map their connection fields to the
+        parameters expected by their specific provider class.  The base
+        implementation handles the common ``api_key`` / ``base_url`` pattern
+        used by OpenAI, Anthropic, Groq, Mistral, Ollama, and most other
+        providers.
+
+        :param api_key: Value of ``conn.password``.
+        :param base_url: Value of ``conn.host``.
+        :param extra: Deserialized ``conn.extra`` JSON.
+        :return: Kwargs forwarded to ``provider_cls(**kwargs)``.  Empty dict
+            signals that no explicit credentials are available and the hook
+            should fall back to environment-variable–based auth.
+        """
+        kwargs: dict[str, Any] = {}
+        if api_key:
+            kwargs["api_key"] = api_key
+        if base_url:
+            kwargs["base_url"] = base_url
+        return kwargs
+
     def get_conn(self) -> Model:
         """
-        Return a configured pydantic-ai Model.
+        Return a configured pydantic-ai ``Model``.
 
-        Reads API key from connection password, base_url from connection host,
-        and model from (in priority order):
+        Resolution order:
 
-        1. ``model_id`` parameter on the hook
-        2. ``extra["model"]`` on the connection (set by the "Model" conn-field 
in the UI)
+        1. **Explicit credentials** — when :meth:`_get_provider_kwargs` returns
+           a non-empty dict the provider class is instantiated with those 
kwargs
+           and wrapped in a ``provider_factory``.
+        2. **Default resolution** — delegates to pydantic-ai ``infer_model``
+           which reads standard env vars (``OPENAI_API_KEY``, ``AWS_PROFILE``, 
…).
 
-        The result is cached for the lifetime of this hook instance.
+        The resolved model is cached for the lifetime of this hook instance.
         """
         if self._model is not None:
             return self._model
 
         conn = self.get_connection(self.llm_conn_id)
-        model_name: str | KnownModelName = self.model_id or 
conn.extra_dejson.get("model", "")
+
+        extra: dict[str, Any] = conn.extra_dejson
+        model_name: str | KnownModelName = self.model_id or extra.get("model", 
"")
         if not model_name:
             raise ValueError(
                 "No model specified. Set model_id on the hook or the Model 
field on the connection."
             )
-        api_key = conn.password
-        base_url = conn.host or None
 
-        if not api_key and not base_url:
-            # No credentials to inject — use default provider resolution
-            # (picks up env vars like OPENAI_API_KEY, AWS_PROFILE, etc.)
-            self._model = infer_model(model_name)
+        api_key: str | None = conn.password or None
+        base_url: str | None = conn.host or None
+
+        provider_kwargs = self._get_provider_kwargs(api_key, base_url, extra)
+        if provider_kwargs:
+            _kwargs = provider_kwargs  # capture for closure
+            self.log.info(
+                "Using explicit credentials for provider with model '%s': %s",
+                model_name,
+                list(provider_kwargs),
+            )
+
+            def _provider_factory(pname: str) -> Any:
+                try:
+                    return infer_provider_class(pname)(**_kwargs)
+                except TypeError:

Review Comment:
   This silently discards the explicit credentials from `_get_provider_kwargs` 
and falls back to env-var auth. If a subclass carefully builds provider kwargs 
but the provider class's constructor signature changes, the user would silently 
get env-var auth instead of the credentials they configured — hard to debug.
   
   Consider logging a warning here:
   ```python
   except TypeError:
       self.log.warning(
           "Provider '%s' rejected kwargs %s; falling back to env-var auth",
           pname, list(_kwargs),
       )
       return infer_provider(pname)
   ```
   
   (The closure can capture `self` — just define `_provider_factory` as a 
regular closure instead of standalone, or pass a logger reference.)



##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -157,13 +190,238 @@ def test_connection(self) -> tuple[bool, str]:
         """
         Test connection by resolving the model.
 
-        Validates that the model string is valid, the provider package is
-        installed, and the provider class can be instantiated. Does NOT make an
-        LLM API call — that would be expensive, flaky, and fail for reasons
-        unrelated to connectivity (quotas, billing, rate limits).
+        Validates that the model string is valid and the provider class can be
+        instantiated with the supplied credentials.  Does NOT make an LLM API
+        call — that would be expensive and fail for reasons unrelated to
+        connectivity (quotas, billing, rate limits).
         """
         try:
             self.get_conn()
             return True, "Model resolved successfully."
         except Exception as e:
             return False, str(e)
+
+
+class PydanticAIAzureHook(PydanticAIHook):
+    """
+    Hook for Azure OpenAI via pydantic-ai.
+
+    Connection fields:
+        - **password**: Azure API key
+        - **host**: Azure endpoint (e.g. 
``https://<resource>.openai.azure.com``)
+        - **extra** JSON::
+
+            {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"}
+
+    :param llm_conn_id: Airflow connection ID.
+    :param model_id: Model identifier, e.g. ``"azure:gpt-4o"``.
+    """
+
+    conn_type = "pydanticai_azure"
+    default_conn_name = "pydanticai_azure_default"
+    hook_name = "Pydantic AI (Azure OpenAI)"
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Return custom field behaviour for the Airflow connection form."""
+        return {
+            "hidden_fields": ["schema", "port", "login"],
+            "relabeling": {"password": "API Key", "host": "Azure Endpoint"},
+            "placeholders": {
+                "host": "https://<resource>.openai.azure.com",
+                "extra": '{"model": "azure:gpt-4o", "api_version": 
"2024-07-01-preview"}',
+            },
+        }
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        kwargs: dict[str, Any] = {}
+        if api_key:
+            kwargs["api_key"] = api_key
+        if base_url:
+            kwargs["azure_endpoint"] = base_url
+        if extra.get("api_version"):
+            kwargs["api_version"] = extra["api_version"]
+        return kwargs
+
+
+class PydanticAIBedrockHook(PydanticAIHook):
+    """
+    Hook for AWS Bedrock via pydantic-ai.
+
+    Credentials are resolved in order:
+
+    1. IAM keys from ``extra`` (``aws_access_key_id`` + 
``aws_secret_access_key``,
+       optionally ``aws_session_token``).
+    2. Bearer token in ``extra`` (``api_key``, maps to env 
``AWS_BEARER_TOKEN_BEDROCK``).
+    3. Environment-variable / instance-role chain (``AWS_PROFILE``, IAM role, 
…)
+       when no explicit keys are provided.
+
+    Connection fields:
+        - **extra** JSON::
+
+            {
+              "model": "bedrock:us.anthropic.claude-opus-4-5",
+              "region_name": "us-east-1",
+              "aws_access_key_id": "AKIA...",
+              "aws_secret_access_key": "...",
+              "aws_session_token": "...",
+              "profile_name": "my-aws-profile",
+              "api_key": "bearer-token",
+              "base_url": "https://custom-bedrock-endpoint";,
+              "aws_read_timeout": 60.0,
+              "aws_connect_timeout": 10.0
+            }
+
+          Leave ``aws_access_key_id`` / ``aws_secret_access_key`` and 
``api_key``
+          empty to use the default AWS credential chain.
+
+    :param llm_conn_id: Airflow connection ID.
+    :param model_id: Model identifier, e.g. 
``"bedrock:us.anthropic.claude-opus-4-5"``.
+    """
+
+    conn_type = "pydanticai_bedrock"
+    default_conn_name = "pydanticai_bedrock_default"
+    hook_name = "Pydantic AI (AWS Bedrock)"
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Return custom field behaviour for the Airflow connection form."""
+        return {
+            "hidden_fields": ["schema", "port", "login", "host", "password"],
+            "relabeling": {},
+            "placeholders": {
+                "extra": (
+                    '{"model": "bedrock:us.anthropic.claude-opus-4-5", '
+                    '"region_name": "us-east-1"}'
+                    "  — leave aws_access_key_id empty for IAM role / env-var 
auth"
+                ),
+            },
+        }
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        """
+        Return kwargs for ``BedrockProvider``.
+
+        .. note::
+            The ``api_key`` and ``base_url`` parameters (sourced from
+            ``conn.password`` and ``conn.host``) are intentionally ignored 
here.
+            Bedrock connections hide those fields in the UI; all config is
+            stored in ``extra`` instead.  The ``api_key`` and ``base_url``
+            keys below refer to *extra* fields, not the method parameters.
+        """
+        _str_keys = (
+            "aws_access_key_id",
+            "aws_secret_access_key",
+            "aws_session_token",
+            "region_name",
+            "profile_name",
+            # Bearer-token auth (alternative to IAM key/secret).
+            # Maps to AWS_BEARER_TOKEN_BEDROCK env var.
+            "api_key",
+            # Custom Bedrock runtime endpoint.
+            "base_url",
+        )
+        kwargs: dict[str, Any] = {k: extra[k] for k in _str_keys if 
extra.get(k) is not None}
+        # BedrockProvider expects float for timeout values; JSON integers must 
be coerced.
+        for _timeout_key in ("aws_read_timeout", "aws_connect_timeout"):
+            if extra.get(_timeout_key) is not None:
+                kwargs[_timeout_key] = float(extra[_timeout_key])
+        return kwargs
+
+
+class PydanticAIVertexHook(PydanticAIHook):
+    """
+    Hook for Google Vertex AI (or Generative Language API) via pydantic-ai.
+
+    Credentials are resolved in order:
+
+    1. ``service_account_info`` (JSON object) in ``extra``
+       — loaded into a ``google.auth.credentials.Credentials``
+       object and passed as ``credentials`` to ``GoogleProvider``.
+    2. ``api_key`` in ``extra`` — for Generative Language API (non-Vertex) or
+       Vertex API-key auth.
+    3. Application Default Credentials (``GOOGLE_APPLICATION_CREDENTIALS``,
+       ``gcloud auth application-default login``, Workload Identity, …) when
+       no explicit credentials are provided.
+
+    Connection fields:
+        - **extra** JSON::
+
+            {
+                "model": "google-vertex:gemini-2.0-flash",
+                "project": "my-gcp-project",
+                "location": "us-central1",
+                "service_account_info": {...},
+                "vertexai": true,
+            }
+
+        Use ``"service_account_info"`` to embed the service-account JSON 
directly
+        (as an object, not a string path).
+
+        Set ``"vertexai": true`` to force Vertex AI mode when only ``api_key`` 
is
+        provided.  Omit ``vertexai`` for the Generative Language API (GLA).
+
+    :param llm_conn_id: Airflow connection ID.
+    :param model_id: Model identifier, e.g. 
``"google-vertex:gemini-2.0-flash"``.
+    """
+
+    conn_type = "pydanticai_vertex"
+    default_conn_name = "pydanticai_vertex_default"
+    hook_name = "Pydantic AI (Google Vertex AI)"
+
+    @staticmethod
+    def get_ui_field_behaviour() -> dict[str, Any]:
+        """Return custom field behaviour for the Airflow connection form."""
+        return {
+            "hidden_fields": ["schema", "port", "login", "host", "password"],
+            "relabeling": {},
+            "placeholders": {
+                "extra": (
+                    '{"model": "google-vertex:gemini-2.0-flash", '
+                    '"project": "my-project", "location": "us-central1", 
"vertexai": true}'
+                    "  — add service_account_info (object) for SA auth;"
+                    " omit both to use Application Default Credentials"
+                ),
+            },
+        }
+
+    def _get_provider_kwargs(
+        self,
+        api_key: str | None,
+        base_url: str | None,
+        extra: dict[str, Any],
+    ) -> dict[str, Any]:
+        sa_info = extra.get("service_account_info")
+        kwargs: dict[str, Any] = {}
+
+        # Direct GoogleProvider scalar kwargs.
+        for _key in ("api_key", "project", "location", "base_url"):
+            if extra.get(_key) is not None:

Review Comment:
   Same empty-string concern as the Bedrock hook — `extra.get(_key) is not 
None` lets `""` through. Using `if extra.get(_key):` would be consistent with 
the base class's truthiness pattern.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to