chinwobble commented on issue #18999:
URL: https://github.com/apache/airflow/issues/18999#issuecomment-953359692


   @eskarimov I have implemented a prototype like this:
   
   I'm sure many improvements could be made but this should work.
   ```python
   
   # pylint: disable=abstract-method
   class DatabricksHookAsync(DatabricksHook):
       """Async version of the databricks hook"""
   
       async def get_run_state_async(
           self, run_id: str, session: ClientSession
       ) -> RunState:
           json = {"run_id": run_id}
           response = await self._do_api_call_async(GET_RUN_ENDPOINT, json, 
session)
           state = response["state"]
           life_cycle_state = state["life_cycle_state"]
           # result_state may not be in the state if not terminal
           result_state = state.get("result_state", None)
           state_message = state["state_message"]
           return RunState(life_cycle_state, result_state, state_message)
   
       async def _do_api_call_async(self, endpoint_info, json, session: 
ClientSession):
           """
           Utility function to perform an API call with retries
           :param endpoint_info: Tuple of method and endpoint
           :type endpoint_info: tuple[string, string]
           :param json: Parameters for this API call.
           :type json: dict
           :return: If the api call returns a OK status code,
               this function returns the response in JSON. Otherwise,
               we throw an AirflowException.
           :rtype: dict
           """
           method, endpoint = endpoint_info
   
           self.databricks_conn = self.get_connection(self.databricks_conn_id)
   
           if "token" in self.databricks_conn.extra_dejson:
               self.log.info("Using token auth. ")
               auth = {
                   "Authorization": "Bearer " + 
self.databricks_conn.extra_dejson["token"]
               }
               if "host" in self.databricks_conn.extra_dejson:
                   host = 
self._parse_host(self.databricks_conn.extra_dejson["host"])
               else:
                   host = self.databricks_conn.host
           else:
               raise AirflowException("DatabricksHookAsync only supports token 
Auth")
   
           url = f"https://{self._parse_host(host)}/{endpoint}"  # type: ignore
   
           if method == "GET":
               request_func = session.get
           elif method == "POST":
               request_func = session.post
           elif method == "PATCH":
               request_func = session.patch
           else:
               raise AirflowException("Unexpected HTTP Method: " + method)
   
           attempt_num = 1
           while True:
               try:
                   response = await request_func(
                       url,
                       json=json if method in ("POST", "PATCH") else None,
                       params=json if method == "GET" else None,
                       headers=auth,
                       timeout=self.timeout_seconds,
                   )
                   response.raise_for_status()
                   return await response.json()
               except ClientResponseError as err:
                   if err.status < 500:
                       # In this case, the user probably made a mistake.
                       # Don't retry.
                       # pylint: disable=raise-missing-from
                       raise AirflowException(
                           f"Response: {err.message}, Status Code: {err.status}"
                       )
   
               if attempt_num == self.retry_limit:
                   raise AirflowException(
                       (
                           "API requests to Databricks failed {} times. " + 
"Giving up."
                       ).format(self.retry_limit)
                   )
   
               attempt_num += 1
               await asyncio.sleep(self.retry_delay)
   
   
   class DatabricksJobTrigger(BaseTrigger):
       """A trigger that checks every 15 seconds whether the databricks job is 
finished"""
   
       def __init__(self, run_id: str, databricks_conn_id):
           super().__init__()
           self.run_id = run_id
           self.databricks_conn_id = databricks_conn_id
   
       def serialize(self) -> typing.Tuple[str, typing.Dict[str, typing.Any]]:
           return (
               "operators.submit_to_databricks_operator.DatabricksJobTrigger",
               {
                   "run_id": self.run_id,
                   "databricks_conn_id": self.databricks_conn_id,
               },
           )
   
       async def run(self):
           hook = DatabricksHookAsync(self.databricks_conn_id)
           async with aiohttp.ClientSession() as session:
               while True:
                   run_state = await hook.get_run_state_async(self.run_id, 
session)
                   if run_state.is_terminal:
                       if run_state.is_successful:
                           self.log.info("Run id: %s completed successfully.", 
self.run_id)
                       else:
                           self.log.info("Run id: %s completed and failed.", 
self.run_id)
                       yield TriggerEvent((self.run_id, run_state.result_state))
                   await asyncio.sleep(15)
   ```


-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]


Reply via email to