cetingokhan commented on code in PR #62816:
URL: https://github.com/apache/airflow/pull/62816#discussion_r2900080636
##########
providers/common/ai/src/airflow/providers/common/ai/hooks/pydantic_ai.py:
##########
@@ -157,13 +194,266 @@ def test_connection(self) -> tuple[bool, str]:
"""
Test connection by resolving the model.
- Validates that the model string is valid, the provider package is
- installed, and the provider class can be instantiated. Does NOT make an
- LLM API call — that would be expensive, flaky, and fail for reasons
- unrelated to connectivity (quotas, billing, rate limits).
+ Validates that the model string is valid and the provider class can be
+ instantiated with the supplied credentials. Does NOT make an LLM API
+ call — that would be expensive and fail for reasons unrelated to
+ connectivity (quotas, billing, rate limits).
"""
try:
self.get_conn()
return True, "Model resolved successfully."
except Exception as e:
return False, str(e)
+
+ @classmethod
+ def for_connection(cls, conn_id: str, model_id: str | None = None) ->
PydanticAIHook:
+ """
+ Return the correct :class:`PydanticAIHook` subclass for *conn_id*.
+
+ Looks up the connection's ``conn_type`` in the registered hook map and
+ instantiates the matching subclass. Falls back to
+ :class:`PydanticAIHook` for unknown types.
+
+ :param conn_id: Airflow connection ID.
+ :param model_id: Optional model override forwarded to the hook.
+ """
+ conn = cls.get_connection(conn_id)
+ hook_cls = _CONN_TYPE_TO_HOOK.get(conn.conn_type or "", cls)
+ return hook_cls(llm_conn_id=conn_id, model_id=model_id)
+
+
+class PydanticAIAzureHook(PydanticAIHook):
+ """
+ Hook for Azure OpenAI via pydantic-ai.
+
+ Connection fields:
+ - **password**: Azure API key
+ - **host**: Azure endpoint (e.g.
``https://<resource>.openai.azure.com``)
+ - **extra** JSON::
+
+ {"model": "azure:gpt-4o", "api_version": "2024-07-01-preview"}
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g. ``"azure:gpt-4o"``.
+ """
+
+ conn_type = "pydanticai_azure"
+ hook_name = "Pydantic AI (Azure OpenAI)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login"],
+ "relabeling": {"password": "API Key", "host": "Azure Endpoint"},
+ "placeholders": {
+ "host": "https://<resource>.openai.azure.com",
+ "extra": '{"model": "azure:gpt-4o", "api_version":
"2024-07-01-preview"}',
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ kwargs: dict[str, Any] = {}
+ if api_key:
+ kwargs["api_key"] = api_key
+ if base_url:
+ kwargs["azure_endpoint"] = base_url
+ if extra.get("api_version"):
+ kwargs["api_version"] = extra["api_version"]
+ return kwargs
+
+
+class PydanticAIBedrockHook(PydanticAIHook):
+ """
+ Hook for AWS Bedrock via pydantic-ai.
+
+ Credentials are resolved in order:
+
+ 1. IAM keys from ``extra`` (``aws_access_key_id`` +
``aws_secret_access_key``,
+ optionally ``aws_session_token``).
+ 2. Bearer token in ``extra`` (``api_key``, maps to env
``AWS_BEARER_TOKEN_BEDROCK``).
+ 3. Environment-variable / instance-role chain (``AWS_PROFILE``, IAM role,
…)
+ when no explicit keys are provided.
+
+ Connection fields:
+ - **extra** JSON::
+
+ {
+ "model": "bedrock:us.anthropic.claude-opus-4-5",
+ "region_name": "us-east-1",
+ "aws_access_key_id": "AKIA...",
+ "aws_secret_access_key": "...",
+ "aws_session_token": "...",
+ "profile_name": "my-aws-profile",
+ "api_key": "bearer-token",
+ "base_url": "https://custom-bedrock-endpoint",
+ "aws_read_timeout": 60.0,
+ "aws_connect_timeout": 10.0
+ }
+
+ Leave ``aws_access_key_id`` / ``aws_secret_access_key`` and
``api_key``
+ empty to use the default AWS credential chain.
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g.
``"bedrock:us.anthropic.claude-opus-4-5"``.
+ """
+
+ conn_type = "pydanticai_bedrock"
+ hook_name = "Pydantic AI (AWS Bedrock)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login", "host", "password"],
+ "relabeling": {},
+ "placeholders": {
+ "extra": (
+ '{"model": "bedrock:us.anthropic.claude-opus-4-5", '
+ '"region_name": "us-east-1"}'
+ " — leave aws_access_key_id empty for IAM role / env-var
auth"
+ ),
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ _str_keys = (
+ "aws_access_key_id",
+ "aws_secret_access_key",
+ "aws_session_token",
+ "region_name",
+ "profile_name",
+ # Bearer-token auth (alternative to IAM key/secret).
+ # Maps to AWS_BEARER_TOKEN_BEDROCK env var.
+ "api_key",
+ # Custom Bedrock runtime endpoint.
+ "base_url",
+ )
+ kwargs: dict[str, Any] = {k: extra[k] for k in _str_keys if
extra.get(k) is not None}
+ # BedrockProvider expects float for timeout values; JSON integers must
be coerced.
+ for _timeout_key in ("aws_read_timeout", "aws_connect_timeout"):
+ if extra.get(_timeout_key) is not None:
+ kwargs[_timeout_key] = float(extra[_timeout_key])
+ return kwargs
+
+
+class PydanticAIVertexHook(PydanticAIHook):
+ """
+ Hook for Google Vertex AI (or Generative Language API) via pydantic-ai.
+
+ Credentials are resolved in order:
+
+ 1. ``service_account_file`` (path string) or ``service_account_info`` (JSON
+ object) in ``extra`` — loaded into a
``google.auth.credentials.Credentials``
+ object and passed as ``credentials`` to ``GoogleProvider``.
+ 2. ``api_key`` in ``extra`` — for Generative Language API (non-Vertex) or
+ Vertex API-key auth.
+ 3. Application Default Credentials (``GOOGLE_APPLICATION_CREDENTIALS``,
+ ``gcloud auth application-default login``, Workload Identity, …) when
+ no explicit credentials are provided.
+
+ Connection fields:
+ - **extra** JSON::
+
+ {
+ "model": "google-vertex:gemini-2.0-flash",
+ "project": "my-gcp-project",
+ "location": "us-central1",
+ "service_account_file": "/path/to/sa.json",
+ "vertexai": true,
+ }
+
+ Use ``"service_account_info"`` instead of ``"service_account_file"`` to
+ embed the service-account JSON directly (as an object, not a string
path).
+ Setting both at the same time raises ``ValueError``.
+
+ Set ``"vertexai": true`` to force Vertex AI mode when only ``api_key``
is
+ provided. Omit ``vertexai`` for the Generative Language API (GLA).
+
+ :param llm_conn_id: Airflow connection ID.
+ :param model_id: Model identifier, e.g.
``"google-vertex:gemini-2.0-flash"``.
+ """
+
+ conn_type = "pydanticai_vertex"
+ hook_name = "Pydantic AI (Google Vertex AI)"
+
+ @staticmethod
+ def get_ui_field_behaviour() -> dict[str, Any]:
+ """Return custom field behaviour for the Airflow connection form."""
+ return {
+ "hidden_fields": ["schema", "port", "login", "host", "password"],
+ "relabeling": {},
+ "placeholders": {
+ "extra": (
+ '{"model": "google-vertex:gemini-2.0-flash", '
+ '"project": "my-project", "location": "us-central1",
"vertexai": true}'
+ " — add service_account_file (path) or
service_account_info (object) for SA auth;"
+ " omit both to use Application Default Credentials"
+ ),
+ },
+ }
+
+ def _get_provider_kwargs(
+ self,
+ api_key: str | None,
+ base_url: str | None,
+ extra: dict[str, Any],
+ ) -> dict[str, Any]:
+ sa_file = extra.get("service_account_file")
+ sa_info = extra.get("service_account_info")
+ if sa_file and sa_info:
+ raise ValueError(
+ "Specify 'service_account_file' or 'service_account_info' in
the connection extra, not both."
+ )
+
+ kwargs: dict[str, Any] = {}
+
+ # Direct GoogleProvider scalar kwargs.
+ for _key in ("api_key", "project", "location", "base_url"):
+ if extra.get(_key) is not None:
+ kwargs[_key] = extra[_key]
+
+ # Optional vertexai bool flag (force Vertex AI mode for API-key auth).
+ _vertexai = extra.get("vertexai")
+ if _vertexai is not None:
+ kwargs["vertexai"] = bool(_vertexai)
+
+ # Service-account credentials — loaded lazily to avoid importing
+ # google-auth on non-Vertex code paths (optional heavy dependency).
+ if sa_file:
+ from google.oauth2 import service_account # lazy: optional dep
+
+ kwargs["credentials"] =
service_account.Credentials.from_service_account_file(
+ sa_file,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+ elif sa_info:
+ from google.oauth2 import service_account # lazy: optional dep
+
+ kwargs["credentials"] =
service_account.Credentials.from_service_account_info(
+ sa_info,
+ scopes=["https://www.googleapis.com/auth/cloud-platform"],
+ )
+
+ return kwargs
+
+
+# ---------------------------------------------------------------------------
+# Hook registry — maps conn_type → hook class for use by for_connection()
+# ---------------------------------------------------------------------------
+_CONN_TYPE_TO_HOOK: dict[str, type[PydanticAIHook]] = {
Review Comment:
During development, I hadn't defined the hooks in the provider.yaml file, so
the tests couldn't discover them. That's why I ended up adding the
for_connection() method as a workaround.
I didn't know ProviderManager handled it automatically under the hood like
that—thanks for the heads-up! :)
I've updated it as you suggested: adding them to the yaml file made
_CONN_TYPE_TO_HOOK and for_connection() completely unnecessary, so I've removed
them. ;)
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]