This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new c57be2ffc5 Add ability to provide proxy for dbt Cloud connection 
(#42737)
c57be2ffc5 is described below

commit c57be2ffc533a3002c0c05d9045c77717c8a0e36
Author: Benoit Perigaud <8754100+b-...@users.noreply.github.com>
AuthorDate: Tue Oct 8 20:31:25 2024 +0200

    Add ability to provide proxy for dbt Cloud connection (#42737)
    
    * Add ability to provide proxy for dbt Cloud connection
    
    * Running pre-commit checks
    
    * Update current tests and add new test with proxy
---
 airflow/providers/dbt/cloud/hooks/dbt.py           | 56 +++++++++----
 .../connections.rst                                | 14 ++++
 tests/providers/dbt/cloud/hooks/test_dbt.py        | 95 ++++++++++++++++++----
 3 files changed, 134 insertions(+), 31 deletions(-)

diff --git a/airflow/providers/dbt/cloud/hooks/dbt.py 
b/airflow/providers/dbt/cloud/hooks/dbt.py
index acffb47716..b13e1003b9 100644
--- a/airflow/providers/dbt/cloud/hooks/dbt.py
+++ b/airflow/providers/dbt/cloud/hooks/dbt.py
@@ -26,7 +26,6 @@ from inspect import signature
 from typing import TYPE_CHECKING, Any, Callable, Sequence, Set, TypeVar, cast
 
 import aiohttp
-from aiohttp import ClientResponseError
 from asgiref.sync import sync_to_async
 from requests.auth import AuthBase
 from requests.sessions import Session
@@ -182,9 +181,12 @@ class DbtCloudHook(HttpHook):
     def get_ui_field_behaviour(cls) -> dict[str, Any]:
         """Build custom field behavior for the dbt Cloud connection form in 
the Airflow UI."""
         return {
-            "hidden_fields": ["schema", "port", "extra"],
+            "hidden_fields": ["schema", "port"],
             "relabeling": {"login": "Account ID", "password": "API Token", 
"host": "Tenant"},
-            "placeholders": {"host": "Defaults to 'cloud.getdbt.com'."},
+            "placeholders": {
+                "host": "Defaults to 'cloud.getdbt.com'.",
+                "extra": "Optional JSON-formatted extra.",
+            },
         }
 
     def __init__(self, dbt_cloud_conn_id: str = default_conn_name, *args, 
**kwargs) -> None:
@@ -195,6 +197,10 @@ class DbtCloudHook(HttpHook):
     def _get_tenant_domain(conn: Connection) -> str:
         return conn.host or "cloud.getdbt.com"
 
+    @staticmethod
+    def _get_proxies(conn: Connection) -> dict[str, str] | None:
+        return conn.extra_dejson.get("proxies", None)
+
     @staticmethod
     def get_request_url_params(
         tenant: str, endpoint: str, include_related: list[str] | None = None, 
*, api_version: str = "v2"
@@ -238,14 +244,26 @@ class DbtCloudHook(HttpHook):
         endpoint = f"{account_id}/runs/{run_id}/"
         headers, tenant = await self.get_headers_tenants_from_connection()
         url, params = self.get_request_url_params(tenant, endpoint, 
include_related)
-        async with aiohttp.ClientSession(headers=headers) as session, 
session.get(
-            url, params=params
-        ) as response:
-            try:
-                response.raise_for_status()
-                return await response.json()
-            except ClientResponseError as e:
-                raise AirflowException(f"{e.status}:{e.message}")
+        proxies = self._get_proxies(self.connection)
+        async with aiohttp.ClientSession(headers=headers) as session:
+            if proxies is not None:
+                if url.startswith("https"):
+                    proxy = proxies.get("https")
+                else:
+                    proxy = proxies.get("http")
+                async with session.get(url, params=params, proxy=proxy) as 
response:
+                    try:
+                        response.raise_for_status()
+                        return await response.json()
+                    except aiohttp.ClientResponseError as e:
+                        raise AirflowException(f"{e.status}:{e.message}")
+            else:
+                async with session.get(url, params=params) as response:
+                    try:
+                        response.raise_for_status()
+                        return await response.json()
+                    except aiohttp.ClientResponseError as e:
+                        raise AirflowException(f"{e.status}:{e.message}")
 
     async def get_job_status(
         self, run_id: int, account_id: int | None = None, include_related: 
list[str] | None = None
@@ -280,8 +298,11 @@ class DbtCloudHook(HttpHook):
 
         return session
 
-    def _paginate(self, endpoint: str, payload: dict[str, Any] | None = None) 
-> list[Response]:
-        response = self.run(endpoint=endpoint, data=payload)
+    def _paginate(
+        self, endpoint: str, payload: dict[str, Any] | None = None, proxies: 
dict[str, str] | None = None
+    ) -> list[Response]:
+        extra_options = {"proxies": proxies} if proxies is not None else None
+        response = self.run(endpoint=endpoint, data=payload, 
extra_options=extra_options)
         resp_json = response.json()
         limit = resp_json["extra"]["filters"]["limit"]
         num_total_results = resp_json["extra"]["pagination"]["total_count"]
@@ -292,7 +313,7 @@ class DbtCloudHook(HttpHook):
             _paginate_payload["offset"] = limit
 
             while num_current_results < num_total_results:
-                response = self.run(endpoint=endpoint, data=_paginate_payload)
+                response = self.run(endpoint=endpoint, data=_paginate_payload, 
extra_options=extra_options)
                 resp_json = response.json()
                 results.append(response)
                 num_current_results += 
resp_json["extra"]["pagination"]["count"]
@@ -310,17 +331,20 @@ class DbtCloudHook(HttpHook):
     ) -> Any:
         self.method = method
         full_endpoint = f"api/{api_version}/accounts/{endpoint}" if endpoint 
else None
+        proxies = self._get_proxies(self.connection)
+        extra_options = {"proxies": proxies} if proxies is not None else None
 
         if paginate:
             if isinstance(payload, str):
                 raise ValueError("Payload cannot be a string to paginate a 
response.")
 
             if full_endpoint:
-                return self._paginate(endpoint=full_endpoint, payload=payload)
+                return self._paginate(endpoint=full_endpoint, payload=payload, 
proxies=proxies)
 
             raise ValueError("An endpoint is needed to paginate a response.")
 
-        return self.run(endpoint=full_endpoint, data=payload)
+        # breakpoint()
+        return self.run(endpoint=full_endpoint, data=payload, 
extra_options=extra_options)
 
     def list_accounts(self) -> list[Response]:
         """
diff --git a/docs/apache-airflow-providers-dbt-cloud/connections.rst 
b/docs/apache-airflow-providers-dbt-cloud/connections.rst
index f3514cc83a..428c15d12e 100644
--- a/docs/apache-airflow-providers-dbt-cloud/connections.rst
+++ b/docs/apache-airflow-providers-dbt-cloud/connections.rst
@@ -77,6 +77,20 @@ Host (optional)
     If using the Connection form in the Airflow UI, the Tenant domain can also 
be stored in the "Tenant"
     field.
 
+Extra (optional)
+    Specify extra parameters as JSON dictionary. As of now, only `proxies` is 
supported when wanting to connect to dbt Cloud via a proxy.
+
+    `proxies` should be a dictionary of proxies to be used by HTTP and HTTPS 
connections.
+
+.. code-block:: json
+
+    {
+      "proxies": {
+        "http": "http://myproxy.mycompany.local:8080";,
+        "https": "http://myproxy.mycompany.local:8080";
+      }
+    }
+
 When specifying the connection as an environment variable, you should specify 
it following the standard syntax
 of a database connection. Note that all components of the URI should be 
URL-encoded.
 
diff --git a/tests/providers/dbt/cloud/hooks/test_dbt.py 
b/tests/providers/dbt/cloud/hooks/test_dbt.py
index 71ef75ba31..0d84189bc8 100644
--- a/tests/providers/dbt/cloud/hooks/test_dbt.py
+++ b/tests/providers/dbt/cloud/hooks/test_dbt.py
@@ -40,9 +40,11 @@ pytestmark = pytest.mark.db_test
 ACCOUNT_ID_CONN = "account_id_conn"
 NO_ACCOUNT_ID_CONN = "no_account_id_conn"
 SINGLE_TENANT_CONN = "single_tenant_conn"
+PROXY_CONN = "proxy_conn"
 DEFAULT_ACCOUNT_ID = 11111
 ACCOUNT_ID = 22222
 SINGLE_TENANT_DOMAIN = "single.tenant.getdbt.com"
+EXTRA_PROXIES = {"proxies": {"https": "http://myproxy:1234"}}
 TOKEN = "token"
 PROJECT_ID = 33333
 JOB_ID = 4444
@@ -136,9 +138,20 @@ class TestDbtCloudHook:
             host=SINGLE_TENANT_DOMAIN,
         )
 
+        # Connection with a proxy set in extra parameters
+        proxy_conn = Connection(
+            conn_id=PROXY_CONN,
+            conn_type=DbtCloudHook.conn_type,
+            login=DEFAULT_ACCOUNT_ID,
+            password=TOKEN,
+            host=SINGLE_TENANT_DOMAIN,
+            extra=EXTRA_PROXIES,
+        )
+
         db.merge_conn(account_id_conn)
         db.merge_conn(no_account_id_conn)
         db.merge_conn(host_conn)
+        db.merge_conn(proxy_conn)
 
     @pytest.mark.parametrize(
         argnames="conn_id, url",
@@ -196,7 +209,7 @@ class TestDbtCloudHook:
         hook.list_accounts()
 
         assert hook.method == "GET"
-        hook.run.assert_called_once_with(endpoint=None, data=None)
+        hook.run.assert_called_once_with(endpoint=None, data=None, 
extra_options=None)
         hook._paginate.assert_not_called()
 
     @pytest.mark.parametrize(
@@ -213,7 +226,9 @@ class TestDbtCloudHook:
         assert hook.method == "GET"
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
-        
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/", 
data=None)
+        hook.run.assert_called_once_with(
+            endpoint=f"api/v2/accounts/{_account_id}/", data=None, 
extra_options=None
+        )
         hook._paginate.assert_not_called()
 
     @pytest.mark.parametrize(
@@ -232,7 +247,7 @@ class TestDbtCloudHook:
         _account_id = account_id or DEFAULT_ACCOUNT_ID
         hook.run.assert_not_called()
         hook._paginate.assert_called_once_with(
-            endpoint=f"api/v3/accounts/{_account_id}/projects/", payload=None
+            endpoint=f"api/v3/accounts/{_account_id}/projects/", payload=None, 
proxies=None
         )
 
     @pytest.mark.parametrize(
@@ -250,7 +265,7 @@ class TestDbtCloudHook:
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
         hook.run.assert_called_once_with(
-            endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/", 
data=None
+            endpoint=f"api/v3/accounts/{_account_id}/projects/{PROJECT_ID}/", 
data=None, extra_options=None
         )
         hook._paginate.assert_not_called()
 
@@ -269,7 +284,9 @@ class TestDbtCloudHook:
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
         hook._paginate.assert_called_once_with(
-            endpoint=f"api/v2/accounts/{_account_id}/jobs/", 
payload={"order_by": None, "project_id": None}
+            endpoint=f"api/v2/accounts/{_account_id}/jobs/",
+            payload={"order_by": None, "project_id": None},
+            proxies=None,
         )
         hook.run.assert_not_called()
 
@@ -290,6 +307,7 @@ class TestDbtCloudHook:
         hook._paginate.assert_called_once_with(
             endpoint=f"api/v2/accounts/{_account_id}/jobs/",
             payload={"order_by": "-id", "project_id": PROJECT_ID},
+            proxies=None,
         )
         hook.run.assert_not_called()
 
@@ -307,7 +325,9 @@ class TestDbtCloudHook:
         assert hook.method == "GET"
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
-        
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}",
 data=None)
+        hook.run.assert_called_once_with(
+            endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}", 
data=None, extra_options=None
+        )
         hook._paginate.assert_not_called()
 
     @pytest.mark.parametrize(
@@ -328,6 +348,7 @@ class TestDbtCloudHook:
         hook.run.assert_called_once_with(
             endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
             data=json.dumps({"cause": cause, "steps_override": None, 
"schema_override": None}),
+            extra_options=None,
         )
         hook._paginate.assert_not_called()
 
@@ -359,6 +380,7 @@ class TestDbtCloudHook:
             data=json.dumps(
                 {"cause": cause, "steps_override": steps_override, 
"schema_override": schema_override}
             ),
+            extra_options=None,
         )
         hook._paginate.assert_not_called()
 
@@ -393,6 +415,7 @@ class TestDbtCloudHook:
                     "generate_docs_override": False,
                 }
             ),
+            extra_options=None,
         )
         hook._paginate.assert_not_called()
 
@@ -422,6 +445,7 @@ class TestDbtCloudHook:
         hook.run.assert_called_once_with(
             endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
             data=json.dumps({"cause": expected_cause, "steps_override": None, 
"schema_override": None}),
+            extra_options=None,
         )
         hook._paginate.assert_not_called()
 
@@ -467,7 +491,9 @@ class TestDbtCloudHook:
             hook._paginate.assert_not_called()
             if should_use_rerun:
                 hook.run.assert_called_once_with(
-                    
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/rerun/", data=None
+                    
endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/rerun/",
+                    data=None,
+                    extra_options=None,
                 )
             else:
                 hook.run.assert_called_once_with(
@@ -479,8 +505,31 @@ class TestDbtCloudHook:
                             "schema_override": None,
                         }
                     ),
+                    extra_options=None,
                 )
 
+    @pytest.mark.parametrize(
+        argnames="conn_id, account_id",
+        argvalues=[(PROXY_CONN, ACCOUNT_ID)],
+        ids=["proxy_connection"],
+    )
+    @patch.object(DbtCloudHook, "run")
+    @patch.object(DbtCloudHook, "_paginate")
+    def test_trigger_job_run_with_proxy(self, mock_http_run, mock_paginate, 
conn_id, account_id):
+        hook = DbtCloudHook(conn_id)
+        cause = ""
+        hook.trigger_job_run(job_id=JOB_ID, cause=cause, account_id=account_id)
+
+        assert hook.method == "POST"
+
+        _account_id = account_id or DEFAULT_ACCOUNT_ID
+        hook.run.assert_called_once_with(
+            endpoint=f"api/v2/accounts/{_account_id}/jobs/{JOB_ID}/run/",
+            data=json.dumps({"cause": cause, "steps_override": None, 
"schema_override": None}),
+            extra_options={"proxies": {"https": "http://myproxy:1234"}},
+        )
+        hook._paginate.assert_not_called()
+
     @pytest.mark.parametrize(
         argnames="conn_id, account_id",
         argvalues=[(ACCOUNT_ID_CONN, None), (NO_ACCOUNT_ID_CONN, ACCOUNT_ID)],
@@ -503,6 +552,7 @@ class TestDbtCloudHook:
                 "job_definition_id": None,
                 "order_by": None,
             },
+            proxies=None,
         )
 
     @pytest.mark.parametrize(
@@ -529,6 +579,7 @@ class TestDbtCloudHook:
                 "job_definition_id": JOB_ID,
                 "order_by": "id",
             },
+            proxies=None,
         )
 
     @pytest.mark.parametrize(
@@ -544,7 +595,9 @@ class TestDbtCloudHook:
         assert hook.method == "GET"
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
-        
hook.run.assert_called_once_with(endpoint=f"api/v2/accounts/{_account_id}/runs/",
 data=None)
+        hook.run.assert_called_once_with(
+            endpoint=f"api/v2/accounts/{_account_id}/runs/", data=None, 
extra_options=None
+        )
 
     @pytest.mark.parametrize(
         argnames="conn_id, account_id",
@@ -561,7 +614,9 @@ class TestDbtCloudHook:
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
         hook.run.assert_called_once_with(
-            endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", 
data={"include_related": None}
+            endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/",
+            data={"include_related": None},
+            extra_options=None,
         )
         hook._paginate.assert_not_called()
 
@@ -580,7 +635,9 @@ class TestDbtCloudHook:
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
         hook.run.assert_called_once_with(
-            endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/", 
data={"include_related": ["triggers"]}
+            endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/",
+            data={"include_related": ["triggers"]},
+            extra_options=None,
         )
         hook._paginate.assert_not_called()
 
@@ -645,7 +702,7 @@ class TestDbtCloudHook:
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
         hook.run.assert_called_once_with(
-            endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/cancel/", 
data=None
+            endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/cancel/", 
data=None, extra_options=None
         )
         hook._paginate.assert_not_called()
 
@@ -664,7 +721,9 @@ class TestDbtCloudHook:
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
         hook.run.assert_called_once_with(
-            
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", 
data={"step": None}
+            endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/",
+            data={"step": None},
+            extra_options=None,
         )
         hook._paginate.assert_not_called()
 
@@ -683,7 +742,9 @@ class TestDbtCloudHook:
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
         hook.run.assert_called_once_with(
-            
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/", 
data={"step": 2}
+            endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/",
+            data={"step": 2},
+            extra_options=None,
         )
         hook._paginate.assert_not_called()
 
@@ -703,7 +764,9 @@ class TestDbtCloudHook:
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
         hook.run.assert_called_once_with(
-            
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", 
data={"step": None}
+            
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}",
+            data={"step": None},
+            extra_options=None,
         )
         hook._paginate.assert_not_called()
 
@@ -723,7 +786,9 @@ class TestDbtCloudHook:
 
         _account_id = account_id or DEFAULT_ACCOUNT_ID
         hook.run.assert_called_once_with(
-            
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}", 
data={"step": 2}
+            
endpoint=f"api/v2/accounts/{_account_id}/runs/{RUN_ID}/artifacts/{path}",
+            data={"step": 2},
+            extra_options=None,
         )
         hook._paginate.assert_not_called()
 

Reply via email to