This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 2b101e2377 Feature: Added event_handler parameter in 
MSGraphAsyncOperator (#42539)
2b101e2377 is described below

commit 2b101e2377f8d49a46aca6c219e4b38ee099a98d
Author: David Blain <i...@dabla.be>
AuthorDate: Tue Oct 15 02:41:36 2024 +0200

    Feature: Added event_handler parameter in MSGraphAsyncOperator (#42539)
    
    * refactor: Added parameter in MSGraphAsyncOperator to allow overriding 
default event_handler
    
    * docs: Added docstring for event_handler parameter in MSGraphAsyncOperator
    
    * refactor: Fixed TestMSGraphAsyncOperator
    
    * refactor: Check if event is not None
    
    * refactor: Register the TextParseNodeFactory and JsonParseNodeFactory so 
error messages get handled correctly in RequestAdapter
    
    * refactor: Reorganized import for TestMSGraphAsyncOperator
    
    * refactor: Added missing kiota-serialization packages in azure provider
    
    * refactor: Updated provider dependencies
    
    * refactor: Reorganized import of TestKiotaRequestAdapterHook
    
    * refactor: Downgraded version of json kiota serialization
    
    * refactor: Updated provider dependencies
    
    * refactor: Put import of Context in TYPE_CHECKING block
    
    * refactor: Fixed lookup of tenant-id
    
    * refactor: Fixed kiota serialization dependencies to 1.0.0 to avoid 
pendulum dependency issues for backward compatibility
    
    * refactor: Updated provider dependencies
    
    * refactored: Fixed import of test_utils in test_dag_run
    
    ---------
    
    Co-authored-by: David Blain <david.bl...@infrabel.be>
---
 generated/provider_dependencies.json               |  2 +
 .../providers/microsoft/azure/hooks/msgraph.py     |  7 ++++
 .../providers/microsoft/azure/operators/msgraph.py | 18 +++++++--
 .../providers/microsoft/azure/provider.yaml        |  2 +
 .../tests/microsoft/azure/hooks/test_msgraph.py    | 45 +++++++++++++++++++++-
 .../microsoft/azure/operators/test_msgraph.py      | 30 +++++++++++++++
 6 files changed, 99 insertions(+), 5 deletions(-)

diff --git a/generated/provider_dependencies.json 
b/generated/provider_dependencies.json
index 4d921fc1fb..8efdd5eae7 100644
--- a/generated/provider_dependencies.json
+++ b/generated/provider_dependencies.json
@@ -828,6 +828,8 @@
       "azure-synapse-artifacts>=0.17.0",
       "azure-synapse-spark>=0.2.0",
       "microsoft-kiota-http>=1.3.0,!=1.3.4",
+      "microsoft-kiota-serialization-json==1.0.0",
+      "microsoft-kiota-serialization-text==1.0.0",
       "msgraph-core>=1.0.0"
     ],
     "devel-deps": [
diff --git a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py 
b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
index 61e555f4ca..4ab3aaf3ba 100644
--- a/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
+++ b/providers/src/airflow/providers/microsoft/azure/hooks/msgraph.py
@@ -32,11 +32,14 @@ from kiota_abstractions.api_error import APIError
 from kiota_abstractions.method import Method
 from kiota_abstractions.request_information import RequestInformation
 from kiota_abstractions.response_handler import ResponseHandler
+from kiota_abstractions.serialization import ParseNodeFactoryRegistry
 from kiota_authentication_azure.azure_identity_authentication_provider import (
     AzureIdentityAuthenticationProvider,
 )
 from kiota_http.httpx_request_adapter import HttpxRequestAdapter
 from kiota_http.middleware.options import ResponseHandlerOption
+from kiota_serialization_json.json_parse_node_factory import 
JsonParseNodeFactory
+from kiota_serialization_text.text_parse_node_factory import 
TextParseNodeFactory
 from msgraph_core import APIVersion, GraphClientFactory
 from msgraph_core._enums import NationalClouds
 
@@ -249,8 +252,12 @@ class KiotaRequestAdapterHook(BaseHook):
                 scopes=scopes,
                 allowed_hosts=allowed_hosts,
             )
+            parse_node_factory = ParseNodeFactoryRegistry()
+            parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["text/plain"] 
= TextParseNodeFactory()
+            
parse_node_factory.CONTENT_TYPE_ASSOCIATED_FACTORIES["application/json"] = 
JsonParseNodeFactory()
             request_adapter = HttpxRequestAdapter(
                 authentication_provider=auth_provider,
+                parse_node_factory=parse_node_factory,
                 http_client=http_client,
                 base_url=base_url,
             )
diff --git 
a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py 
b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py
index b3d14b14a5..0d187ebd51 100644
--- a/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py
+++ b/providers/src/airflow/providers/microsoft/azure/operators/msgraph.py
@@ -44,6 +44,14 @@ if TYPE_CHECKING:
     from airflow.utils.context import Context
 
 
+def default_event_handler(context: Context, event: dict[Any, Any] | None = 
None) -> Any:
+    if event:
+        if event.get("status") == "failure":
+            raise AirflowException(event.get("message"))
+
+        return event.get("response")
+
+
 class MSGraphAsyncOperator(BaseOperator):
     """
     A Microsoft Graph API operator which allows you to execute REST call to 
the Microsoft Graph API.
@@ -69,6 +77,9 @@ class MSGraphAsyncOperator(BaseOperator):
     :param result_processor: Function to further process the response from MS 
Graph API
         (default is lambda: context, response: response).  When the response 
returned by the
         `KiotaRequestAdapterHook` are bytes, then those will be base64 encoded 
into a string.
+    :param event_handler: Function to process the event returned from 
`MSGraphTrigger`.  By default, when the
+        event returned by the `MSGraphTrigger` has a failed status, an 
AirflowException is being raised with
+        the message from the event, otherwise the response from the event 
payload is returned.
     :param serializer: Class which handles response serialization (default is 
ResponseSerializer).
         Bytes will be base64 encoded into a string, so it can be stored as an 
XCom.
     """
@@ -102,6 +113,7 @@ class MSGraphAsyncOperator(BaseOperator):
         api_version: APIVersion | str | None = None,
         pagination_function: Callable[[MSGraphAsyncOperator, dict, Context], 
tuple[str, dict]] | None = None,
         result_processor: Callable[[Context, Any], Any] = lambda context, 
result: result,
+        event_handler: Callable[[Context, dict[Any, Any] | None], Any] | None 
= None,
         serializer: type[ResponseSerializer] = ResponseSerializer,
         **kwargs: Any,
     ):
@@ -121,6 +133,7 @@ class MSGraphAsyncOperator(BaseOperator):
         self.api_version = api_version
         self.pagination_function = pagination_function or self.paginate
         self.result_processor = result_processor
+        self.event_handler = event_handler or default_event_handler
         self.serializer: ResponseSerializer = serializer()
 
     def execute(self, context: Context) -> None:
@@ -158,10 +171,7 @@ class MSGraphAsyncOperator(BaseOperator):
         if event:
             self.log.debug("%s completed with %s: %s", self.task_id, 
event.get("status"), event)
 
-            if event.get("status") == "failure":
-                raise AirflowException(event.get("message"))
-
-            response = event.get("response")
+            response = self.event_handler(context, event)
 
             self.log.debug("response: %s", response)
 
diff --git a/providers/src/airflow/providers/microsoft/azure/provider.yaml 
b/providers/src/airflow/providers/microsoft/azure/provider.yaml
index cf0b3f75ef..c4831a641b 100644
--- a/providers/src/airflow/providers/microsoft/azure/provider.yaml
+++ b/providers/src/airflow/providers/microsoft/azure/provider.yaml
@@ -111,6 +111,8 @@ dependencies:
   # msgraph-core has transient import failures with microsoft-kiota-http==1.3.4
   # See https://github.com/microsoftgraph/msgraph-sdk-python-core/issues/706
   - microsoft-kiota-http>=1.3.0,!=1.3.4
+  - microsoft-kiota-serialization-json==1.0.0
+  - microsoft-kiota-serialization-text==1.0.0
 
 devel-dependencies:
   - pywinrm
diff --git a/providers/tests/microsoft/azure/hooks/test_msgraph.py 
b/providers/tests/microsoft/azure/hooks/test_msgraph.py
index 0ecad98548..aff5d0226a 100644
--- a/providers/tests/microsoft/azure/hooks/test_msgraph.py
+++ b/providers/tests/microsoft/azure/hooks/test_msgraph.py
@@ -19,11 +19,15 @@ from __future__ import annotations
 import asyncio
 from json import JSONDecodeError
 from typing import TYPE_CHECKING
-from unittest.mock import patch
+from unittest.mock import Mock, patch
 
 import pytest
+from httpx import Response
 from kiota_http.httpx_request_adapter import HttpxRequestAdapter
+from kiota_serialization_json.json_parse_node import JsonParseNode
+from kiota_serialization_text.text_parse_node import TextParseNode
 from msgraph_core import APIVersion, NationalClouds
+from opentelemetry.trace import Span
 
 from airflow.exceptions import AirflowBadRequest, AirflowException, 
AirflowNotFoundException
 from airflow.providers.microsoft.azure.hooks.msgraph import (
@@ -175,6 +179,45 @@ class TestKiotaRequestAdapterHook:
 
         assert actual == {"%24expand": 
"reports,users,datasets,dataflows,dashboards", "%24top": 5000}
 
+    @pytest.mark.asyncio
+    async def test_throw_failed_responses_with_text_plain_content_type(self):
+        with patch(
+            "airflow.hooks.base.BaseHook.get_connection",
+            side_effect=get_airflow_connection,
+        ):
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+            response = Mock(spec=Response)
+            response.headers = {"content-type": "text/plain"}
+            response.status_code = 429
+            response.content = b"TenantThrottleThresholdExceeded"
+            response.is_success = False
+            span = Mock(spec=Span)
+
+            actual = await hook.get_conn().get_root_parse_node(response, span, 
span)
+
+            assert isinstance(actual, TextParseNode)
+            assert actual.get_str_value() == "TenantThrottleThresholdExceeded"
+
+    @pytest.mark.asyncio
+    async def 
test_throw_failed_responses_with_application_json_content_type(self):
+        with patch(
+            "airflow.hooks.base.BaseHook.get_connection",
+            side_effect=get_airflow_connection,
+        ):
+            hook = KiotaRequestAdapterHook(conn_id="msgraph_api")
+            response = Mock(spec=Response)
+            response.headers = {"content-type": "application/json"}
+            response.status_code = 429
+            response.content = b'{"error": {"code": 
"TenantThrottleThresholdExceeded"}}'
+            response.is_success = False
+            span = Mock(spec=Span)
+
+            actual = await hook.get_conn().get_root_parse_node(response, span, 
span)
+
+            assert isinstance(actual, JsonParseNode)
+            error_code = 
actual.get_child_node("error").get_child_node("code").get_str_value()
+            assert error_code == "TenantThrottleThresholdExceeded"
+
 
 class TestResponseHandler:
     def test_default_response_handler_when_json(self):
diff --git a/providers/tests/microsoft/azure/operators/test_msgraph.py 
b/providers/tests/microsoft/azure/operators/test_msgraph.py
index 372152fe97..fe404e48e6 100644
--- a/providers/tests/microsoft/azure/operators/test_msgraph.py
+++ b/providers/tests/microsoft/azure/operators/test_msgraph.py
@@ -19,6 +19,7 @@ from __future__ import annotations
 import json
 import locale
 from base64 import b64encode
+from typing import TYPE_CHECKING, Any
 
 import pytest
 
@@ -35,6 +36,9 @@ from providers.tests.microsoft.conftest import (
     mock_response,
 )
 
+if TYPE_CHECKING:
+    from airflow.utils.context import Context
+
 
 class TestMSGraphAsyncOperator(Base):
     @pytest.mark.db_test
@@ -101,6 +105,32 @@ class TestMSGraphAsyncOperator(Base):
             with pytest.raises(AirflowException):
                 self.execute_operator(operator)
 
+    @pytest.mark.db_test
+    def test_execute_when_an_exception_occurs_on_custom_event_handler(self):
+        with self.patch_hook_and_request_adapter(AirflowException("An error 
occurred")):
+
+            def custom_event_handler(context: Context, event: dict[Any, Any] | 
None = None):
+                if event:
+                    if event.get("status") == "failure":
+                        return None
+
+                    return event.get("response")
+
+            operator = MSGraphAsyncOperator(
+                task_id="users_delta",
+                conn_id="msgraph_api",
+                url="users/delta",
+                event_handler=custom_event_handler,
+            )
+
+            results, events = self.execute_operator(operator)
+
+            assert not results
+            assert len(events) == 1
+            assert isinstance(events[0], TriggerEvent)
+            assert events[0].payload["status"] == "failure"
+            assert events[0].payload["message"] == "An error occurred"
+
     @pytest.mark.db_test
     def test_execute_when_response_is_bytes(self):
         content = load_file("resources", "dummy.pdf", mode="rb", encoding=None)

Reply via email to