This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new c57be2ffc5 Add ability to provide proxy for dbt Cloud connection (#42737) c57be2ffc5 is described below commit c57be2ffc533a3002c0c05d9045c77717c8a0e36 Author: Benoit Perigaud <8754100+b-...@users.noreply.github.com> AuthorDate: Tue Oct 8 20:31:25 2024 +0200 Add ability to provide proxy for dbt Cloud connection (#42737) * Add ability to provide proxy for dbt Cloud connection * Running pre-commit checks * Update current tests and add new test with proxy --- airflow/providers/dbt/cloud/hooks/dbt.py | 56 +++++++++---- .../connections.rst | 14 ++++ tests/providers/dbt/cloud/hooks/test_dbt.py | 95 ++++++++++++++++++---- 3 files changed, 134 insertions(+), 31 deletions(-) diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py b/airflow/providers/dbt/cloud/hooks/dbt.py index acffb47716..b13e1003b9 100644 --- a/airflow/providers/dbt/cloud/hooks/dbt.py +++ b/airflow/providers/dbt/cloud/hooks/dbt.py @@ -26,7 +26,6 @@ from inspect import signature from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, TypeVar, cast import aiohttp -from aiohttp import ClientResponseError from asgiref.sync import sync_to_async from requests.auth import AuthBase from requests.sessions import Session @@ -182,9 +181,12 @@ class DbtCloudHook(HttpHook): def get_ui_field_behaviour(cls) -> dict[str, Any]: """Build custom field behavior for the dbt Cloud connection form in the Airflow UI.""" return { - "hidden_fields": ["schema", "port", "extra"], + "hidden_fields": ["schema", "port"], "relabeling": {"login": "Account ID", "password": "API Token", "host": "Tenant"}, - "placeholders": {"host": "Defaults to 'cloud.getdbt.com'."}, + "placeholders": { + "host": "Defaults to 'cloud.getdbt.com'.", + "extra": "Optional JSON-formatted extra.", + }, } def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, **kwargs) -> None: @@ -195,6 +197,10 @@ class DbtCloudHook(HttpHook): def _get_tenant_domain(conn: Connection) -> str: return conn.host or "cloud.getdbt.com" + @staticmethod + def _get_proxies(conn: Connection) -> dict[str, str] | None: + return conn.extra_dejson.get("proxies", None) + @staticmethod def get_request_url_params( tenant: str, endpoint: str, include_related: list[str] | None = None, *, api_version: str = "v2" @@ -238,14 +244,26 @@ class DbtCloudHook(HttpHook): endpoint = f"{account_id}/runs/{run_id}/" headers, tenant = await self.get_headers_tenants_from_connection() url, params = self.get_request_url_params(tenant, endpoint, include_related) - async with aiohttp.ClientSession(headers=headers) as session, session.get( - url, params=params - ) as response: - try: - response.raise_for_status() - return await response.json() - except ClientResponseError as e: - raise AirflowException(f"{e.status}:{e.message}") + proxies = self._get_proxies(self.connection) + async with aiohttp.ClientSession(headers=headers) as session: + if proxies is not None: + if url.startswith("https"): + proxy = proxies.get("https") + else: + proxy = proxies.get("http") + async with session.get(url, params=params, proxy=proxy) as response: + try: + response.raise_for_status() + return await response.json() + except aiohttp.ClientResponseError as e: + raise AirflowException(f"{e.status}:{e.message}") + else: + async with session.get(url, params=params) as response: + try: + response.raise_for_status() + return await response.json() + except aiohttp.ClientResponseError as e: + raise AirflowException(f"{e.status}:{e.message}") async def get_job_status( self, run_id: int, account_id: int | None = None, include_related: list[str] | None = None @@ -280,8 +298,11 @@ class DbtCloudHook(HttpHook): return session - def _paginate(self, endpoint: str, payload: dict[str, Any] | None = None) -> list[Response]: - response = self.run(endpoint=endpoint, data=payload) + def _paginate( + self, endpoint: str, payload: dict[str, Any] | None = None, proxies: dict[str, str] | None = None + ) -> list[Response]: + extra_options = {"proxies": proxies} if proxies is not None else None + response = self.run(endpoint=endpoint, data=payload, extra_options=extra_options) resp_json = response.json() limit = resp_json["extra"]["filters"]["limit"] num_total_results = resp_json["extra"]["pagination"]["total_count"] @@ -292,7 +313,7 @@ class DbtCloudHook(HttpHook): _paginate_payload["offset"] = limit while num_current_results < num_total_results: - response = self.run(endpoint=endpoint, data=_paginate_payload) + response = self.run(endpoint=endpoint, data=_paginate_payload, extra_options=extra_options) resp_json = response.json() results.append(response) num_current_results += resp_json["extra"]["pagination"]["count"] @@ -310,17 +331,20 @@ class DbtCloudHook(HttpHook): ) -> Any: self.method = method full_endpoint = f"api/{api_version}/accounts/{endpoint}" if endpoint else None + proxies = self._get_proxies(self.connection) + extra_options = {"proxies": proxies} if proxies is not None else None if paginate: if isinstance(payload, str): raise ValueError("Payload cannot be a string to paginate a response.") if full_endpoint: - return self._paginate(endpoint=full_endpoint, payload=payload) + return self._paginate(endpoint=full_endpoint, payload=payload, proxies=proxies) raise ValueError("An endpoint is needed to paginate a response.") - return self.run(endpoint=full_endpoint, data=payload) + # breakpoint() + return self.run(endpoint=full_endpoint, data=payload, extra_options=extra_options) def list_accounts(self) -> list[Response]: """ diff --git a/docs/apache-airflow-providers-dbt-cloud/connections.rst b/docs/apache-airflow-providers-dbt-cloud/connections.rst index f3514cc83a..428c15d12e 100644 --- a/docs/apache-airflow-providers-dbt-cloud/connections.rst +++ b/docs/apache-airflow-providers-dbt-cloud/connections.rst @@ -77,6 +77,20 @@ Host (optional) If using the Connection form in the Airflow UI, the Tenant domain can also be stored in the "Tenant" field. +Extra (optional) + Specify extra parameters as JSON dictionary. As of now, only `proxies` is supported when wanting to connect to dbt Cloud via a proxy. + + `proxies` should be a dictionary of proxies to be used by HTTP and HTTPS connections. + +.. code-block:: json + + { + "proxies": { + "http": "http://myproxy.mycompany.local:8080", + "https": "http://myproxy.mycompany.local:8080" + } + } + When specifying the connection as an environment variable, you should specify it following the standard syntax of a database connection. Note that all components of the URI should be URL-encoded. diff --git a/tests/providers/dbt/cloud/hooks/test_dbt.py b/tests/providers/dbt/cloud/hooks/test_dbt.py index 71ef75ba31..0d84189bc8 100644 --- a/tests/providers/dbt/cloud/hooks/test_dbt.py +++ b/tests/providers/dbt/cloud/hooks/test_dbt.py @@ -40,9 +40,11 @@ pytestmark = pytest.mark.db_test ACCOUNT_ID_CONN = "account_id_conn" NO_ACCOUNT_ID_CONN = "no_account_id_conn" SINGLE_TENANT_CONN = "single_tenant_conn" +PROXY_CONN = "proxy_conn" DEFAULT_ACCOUNT_ID = 11111 ACCOUNT_ID = 22222 SINGLE_TENANT_DOMAIN = "single.tenant.getdbt.com" +EXTRA_PROXIES = {"proxies": {"https": "http://myproxy:1234"}} TOKEN = "token" PROJECT_ID = 33333 JOB_ID = 4444 @@ -136,9 +138,20 @@ class TestDbtCloudHook: host=SINGLE_TENANT_DOMAIN, ) + # Connection with a proxy set in extra parameters + proxy_conn = Connection( + conn_id=PROXY_CONN, + conn_type=DbtCloudHook.conn_type, + login=DEFAULT_ACCOUNT_ID, + password=TOKEN, + host=SINGLE_TENANT_DOMAIN, + extra=EXTRA_PROXIES, + ) + db.merge_conn(account_id_conn) db.merge_conn(no_account_id_conn) db.merge_conn(host_conn) + db.merge_conn(proxy_conn) @pytest.mark.parametrize( argnames="conn_id, url", @@ -196,7 +209,7 @@ class TestDbtCloudHook: hook.list_accounts() assert hook.method == "GET" - hook.run.assert_called_once_with(endpoint=None, data=None) + hook.run.assert_called_once_with(endpoint=None, data=None, extra_options=None) hook._paginate.assert_not_called() @pytest.mark.parametrize( @@ -213,7 +226,9 @@ class TestDbtCloudHook: assert hook.method == "GET" _account_id = account_id or DEFAULT_ACCOUNT_ID - hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/", data=None) + hook.run.assert_called_once_with( + endpoint=f"api/v2/accounts/{_account_id}/", data=None, extra_options=None + ) hook._paginate.assert_not_called() @pytest.mark.parametrize( @@ -232,7 +247,7 @@ class TestDbtCloudHook: _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_not_called() hook._paginate.assert_called_once_with( - endpoint=f"api/v3/accounts/{_account_id}/projects/", payload=None + endpoint=f"api/v3/accounts/{_account_id}/projects/", payload=None, proxies=None ) @pytest.mark.parametrize( @@ -250,7 +265,7 @@ class TestDbtCloudHook: _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/", data=None + endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/", data=None, extra_options=None ) hook._paginate.assert_not_called() @@ -269,7 +284,9 @@ class TestDbtCloudHook: _account_id = account_id or DEFAULT_ACCOUNT_ID hook._paginate.assert_called_once_with( - endpoint=f"api/v2/accounts/{_account_id}/jobs/", payload={"order_by": None, "project_id": None} + endpoint=f"api/v2/accounts/{_account_id}/jobs/", + payload={"order_by": None, "project_id": None}, + proxies=None, ) hook.run.assert_not_called() @@ -290,6 +307,7 @@ class TestDbtCloudHook: hook._paginate.assert_called_once_with( endpoint=f"api/v2/accounts/{_account_id}/jobs/", payload={"order_by": "-id", "project_id": PROJECT_ID}, + proxies=None, ) hook.run.assert_not_called() @@ -307,7 +325,9 @@ class TestDbtCloudHook: assert hook.method == "GET" _account_id = account_id or DEFAULT_ACCOUNT_ID - hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}", data=None) + hook.run.assert_called_once_with( + endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}", data=None, extra_options=None + ) hook._paginate.assert_not_called() @pytest.mark.parametrize( @@ -328,6 +348,7 @@ class TestDbtCloudHook: hook.run.assert_called_once_with( endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/", data=json.dumps({"cause": cause, "steps_override": None, "schema_override": None}), + extra_options=None, ) hook._paginate.assert_not_called() @@ -359,6 +380,7 @@ class TestDbtCloudHook: data=json.dumps( {"cause": cause, "steps_override": steps_override, "schema_override": schema_override} ), + extra_options=None, ) hook._paginate.assert_not_called() @@ -393,6 +415,7 @@ class TestDbtCloudHook: "generate_docs_override": False, } ), + extra_options=None, ) hook._paginate.assert_not_called() @@ -422,6 +445,7 @@ class TestDbtCloudHook: hook.run.assert_called_once_with( endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/", data=json.dumps({"cause": expected_cause, "steps_override": None, "schema_override": None}), + extra_options=None, ) hook._paginate.assert_not_called() @@ -467,7 +491,9 @@ class TestDbtCloudHook: hook._paginate.assert_not_called() if should_use_rerun: hook.run.assert_called_once_with( - endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/rerun/", data=None + endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/rerun/", + data=None, + extra_options=None, ) else: hook.run.assert_called_once_with( @@ -479,8 +505,31 @@ class TestDbtCloudHook: "schema_override": None, } ), + extra_options=None, ) + @pytest.mark.parametrize( + argnames="conn_id, account_id", + argvalues=[(PROXY_CONN, ACCOUNT_ID)], + ids=["proxy_connection"], + ) + @patch.object(DbtCloudHook, "run") + @patch.object(DbtCloudHook, "_paginate") + def test_trigger_job_run_with_proxy(self, mock_http_run, mock_paginate, conn_id, account_id): + hook = DbtCloudHook(conn_id) + cause = "" + hook.trigger_job_run(job_id=JOB_ID, cause=cause, account_id=account_id) + + assert hook.method == "POST" + + _account_id = account_id or DEFAULT_ACCOUNT_ID + hook.run.assert_called_once_with( + endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/", + data=json.dumps({"cause": cause, "steps_override": None, "schema_override": None}), + extra_options={"proxies": {"https": "http://myproxy:1234"}}, + ) + hook._paginate.assert_not_called() + @pytest.mark.parametrize( argnames="conn_id, account_id", argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)], @@ -503,6 +552,7 @@ class TestDbtCloudHook: "job_definition_id": None, "order_by": None, }, + proxies=None, ) @pytest.mark.parametrize( @@ -529,6 +579,7 @@ class TestDbtCloudHook: "job_definition_id": JOB_ID, "order_by": "id", }, + proxies=None, ) @pytest.mark.parametrize( @@ -544,7 +595,9 @@ class TestDbtCloudHook: assert hook.method == "GET" _account_id = account_id or DEFAULT_ACCOUNT_ID - hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/runs/", data=None) + hook.run.assert_called_once_with( + endpoint=f"api/v2/accounts/{_account_id}/runs/", data=None, extra_options=None + ) @pytest.mark.parametrize( argnames="conn_id, account_id", @@ -561,7 +614,9 @@ class TestDbtCloudHook: _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", data={"include_related": None} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", + data={"include_related": None}, + extra_options=None, ) hook._paginate.assert_not_called() @@ -580,7 +635,9 @@ class TestDbtCloudHook: _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", data={"include_related": ["triggers"]} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", + data={"include_related": ["triggers"]}, + extra_options=None, ) hook._paginate.assert_not_called() @@ -645,7 +702,7 @@ class TestDbtCloudHook: _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/cancel/", data=None + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/cancel/", data=None, extra_options=None ) hook._paginate.assert_not_called() @@ -664,7 +721,9 @@ class TestDbtCloudHook: _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": None} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", + data={"step": None}, + extra_options=None, ) hook._paginate.assert_not_called() @@ -683,7 +742,9 @@ class TestDbtCloudHook: _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", data={"step": 2} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", + data={"step": 2}, + extra_options=None, ) hook._paginate.assert_not_called() @@ -703,7 +764,9 @@ class TestDbtCloudHook: _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": None} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", + data={"step": None}, + extra_options=None, ) hook._paginate.assert_not_called() @@ -723,7 +786,9 @@ class TestDbtCloudHook: _account_id = account_id or DEFAULT_ACCOUNT_ID hook.run.assert_called_once_with( - endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", data={"step": 2} + endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", + data={"step": 2}, + extra_options=None, ) hook._paginate.assert_not_called()