ferruzzi commented on code in PR #63888:
URL: https://github.com/apache/airflow/pull/63888#discussion_r2997305034


##########
providers/celery/src/airflow/providers/celery/executors/celery_executor.py:
##########
@@ -102,7 +107,7 @@ class CeleryExecutor(BaseExecutor):
     if TYPE_CHECKING:
         if AIRFLOW_V_3_0_PLUS:
             # TODO: TaskSDK: move this type change into BaseExecutor
-            queued_tasks: dict[TaskInstanceKey, workloads.All]  # type: 
ignore[assignment]
+            queued_tasks: dict[WorkloadKey, workloads.All]  # type: 
ignore[assignment]

Review Comment:
   I think there's a miss here.  You are importing WorkloadKey if version is 
over 3.2 up above, but using it here if airflow version is 3.0.  What about 
using this as the import block?
   
   ```
   # Remove this conditional once min version > 3.2
   try:
       from airflow.executors.workloads.types import WorkloadKey
   except ImportError:
       from airflow.models.taskinstancekey import TaskInstanceKey as WorkloadKey
   ```
   
   I know there's some community debate over using try/catch on imports, but I 
think this feels like the right time to use one.
   
   ((I think the same comment goes for celery_executor_utils.py as well))



##########
providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py:
##########
@@ -369,27 +369,23 @@ def send_workload_to_executor(
     try:
         import redis.client  # noqa: F401
     except ImportError:
-        pass  # Redis not installed or not using Redis backend
+        pass  # Redis not installed or not using Redis backend.
 
     try:
         with timeout(seconds=OPERATION_TIMEOUT):
-            result = task_to_run.apply_async(args=args, queue=queue)
+            result = celery_task.apply_async(args=args, queue=queue)
     except (Exception, AirflowTaskTimeout) as e:
         exception_traceback = f"Celery Task ID: 
{key}\n{traceback.format_exc()}"
         result = ExceptionWithTraceback(e, exception_traceback)
 
     # The type is right for the version, but the type cannot be defined 
correctly for Airflow 2 and 3
-    # concurrently;
+    # concurrently.
     return key, args, result
 
 
-# Backward compatibility alias
-send_task_to_executor = send_workload_to_executor

Review Comment:
   I think we still need this?  If not, then we need to update references to it 
such as (but not limited to) celery/provider.yaml 



##########
providers/celery/tests/integration/celery/test_celery_executor.py:
##########
@@ -316,27 +316,27 @@ def test_retry_on_error_sending_task(self, caplog):
             key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1)
             executor.queued_tasks[key] = workload
 
-            # Test that when heartbeat is called again, task is published 
again to Celery Queue
+            # Test that when heartbeat is called again, workload is published 
again to Celery Queue.
             executor.heartbeat()
-            assert dict(executor.task_publish_retries) == {key: 1}
-            assert len(executor.queued_tasks) == 1, "Task should remain in 
queue"
+            assert dict(executor.workload_publish_retries) == {key: 1}
+            assert len(executor.queued_tasks) == 1, "Workload should remain in 
queue"
             assert executor.event_buffer == {}
-            assert f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in 
caplog.text
+            assert f"[Try 1 of 3] Celery Task Timeout Error for Workload: 
({key})." in caplog.text

Review Comment:
   Here and below, it looks like you changed the expected message in the tests 
but didn't actually change the message log in the code?



##########
providers/celery/src/airflow/providers/celery/executors/celery_executor.py:
##########
@@ -136,149 +141,151 @@ def __init__(self, *args, **kwargs):
         from airflow.providers.celery.executors.celery_executor_utils import 
BulkStateFetcher
 
         self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism, 
celery_app=self.celery_app)
-        self.tasks = {}
-        self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
-        self.task_publish_max_retries = self.conf.getint("celery", 
"task_publish_max_retries", fallback=3)
+        self.workloads: dict[WorkloadKey, AsyncResult] = {}
+        self.workload_publish_retries: Counter[WorkloadKey] = Counter()
+        self.workload_publish_max_retries = self.conf.getint("celery", 
"task_publish_max_retries", fallback=3)
 
     def start(self) -> None:
         self.log.debug("Starting Celery Executor using %s processes for 
syncing", self._sync_parallelism)
 
-    def _num_tasks_per_send_process(self, to_send_count: int) -> int:
+    def _num_workloads_per_send_process(self, to_send_count: int) -> int:
         """
-        How many Celery tasks should each worker process send.
+        How many Celery workloads should each worker process send.
 
-        :return: Number of tasks that should be sent per process
+        :return: Number of workloads that should be sent per process
         """
         return max(1, math.ceil(to_send_count / self._sync_parallelism))
 
     def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None:
-        # Airflow V2 version
+        # Airflow V2 compatibility path — converts task tuples into 
workload-compatible tuples.
 
         task_tuples_to_send = [task_tuple[:3] + (self.team_name,) for 
task_tuple in task_tuples]
 
-        self._send_tasks(task_tuples_to_send)
+        self._send_workloads(task_tuples_to_send)
 
     def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
-        # Airflow V3 version -- have to delay imports until we know we are on 
v3
+        # Airflow V3 version -- have to delay imports until we know we are on 
v3.
         from airflow.executors.workloads import ExecuteTask
 
         if AIRFLOW_V_3_2_PLUS:
             from airflow.executors.workloads import ExecuteCallback
 
-        tasks: list[WorkloadInCelery] = []
+        workloads_to_be_sent: list[WorkloadInCelery] = []
         for workload in workloads:
             if isinstance(workload, ExecuteTask):
-                tasks.append((workload.ti.key, workload, workload.ti.queue, 
self.team_name))
+                workloads_to_be_sent.append((workload.ti.key, workload, 
workload.ti.queue, self.team_name))
             elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback):
-                # Use default queue for callbacks, or extract from callback 
data if available
+                # Use default queue for callbacks, or extract from callback 
data if available.
                 queue = "default"
                 if isinstance(workload.callback.data, dict) and "queue" in 
workload.callback.data:
                     queue = workload.callback.data["queue"]
-                tasks.append((workload.callback.key, workload, queue, 
self.team_name))
+                workloads_to_be_sent.append((workload.callback.key, workload, 
queue, self.team_name))
             else:
                 raise ValueError(f"{type(self)}._process_workloads cannot 
handle {type(workload)}")
 
-        self._send_tasks(tasks)
+        self._send_workloads(workloads_to_be_sent)
 
-    def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]):
+    def _send_workloads(self, workload_tuples_to_send: 
Sequence[WorkloadInCelery]):
         # Celery state queries will be stuck if we do not use one same backend
-        # for all tasks.
+        # for all workloads.
         cached_celery_backend = self.celery_app.backend
 
-        key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send)
-        self.log.debug("Sent all tasks.")
+        key_and_async_results = 
self._send_workloads_to_celery(workload_tuples_to_send)
+        self.log.debug("Sent all workloads.")
         from airflow.providers.celery.executors.celery_executor_utils import 
ExceptionWithTraceback
 
         for key, _, result in key_and_async_results:
             if isinstance(result, ExceptionWithTraceback) and isinstance(
                 result.exception, AirflowTaskTimeout
             ):
-                retries = self.task_publish_retries[key]
-                if retries < self.task_publish_max_retries:
+                retries = self.workload_publish_retries[key]
+                if retries < self.workload_publish_max_retries:
                     Stats.incr("celery.task_timeout_error")
                     self.log.info(
-                        "[Try %s of %s] Task Timeout Error for Task: (%s).",
-                        self.task_publish_retries[key] + 1,
-                        self.task_publish_max_retries,
+                        "[Try %s of %s] Task Timeout Error for Workload: 
(%s).",
+                        self.workload_publish_retries[key] + 1,
+                        self.workload_publish_max_retries,
                         tuple(key),
                     )
-                    self.task_publish_retries[key] = retries + 1
+                    self.workload_publish_retries[key] = retries + 1
                     continue
             if key in self.queued_tasks:
                 self.queued_tasks.pop(key)
             else:
                 self.queued_callbacks.pop(key, None)
-            self.task_publish_retries.pop(key, None)
+            self.workload_publish_retries.pop(key, None)
             if isinstance(result, ExceptionWithTraceback):
                 self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER, 
result.exception, result.traceback)
                 self.event_buffer[key] = (TaskInstanceState.FAILED, None)
             elif result is not None:
                 result.backend = cached_celery_backend
                 self.running.add(key)
-                self.tasks[key] = result
+                self.workloads[key] = result
 
-                # Store the Celery task_id in the event buffer. This will get 
"overwritten" if the task
+                # Store the Celery task_id (workload execution ID) in the 
event buffer. This will get "overwritten" if the task
                 # has another event, but that is fine, because the only other 
events are success/failed at
-                # which point we don't need the ID anymore anyway
+                # which point we don't need the ID anymore anyway.
                 self.event_buffer[key] = (TaskInstanceState.QUEUED, 
result.task_id)
 
-    def _send_tasks_to_celery(self, task_tuples_to_send: 
Sequence[WorkloadInCelery]):
-        from airflow.providers.celery.executors.celery_executor_utils import 
send_task_to_executor
+    def _send_workloads_to_celery(self, workload_tuples_to_send: 
Sequence[WorkloadInCelery]):
+        from airflow.providers.celery.executors.celery_executor_utils import 
send_workload_to_executor
 
-        if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
+        if len(workload_tuples_to_send) == 1 or self._sync_parallelism == 1:
             # One tuple, or max one process -> send it in the main thread.
-            return list(map(send_task_to_executor, task_tuples_to_send))
+            return list(map(send_workload_to_executor, 
workload_tuples_to_send))
 
         # Use chunks instead of a work queue to reduce context switching
-        # since tasks are roughly uniform in size
-        chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
-        num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
+        # since workloads are roughly uniform in size.
+        chunksize = 
self._num_workloads_per_send_process(len(workload_tuples_to_send))
+        num_processes = min(len(workload_tuples_to_send), 
self._sync_parallelism)
 
-        # Use ProcessPoolExecutor with team_name instead of task objects to 
avoid pickling issues.
+        # Use ProcessPoolExecutor with team_name instead of workload objects 
to avoid pickling issues.
         # Subprocesses reconstruct the team-specific Celery app from the team 
name and existing config.
         with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
             key_and_async_results = list(
-                send_pool.map(send_task_to_executor, task_tuples_to_send, 
chunksize=chunksize)
+                send_pool.map(send_workload_to_executor, 
workload_tuples_to_send, chunksize=chunksize)
             )
         return key_and_async_results
 
     def sync(self) -> None:
-        if not self.tasks:
-            self.log.debug("No task to query celery, skipping sync")
+        if not self.workloads:
+            self.log.debug("No workload to query celery, skipping sync")
             return
-        self.update_all_task_states()
+        self.update_all_workload_states()
 
     def debug_dump(self) -> None:
         """Debug dump; called in response to SIGUSR2 by the scheduler."""
         super().debug_dump()
         self.log.info(
-            "executor.tasks (%d)\n\t%s", len(self.tasks), 
"\n\t".join(map(repr, self.tasks.items()))
+            "executor.workloads (%d)\n\t%s",
+            len(self.workloads),
+            "\n\t".join(map(repr, self.workloads.items())),
         )
 
-    def update_all_task_states(self) -> None:
-        """Update states of the tasks."""
-        self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
-        state_and_info_by_celery_task_id = 
self.bulk_state_fetcher.get_many(self.tasks.values())
+    def update_all_workload_states(self) -> None:
+        """Update states of the workloads."""
+        self.log.debug("Inquiring about %s celery workload(s)", 
len(self.workloads))
+        state_and_info_by_celery_task_id = 
self.bulk_state_fetcher.get_many(self.workloads.values())
 
         self.log.debug("Inquiries completed.")
-        for key, async_result in list(self.tasks.items()):
+        for key, async_result in list(self.workloads.items()):
             state, info = 
state_and_info_by_celery_task_id.get(async_result.task_id)
             if state:
-                self.update_task_state(key, state, info)
+                self.update_workload_state(key, state, info)
 
     def change_state(
         self, key: TaskInstanceKey, state: TaskInstanceState, info=None, 
remove_running=True
     ) -> None:
         super().change_state(key, state, info, remove_running=remove_running)
-        self.tasks.pop(key, None)
+        self.workloads.pop(key, None)
 
-    def update_task_state(self, key: TaskInstanceKey, state: str, info: Any) 
-> None:
-        """Update state of a single task."""
+    def update_workload_state(self, key: WorkloadKey, state: str, info: Any) 
-> None:

Review Comment:
   Double check me, but I don't believe callbacks support this part (yet?), do 
they?  If not, then maybe we should leave this one for now?  With this change 
it makes the method look like it properly handles both types when it doesn't.
   



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to