SameerMesiah97 commented on code in PR #62401:
URL: https://github.com/apache/airflow/pull/62401#discussion_r2899750284
##########
providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py:
##########
@@ -334,6 +361,100 @@ def hook(self) -> AsyncKubernetesHook:
def pod_manager(self) -> AsyncPodManager:
return AsyncPodManager(async_hook=self.hook)
+ if not AIRFLOW_V_3_0_PLUS:
+
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ """Get the task instance for this trigger from the database
(Airflow 2.x only)."""
+ task_instance = session.scalar(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ )
+ if task_instance is None:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
Review Comment:
Avoid using `AirflowException` in provider packages as it is a core
exception. You could use `RuntimeError` instead.
##########
providers/cncf/kubernetes/src/airflow/providers/cncf/kubernetes/triggers/pod.py:
##########
@@ -334,6 +361,100 @@ def hook(self) -> AsyncKubernetesHook:
def pod_manager(self) -> AsyncPodManager:
return AsyncPodManager(async_hook=self.hook)
+ if not AIRFLOW_V_3_0_PLUS:
+
+ @provide_session
+ def get_task_instance(self, session: Session) -> TaskInstance:
+ """Get the task instance for this trigger from the database
(Airflow 2.x only)."""
+ task_instance = session.scalar(
+ select(TaskInstance).where(
+ TaskInstance.dag_id == self.task_instance.dag_id,
+ TaskInstance.task_id == self.task_instance.task_id,
+ TaskInstance.run_id == self.task_instance.run_id,
+ TaskInstance.map_index == self.task_instance.map_index,
+ )
+ )
+ if task_instance is None:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
+ self.task_instance.dag_id,
+ self.task_instance.task_id,
+ self.task_instance.run_id,
+ self.task_instance.map_index,
+ )
+ return task_instance
+
+ async def get_task_state(self):
+ """Get the current state of the task instance."""
+ if AIRFLOW_V_3_0_PLUS:
+ from airflow.sdk.execution_time.task_runner import
RuntimeTaskInstance
+
+ task_states_response = await
sync_to_async(RuntimeTaskInstance.get_task_states)(
+ dag_id=self.task_instance.dag_id,
+ task_ids=[self.task_instance.task_id],
+ run_ids=[self.task_instance.run_id],
+ map_index=self.task_instance.map_index,
+ )
+ try:
+ return
task_states_response[self.task_instance.run_id][self.task_instance.task_id]
+ except Exception:
+ raise AirflowException(
+ "TaskInstance with dag_id: %s, task_id: %s, run_id: %s and
map_index: %s is not found",
Review Comment:
Same here. And would it make sense to narrow this `Exception` to the
expected failure types (e.g. API/DB errors)? Catching all exceptions could mask
unrelated bugs.
##########
providers/cncf/kubernetes/tests/unit/cncf/kubernetes/triggers/test_pod.py:
##########
@@ -555,3 +558,121 @@ async def test__get_pod_retries(
with context:
await trigger._get_pod()
assert mock_hook.get_pod.call_count == call_count
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{TRIGGER_PATH}.hook")
+ async def test_cleanup_does_not_delete_when_fired_event(self, mock_hook):
+ trigger = KubernetesPodTrigger(
+ pod_name=POD_NAME,
+ pod_namespace=NAMESPACE,
+ base_container_name=BASE_CONTAINER_NAME,
+ trigger_start_time=TRIGGER_START_TIME,
+ schedule_timeout=STARTUP_TIMEOUT_SECS,
+ on_kill_action="delete_pod",
+ on_finish_action="delete_pod",
+ )
+ trigger._fired_event = True
+ await trigger.cleanup()
+ mock_hook.delete_pod.assert_not_called()
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{TRIGGER_PATH}.hook")
+ async def test_cleanup_does_not_delete_when_on_kill_action_keep_pod(self,
mock_hook):
+ trigger = KubernetesPodTrigger(
+ pod_name=POD_NAME,
+ pod_namespace=NAMESPACE,
+ base_container_name=BASE_CONTAINER_NAME,
+ trigger_start_time=TRIGGER_START_TIME,
+ schedule_timeout=STARTUP_TIMEOUT_SECS,
+ on_kill_action="keep_pod",
+ on_finish_action="delete_pod",
+ )
+ await trigger.cleanup()
+ mock_hook.delete_pod.assert_not_called()
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{TRIGGER_PATH}.safe_to_cancel", new_callable=mock.AsyncMock,
return_value=False)
+ @mock.patch(f"{TRIGGER_PATH}.hook")
+ async def test_cleanup_does_not_delete_during_triggerer_restart(self,
mock_hook, mock_safe):
+ trigger = KubernetesPodTrigger(
+ pod_name=POD_NAME,
+ pod_namespace=NAMESPACE,
+ base_container_name=BASE_CONTAINER_NAME,
+ trigger_start_time=TRIGGER_START_TIME,
+ schedule_timeout=STARTUP_TIMEOUT_SECS,
+ on_kill_action="delete_pod",
+ on_finish_action="delete_pod",
+ )
+ await trigger.cleanup()
+ mock_hook.delete_pod.assert_not_called()
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{TRIGGER_PATH}.safe_to_cancel", new_callable=mock.AsyncMock,
return_value=True)
+ @mock.patch(f"{TRIGGER_PATH}.hook")
+ async def test_cleanup_deletes_pod_on_manual_mark(self, mock_hook,
mock_safe):
+ trigger = KubernetesPodTrigger(
+ pod_name=POD_NAME,
+ pod_namespace=NAMESPACE,
+ base_container_name=BASE_CONTAINER_NAME,
+ trigger_start_time=TRIGGER_START_TIME,
+ schedule_timeout=STARTUP_TIMEOUT_SECS,
+ on_kill_action="delete_pod",
+ on_finish_action="delete_pod",
+ )
+ mock_hook.delete_pod = mock.AsyncMock()
+ await trigger.cleanup()
+ mock_hook.delete_pod.assert_called_once_with(
+ name=POD_NAME,
+ namespace=NAMESPACE,
+ grace_period_seconds=None,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{TRIGGER_PATH}.safe_to_cancel", new_callable=mock.AsyncMock,
return_value=True)
+ @mock.patch(f"{TRIGGER_PATH}.hook")
+ async def
test_cleanup_deletes_pod_even_when_on_finish_action_keep_pod(self, mock_hook,
mock_safe):
+ """on_finish_action is not consulted during kill -- on_kill_action is
the sole control."""
+ trigger = KubernetesPodTrigger(
+ pod_name=POD_NAME,
+ pod_namespace=NAMESPACE,
+ base_container_name=BASE_CONTAINER_NAME,
+ trigger_start_time=TRIGGER_START_TIME,
+ schedule_timeout=STARTUP_TIMEOUT_SECS,
+ on_kill_action="delete_pod",
+ on_finish_action="keep_pod",
+ )
+ mock_hook.delete_pod = mock.AsyncMock()
+ await trigger.cleanup()
+ mock_hook.delete_pod.assert_called_once_with(
+ name=POD_NAME,
+ namespace=NAMESPACE,
+ grace_period_seconds=None,
+ )
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{TRIGGER_PATH}.get_task_state", new_callable=mock.AsyncMock)
+ async def test_safe_to_cancel_returns_true_when_task_not_deferred(self,
mock_get_state):
+ """safe_to_cancel should return True when the task is no longer
DEFERRED (e.g. user marked success)."""
+ mock_get_state.return_value = TaskInstanceState.SUCCESS
+ trigger = KubernetesPodTrigger(
+ pod_name=POD_NAME,
+ pod_namespace=NAMESPACE,
+ base_container_name=BASE_CONTAINER_NAME,
+ trigger_start_time=TRIGGER_START_TIME,
+ schedule_timeout=STARTUP_TIMEOUT_SECS,
+ )
+ assert await trigger.safe_to_cancel() is True
+
+ @pytest.mark.asyncio
+ @mock.patch(f"{TRIGGER_PATH}.get_task_state", new_callable=mock.AsyncMock)
+ async def test_safe_to_cancel_returns_false_when_task_still_deferred(self,
mock_get_state):
+ """safe_to_cancel should return False when the task is still DEFERRED
(triggerer restart)."""
+ mock_get_state.return_value = TaskInstanceState.DEFERRED
+ trigger = KubernetesPodTrigger(
+ pod_name=POD_NAME,
+ pod_namespace=NAMESPACE,
+ base_container_name=BASE_CONTAINER_NAME,
+ trigger_start_time=TRIGGER_START_TIME,
+ schedule_timeout=STARTUP_TIMEOUT_SECS,
+ )
+ assert await trigger.safe_to_cancel() is False
Review Comment:
Tests look good overall. But I think you should add a test for the branch
where `safe_to_cancel()` raises an exception to ensure cleanup skips pod
deletion as intended.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]