This is an automated email from the ASF dual-hosted git repository.

uranusjr pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 6549b1798f6 Remove zombie ti from its executor (#42932)
6549b1798f6 is described below

commit 6549b1798f6e682b5db6060d9aa89dd2e155ffb5
Author: Tzu-ping Chung <uranu...@gmail.com>
AuthorDate: Wed Oct 16 10:32:57 2024 +0800

    Remove zombie ti from its executor (#42932)
---
 airflow/jobs/scheduler_job_runner.py        | 119 +++++++++++++++-------------
 airflow/models/taskinstance.py              |   7 +-
 docs/apache-airflow/core-concepts/tasks.rst |   4 +-
 tests/jobs/test_scheduler_job.py            |  68 ++++++++--------
 tests_common/test_utils/mock_executor.py    |   5 +-
 5 files changed, 105 insertions(+), 98 deletions(-)

diff --git a/airflow/jobs/scheduler_job_runner.py 
b/airflow/jobs/scheduler_job_runner.py
index 30f58885a9a..951602e14e6 100644
--- a/airflow/jobs/scheduler_job_runner.py
+++ b/airflow/jobs/scheduler_job_runner.py
@@ -1077,7 +1077,7 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
 
         timers.call_regular_interval(
             conf.getfloat("scheduler", "zombie_detection_interval", 
fallback=10.0),
-            self._find_zombies,
+            self._find_and_purge_zombies,
         )
 
         timers.call_regular_interval(60.0, 
self._update_dag_run_state_for_paused_dags)
@@ -1953,73 +1953,80 @@ class SchedulerJobRunner(BaseJobRunner, LoggingMixin):
                 if num_timed_out_tasks:
                     self.log.info("Timed out %i deferred tasks without fired 
triggers", num_timed_out_tasks)
 
-    # [START find_zombies]
-    def _find_zombies(self) -> None:
+    # [START find_and_purge_zombies]
+    def _find_and_purge_zombies(self) -> None:
         """
-        Find zombie task instances and create a TaskCallbackRequest to be 
handled by the DAG processor.
+        Find and purge zombie task instances.
 
-        Zombie instances are tasks haven't heartbeated for too long or have a 
no-longer-running LocalTaskJob.
+        Zombie instances are tasks that failed to heartbeat for too long, or
+        have a no-longer-running LocalTaskJob.
+
+        A TaskCallbackRequest is also created for the killed zombie to be
+        handled by the DAG processor, and the executor is informed to no longer
+        count the zombie as running when it calculates parallelism.
         """
+        with create_session() as session:
+            if zombies := self._find_zombies(session=session):
+                self._purge_zombies(zombies, session=session)
+
+    def _find_zombies(self, *, session: Session) -> list[tuple[TI, str, str]]:
         from airflow.jobs.job import Job
 
         self.log.debug("Finding 'running' jobs without a recent heartbeat")
         limit_dttm = timezone.utcnow() - 
timedelta(seconds=self._zombie_threshold_secs)
-
-        with create_session() as session:
-            zombies: list[tuple[TI, str, str]] = (
-                session.execute(
-                    select(TI, DM.fileloc, DM.processor_subdir)
-                    .with_hint(TI, "USE INDEX (ti_state)", 
dialect_name="mysql")
-                    .join(Job, TI.job_id == Job.id)
-                    .join(DM, TI.dag_id == DM.dag_id)
-                    .where(TI.state == TaskInstanceState.RUNNING)
-                    .where(
-                        or_(
-                            Job.state != JobState.RUNNING,
-                            Job.latest_heartbeat < limit_dttm,
-                        )
-                    )
-                    .where(Job.job_type == "LocalTaskJob")
-                    .where(TI.queued_by_job_id == self.job.id)
-                )
-                .unique()
-                .all()
+        zombies = (
+            session.execute(
+                select(TI, DM.fileloc, DM.processor_subdir)
+                .with_hint(TI, "USE INDEX (ti_state)", dialect_name="mysql")
+                .join(Job, TI.job_id == Job.id)
+                .join(DM, TI.dag_id == DM.dag_id)
+                .where(TI.state == TaskInstanceState.RUNNING)
+                .where(or_(Job.state != JobState.RUNNING, Job.latest_heartbeat 
< limit_dttm))
+                .where(Job.job_type == "LocalTaskJob")
+                .where(TI.queued_by_job_id == self.job.id)
             )
-
+            .unique()
+            .all()
+        )
         if zombies:
             self.log.warning("Failing (%s) jobs without heartbeat after %s", 
len(zombies), limit_dttm)
-
-        with create_session() as session:
-            for ti, file_loc, processor_subdir in zombies:
-                zombie_message_details = 
self._generate_zombie_message_details(ti)
-                request = TaskCallbackRequest(
-                    full_filepath=file_loc,
-                    processor_subdir=processor_subdir,
-                    simple_task_instance=SimpleTaskInstance.from_ti(ti),
-                    msg=str(zombie_message_details),
-                )
-                session.add(
-                    Log(
-                        event="heartbeat timeout",
-                        task_instance=ti.key,
-                        extra=(
-                            f"Task did not emit heartbeat within time limit 
({self._zombie_threshold_secs} "
-                            "seconds) and will be terminated. "
-                            "See 
https://airflow.apache.org/docs/apache-airflow/";
-                            
"stable/core-concepts/tasks.html#zombie-undead-tasks"
-                        ),
-                    )
-                )
-                self.log.error(
-                    "Detected zombie job: %s "
-                    "(See https://airflow.apache.org/docs/apache-airflow/";
-                    "stable/core-concepts/tasks.html#zombie-undead-tasks)",
-                    request,
+        return zombies
+
+    def _purge_zombies(self, zombies: list[tuple[TI, str, str]], *, session: 
Session) -> None:
+        for ti, file_loc, processor_subdir in zombies:
+            zombie_message_details = self._generate_zombie_message_details(ti)
+            request = TaskCallbackRequest(
+                full_filepath=file_loc,
+                processor_subdir=processor_subdir,
+                simple_task_instance=SimpleTaskInstance.from_ti(ti),
+                msg=str(zombie_message_details),
+            )
+            session.add(
+                Log(
+                    event="heartbeat timeout",
+                    task_instance=ti.key,
+                    extra=(
+                        f"Task did not emit heartbeat within time limit 
({self._zombie_threshold_secs} "
+                        "seconds) and will be terminated. "
+                        "See https://airflow.apache.org/docs/apache-airflow/";
+                        "stable/core-concepts/tasks.html#zombie-undead-tasks"
+                    ),
                 )
-                self.job.executor.send_callback(request)
-                Stats.incr("zombies_killed", tags={"dag_id": ti.dag_id, 
"task_id": ti.task_id})
+            )
+            self.log.error(
+                "Detected zombie job: %s "
+                "(See https://airflow.apache.org/docs/apache-airflow/";
+                "stable/core-concepts/tasks.html#zombie-undead-tasks)",
+                request,
+            )
+            self.job.executor.send_callback(request)
+            if (executor := self._try_to_load_executor(ti.executor)) is None:
+                self.log.warning("Cannot clean up zombie %r with non-existent 
executor %s", ti, ti.executor)
+                continue
+            executor.change_state(ti.key, TaskInstanceState.FAILED, 
remove_running=True)
+            Stats.incr("zombies_killed", tags={"dag_id": ti.dag_id, "task_id": 
ti.task_id})
 
-    # [END find_zombies]
+    # [END find_and_purge_zombies]
 
     @staticmethod
     def _generate_zombie_message_details(ti: TI) -> dict[str, Any]:
diff --git a/airflow/models/taskinstance.py b/airflow/models/taskinstance.py
index c1373e5d6a1..1dbe299f25b 100644
--- a/airflow/models/taskinstance.py
+++ b/airflow/models/taskinstance.py
@@ -154,7 +154,6 @@ if TYPE_CHECKING:
     from sqlalchemy.sql.expression import ColumnOperators
 
     from airflow.models.abstractoperator import TaskStateChangeCallback
-    from airflow.models.asset import AssetEvent
     from airflow.models.baseoperator import BaseOperator
     from airflow.models.dag import DAG, DagModel
     from airflow.models.dagrun import DagRun
@@ -3925,7 +3924,11 @@ class SimpleTaskInstance:
         self.queue = queue
         self.key = key
 
-    def __eq__(self, other):
+    def __repr__(self) -> str:
+        attrs = ", ".join(f"{k}={v!r}" for k, v in self.__dict__.items())
+        return f"SimpleTaskInstance({attrs})"
+
+    def __eq__(self, other) -> bool:
         if isinstance(other, self.__class__):
             return self.__dict__ == other.__dict__
         return NotImplemented
diff --git a/docs/apache-airflow/core-concepts/tasks.rst 
b/docs/apache-airflow/core-concepts/tasks.rst
index 5adfe8be460..3f5a1a21b77 100644
--- a/docs/apache-airflow/core-concepts/tasks.rst
+++ b/docs/apache-airflow/core-concepts/tasks.rst
@@ -189,8 +189,8 @@ Below is the code snippet from the Airflow scheduler that 
runs periodically to d
 
 .. exampleinclude:: /../../airflow/jobs/scheduler_job_runner.py
     :language: python
-    :start-after: [START find_zombies]
-    :end-before: [END find_zombies]
+    :start-after: [START find_and_purge_zombies]
+    :end-before: [END find_and_purge_zombies]
 
 
 The explanation of the criteria used in the above snippet to detect zombie 
tasks is as below:
diff --git a/tests/jobs/test_scheduler_job.py b/tests/jobs/test_scheduler_job.py
index f35ed29e7b1..01d1c5fe7a3 100644
--- a/tests/jobs/test_scheduler_job.py
+++ b/tests/jobs/test_scheduler_job.py
@@ -2976,8 +2976,8 @@ class TestSchedulerJob:
             ti2s = tiq.filter(TaskInstance.task_id == "dummy2").all()
             assert len(ti1s) == 0
             assert len(ti2s) >= 2
-            for task in ti2s:
-                assert task.state == State.SUCCESS
+            for ti in ti2s:
+                assert ti.state == State.SUCCESS
 
     @pytest.mark.parametrize(
         "configs",
@@ -5539,36 +5539,37 @@ class TestSchedulerJob:
         with pytest.raises(OperationalError):
             check_if_trigger_timeout(1)
 
-    def test_find_zombies_nothing(self):
+    def test_find_and_purge_zombies_nothing(self):
         executor = MockExecutor(do_update=False)
         scheduler_job = Job(executor=executor)
-        self.job_runner = SchedulerJobRunner(scheduler_job)
-        self.job_runner.processor_agent = mock.MagicMock()
-
-        self.job_runner._find_zombies()
-
-        scheduler_job.executor.callback_sink.send.assert_not_called()
+        with 
mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as 
loader_mock:
+            loader_mock.return_value = executor
+            self.job_runner = SchedulerJobRunner(scheduler_job)
+            self.job_runner.processor_agent = mock.MagicMock()
+            self.job_runner._find_and_purge_zombies()
+        executor.callback_sink.send.assert_not_called()
 
-    def test_find_zombies(self, load_examples):
+    def test_find_and_purge_zombies(self, load_examples, session):
         dagbag = DagBag(TEST_DAG_FOLDER, read_dags_from_db=False)
-        with create_session() as session:
-            session.query(Job).delete()
-            dag = dagbag.get_dag("example_branch_operator")
-            dag.sync_to_db()
-            data_interval = 
dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
-            triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} 
if AIRFLOW_V_3_0_PLUS else {}
-            dag_run = dag.create_dagrun(
-                state=DagRunState.RUNNING,
-                execution_date=DEFAULT_DATE,
-                run_type=DagRunType.SCHEDULED,
-                session=session,
-                data_interval=data_interval,
-                **triggered_by_kwargs,
-            )
 
-            scheduler_job = Job()
+        dag = dagbag.get_dag("example_branch_operator")
+        dag.sync_to_db()
+        data_interval = dag.infer_automated_data_interval(DEFAULT_LOGICAL_DATE)
+        triggered_by_kwargs = {"triggered_by": DagRunTriggeredByType.TEST} if 
AIRFLOW_V_3_0_PLUS else {}
+        dag_run = dag.create_dagrun(
+            state=DagRunState.RUNNING,
+            execution_date=DEFAULT_DATE,
+            run_type=DagRunType.SCHEDULED,
+            session=session,
+            data_interval=data_interval,
+            **triggered_by_kwargs,
+        )
+
+        executor = MockExecutor()
+        scheduler_job = Job(executor=executor)
+        with 
mock.patch("airflow.executors.executor_loader.ExecutorLoader.load_executor") as 
loader_mock:
+            loader_mock.return_value = executor
             self.job_runner = SchedulerJobRunner(job=scheduler_job, 
subdir=os.devnull)
-            scheduler_job.executor = MockExecutor()
             self.job_runner.processor_agent = mock.MagicMock()
 
             # We will provision 2 tasks so we can check we only find zombies 
from this scheduler
@@ -5594,11 +5595,12 @@ class TestSchedulerJob:
 
             ti.queued_by_job_id = scheduler_job.id
             session.flush()
+            executor.running.add(ti.key)  # The executor normally does this 
during heartbeat.
+            self.job_runner._find_and_purge_zombies()
+            assert ti.key not in executor.running
 
-            self.job_runner._find_zombies()
-
-        scheduler_job.executor.callback_sink.send.assert_called_once()
-        callback_requests = 
scheduler_job.executor.callback_sink.send.call_args.args
+        executor.callback_sink.send.assert_called_once()
+        callback_requests = executor.callback_sink.send.call_args.args
         assert len(callback_requests) == 1
         callback_request = callback_requests[0]
         assert isinstance(callback_request.simple_task_instance, 
SimpleTaskInstance)
@@ -5610,10 +5612,6 @@ class TestSchedulerJob:
         assert callback_request.simple_task_instance.run_id == ti.run_id
         assert callback_request.simple_task_instance.map_index == ti.map_index
 
-        with create_session() as session:
-            session.query(TaskInstance).delete()
-            session.query(Job).delete()
-
     def test_zombie_message(self, load_examples):
         """
         Check that the zombie message comes out as expected
@@ -5729,7 +5727,7 @@ class TestSchedulerJob:
         scheduler_job.executor = MockExecutor()
         self.job_runner.processor_agent = mock.MagicMock()
 
-        self.job_runner._find_zombies()
+        self.job_runner._find_and_purge_zombies()
 
         scheduler_job.executor.callback_sink.send.assert_called_once()
 
diff --git a/tests_common/test_utils/mock_executor.py 
b/tests_common/test_utils/mock_executor.py
index 506c0447589..6d4791e8891 100644
--- a/tests_common/test_utils/mock_executor.py
+++ b/tests_common/test_utils/mock_executor.py
@@ -36,7 +36,6 @@ class MockExecutor(BaseExecutor):
 
     def __init__(self, do_update=True, *args, **kwargs):
         self.do_update = do_update
-        self._running = []
         self.callback_sink = MagicMock()
 
         # A list of "batches" of tasks
@@ -88,8 +87,8 @@ class MockExecutor(BaseExecutor):
     def end(self):
         self.sync()
 
-    def change_state(self, key, state, info=None):
-        super().change_state(key, state, info=info)
+    def change_state(self, key, state, info=None, remove_running=False):
+        super().change_state(key, state, info=info, 
remove_running=remove_running)
         # The normal event buffer is cleared after reading, we want to keep
         # a list of all events for testing
         self.sorted_tasks.append((key, (state, info)))

Reply via email to