This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new 2b101e2377 Feature: Added event_handler parameter in MSGraphAsyncOperator (#42539) 2b101e2377 is described below commit 2b101e2377f8d49a46aca6c219e4b38ee099a98d Author: David Blain <i...@dabla.be> AuthorDate: Tue Oct 15 02:41:36 2024 +0200 Feature: Added event_handler parameter in MSGraphAsyncOperator (#42539) * refactor: Added parameter in MSGraphAsyncOperator to allow overriding default event_handler * docs: Added docstring for event_handler parameter in MSGraphAsyncOperator * refactor: Fixed TestMSGraphAsyncOperator * refactor: Check if event is not None * refactor: Register the TextParseNodeFactory and JsonParseNodeFactory so error messages get handled correctly in RequestAdapter * refactor: Reorganized import for TestMSGraphAsyncOperator * refactor: Added missing kiota-serialization packages in azure provider * refactor: Updated provider dependencies * refactor: Reorganized import of TestKiotaRequestAdapterHook * refactor: Downgraded version of json kiota serialization * refactor: Updated provider dependencies * refactor: Put import of Context in TYPE_CHECKING block * refactor: Fixed lookup of tenant-id * refactor: Fixed kiota serialization dependencies to 1.0.0 to avoid pendulum dependency issues for backward compatibility * refactor: Updated provider dependencies * refactored: Fixed import of test_utils in test_dag_run --------- Co-authored-by: David Blain <david.bl...@infrabel.be> --- generated/provider_dependencies.json | 2 + .../providers/microsoft/azure/hooks/msgraph.py | 7 ++++ .../providers/microsoft/azure/operators/msgraph.py | 18 +++++++-- .../providers/microsoft/azure/provider.yaml | 2 + .../tests/microsoft/azure/hooks/test_msgraph.py | 45 +++++++++++++++++++++- .../microsoft/azure/operators/test_msgraph.py | 30 +++++++++++++++ 6 files changed, 99 insertions(+), 5 deletions(-) diff --git a/generated/provider_dependencies.json b/generated/provider_dependencies.json index 4d921fc1fb..8efdd5eae7 100644 --- a/generated/provider_dependencies.json +++ b/generated/provider_dependencies.json @@ -828,6 +828,8 @@ "azure-synapse-artifacts>=0.17.0", "azure-synapse-spark>=0.2.0", "microsoft-kiota-http>=1.3.0,!=1.3.4", + "microsoft-kiota-serialization-json==1.0.0", + "microsoft-kiota-serialization-text==1.0.0", "msgraph-core>=1.0.0" ], "devel-deps": [ diff --git a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py index 61e555f4ca..4ab3aaf3ba 100644 --- a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py +++ b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py @@ -32,11 +32,14 @@ from kiota_abstractions.api_error import APIError from kiota_abstractions.method import Method from kiota_abstractions.request_information import RequestInformation from kiota_abstractions.response_handler import ResponseHandler +from kiota_abstractions.serialization import ParseNodeFactoryRegistry from kiota_authentication_azure.azure_identity_authentication_provider import ( AzureIdentityAuthenticationProvider, ) from kiota_http.httpx_request_adapter import HttpxRequestAdapter from kiota_http.middleware.options import ResponseHandlerOption +from kiota_serialization_json.json_parse_node_factory import JsonParseNodeFactory +from kiota_serialization_text.text_parse_node_factory import TextParseNodeFactory from msgraph_core import APIVersion, GraphClientFactory from msgraph_core._enums import NationalClouds @@ -249,8 +252,12 @@ class KiotaRequestAdapterHook(BaseHook): scopes=scopes, allowed_hosts=allowed_hosts, ) + parse_node_factory = ParseNodeFactoryRegistry() + parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["text/plain"] = TextParseNodeFactory() + parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["application/json"] = JsonParseNodeFactory() request_adapter = HttpxRequestAdapter( authentication_provider=auth_provider, + parse_node_factory=parse_node_factory, http_client=http_client, base_url=base_url, ) diff --git a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py index b3d14b14a5..0d187ebd51 100644 --- a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py +++ b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py @@ -44,6 +44,14 @@ if TYPE_CHECKING: from airflow.utils.context import Context +def default_event_handler(context: Context, event: dict[Any, Any] | None = None) -> Any: + if event: + if event.get("status") == "failure": + raise AirflowException(event.get("message")) + + return event.get("response") + + class MSGraphAsyncOperator(BaseOperator): """ A Microsoft Graph API operator which allows you to execute REST call to the Microsoft Graph API. @@ -69,6 +77,9 @@ class MSGraphAsyncOperator(BaseOperator): :param result_processor: Function to further process the response from MS Graph API (default is lambda: context, response: response). When the response returned by the `KiotaRequestAdapterHook` are bytes, then those will be base64 encoded into a string. + :param event_handler: Function to process the event returned from `MSGraphTrigger`. By default, when the + event returned by the `MSGraphTrigger` has a failed status, an AirflowException is being raised with + the message from the event, otherwise the response from the event payload is returned. :param serializer: Class which handles response serialization (default is ResponseSerializer). Bytes will be base64 encoded into a string, so it can be stored as an XCom. """ @@ -102,6 +113,7 @@ class MSGraphAsyncOperator(BaseOperator): api_version: APIVersion | str | None = None, pagination_function: Callable[[MSGraphAsyncOperator, dict, Context], tuple[str, dict]] | None = None, result_processor: Callable[[Context, Any], Any] = lambda context, result: result, + event_handler: Callable[[Context, dict[Any, Any] | None], Any] | None = None, serializer: type[ResponseSerializer] = ResponseSerializer, **kwargs: Any, ): @@ -121,6 +133,7 @@ class MSGraphAsyncOperator(BaseOperator): self.api_version = api_version self.pagination_function = pagination_function or self.paginate self.result_processor = result_processor + self.event_handler = event_handler or default_event_handler self.serializer: ResponseSerializer = serializer() def execute(self, context: Context) -> None: @@ -158,10 +171,7 @@ class MSGraphAsyncOperator(BaseOperator): if event: self.log.debug("%s completed with %s: %s", self.task_id, event.get("status"), event) - if event.get("status") == "failure": - raise AirflowException(event.get("message")) - - response = event.get("response") + response = self.event_handler(context, event) self.log.debug("response: %s", response) diff --git a/providers/src/airflow/providers/microsoft/azure/provider.yaml b/providers/src/airflow/providers/microsoft/azure/provider.yaml index cf0b3f75ef..c4831a641b 100644 --- a/providers/src/airflow/providers/microsoft/azure/provider.yaml +++ b/providers/src/airflow/providers/microsoft/azure/provider.yaml @@ -111,6 +111,8 @@ dependencies: # msgraph-core has transient import failures with microsoft-kiota-http==1.3.4 # See https://github.com/microsoftgraph/msgraph-sdk-python-core/issues/706 - microsoft-kiota-http>=1.3.0,!=1.3.4 + - microsoft-kiota-serialization-json==1.0.0 + - microsoft-kiota-serialization-text==1.0.0 devel-dependencies: - pywinrm diff --git a/providers/tests/microsoft/azure/hooks/test_msgraph.py b/providers/tests/microsoft/azure/hooks/test_msgraph.py index 0ecad98548..aff5d0226a 100644 --- a/providers/tests/microsoft/azure/hooks/test_msgraph.py +++ b/providers/tests/microsoft/azure/hooks/test_msgraph.py @@ -19,11 +19,15 @@ from __future__ import annotations import asyncio from json import JSONDecodeError from typing import TYPE_CHECKING -from unittest.mock import patch +from unittest.mock import Mock, patch import pytest +from httpx import Response from kiota_http.httpx_request_adapter import HttpxRequestAdapter +from kiota_serialization_json.json_parse_node import JsonParseNode +from kiota_serialization_text.text_parse_node import TextParseNode from msgraph_core import APIVersion, NationalClouds +from opentelemetry.trace import Span from airflow.exceptions import AirflowBadRequest, AirflowException, AirflowNotFoundException from airflow.providers.microsoft.azure.hooks.msgraph import ( @@ -175,6 +179,45 @@ class TestKiotaRequestAdapterHook: assert actual == {"%24expand": "reports,users,datasets,dataflows,dashboards", "%24top": 5000} + @pytest.mark.asyncio + async def test_throw_failed_responses_with_text_plain_content_type(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + response = Mock(spec=Response) + response.headers = {"content-type": "text/plain"} + response.status_code = 429 + response.content = b"TenantThrottleThresholdExceeded" + response.is_success = False + span = Mock(spec=Span) + + actual = await hook.get_conn().get_root_parse_node(response, span, span) + + assert isinstance(actual, TextParseNode) + assert actual.get_str_value() == "TenantThrottleThresholdExceeded" + + @pytest.mark.asyncio + async def test_throw_failed_responses_with_application_json_content_type(self): + with patch( + "airflow.hooks.base.BaseHook.get_connection", + side_effect=get_airflow_connection, + ): + hook = KiotaRequestAdapterHook(conn_id="msgraph_api") + response = Mock(spec=Response) + response.headers = {"content-type": "application/json"} + response.status_code = 429 + response.content = b'{"error": {"code": "TenantThrottleThresholdExceeded"}}' + response.is_success = False + span = Mock(spec=Span) + + actual = await hook.get_conn().get_root_parse_node(response, span, span) + + assert isinstance(actual, JsonParseNode) + error_code = actual.get_child_node("error").get_child_node("code").get_str_value() + assert error_code == "TenantThrottleThresholdExceeded" + class TestResponseHandler: def test_default_response_handler_when_json(self): diff --git a/providers/tests/microsoft/azure/operators/test_msgraph.py b/providers/tests/microsoft/azure/operators/test_msgraph.py index 372152fe97..fe404e48e6 100644 --- a/providers/tests/microsoft/azure/operators/test_msgraph.py +++ b/providers/tests/microsoft/azure/operators/test_msgraph.py @@ -19,6 +19,7 @@ from __future__ import annotations import json import locale from base64 import b64encode +from typing import TYPE_CHECKING, Any import pytest @@ -35,6 +36,9 @@ from providers.tests.microsoft.conftest import ( mock_response, ) +if TYPE_CHECKING: + from airflow.utils.context import Context + class TestMSGraphAsyncOperator(Base): @pytest.mark.db_test @@ -101,6 +105,32 @@ class TestMSGraphAsyncOperator(Base): with pytest.raises(AirflowException): self.execute_operator(operator) + @pytest.mark.db_test + def test_execute_when_an_exception_occurs_on_custom_event_handler(self): + with self.patch_hook_and_request_adapter(AirflowException("An error occurred")): + + def custom_event_handler(context: Context, event: dict[Any, Any] | None = None): + if event: + if event.get("status") == "failure": + return None + + return event.get("response") + + operator = MSGraphAsyncOperator( + task_id="users_delta", + conn_id="msgraph_api", + url="users/delta", + event_handler=custom_event_handler, + ) + + results, events = self.execute_operator(operator) + + assert not results + assert len(events) == 1 + assert isinstance(events[0], TriggerEvent) + assert events[0].payload["status"] == "failure" + assert events[0].payload["message"] == "An error occurred" + @pytest.mark.db_test def test_execute_when_response_is_bytes(self): content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)