ferruzzi commented on code in PR #63888:
URL: https://github.com/apache/airflow/pull/63888#discussion_r2997305034
##########
providers/celery/src/airflow/providers/celery/executors/celery_executor.py:
##########
@@ -102,7 +107,7 @@ class CeleryExecutor(BaseExecutor):
if TYPE_CHECKING:
if AIRFLOW_V_3_0_PLUS:
# TODO: TaskSDK: move this type change into BaseExecutor
- queued_tasks: dict[TaskInstanceKey, workloads.All] # type:
ignore[assignment]
+ queued_tasks: dict[WorkloadKey, workloads.All] # type:
ignore[assignment]
Review Comment:
I think there's a miss here. You are importing WorkloadKey if version is
over 3.2 up above, but using it here if airflow version is 3.0. What about
using this as the import block?
```
# Remove this conditional once min version > 3.2
try:
from airflow.executors.workloads.types import WorkloadKey
except ImportError:
from airflow.models.taskinstancekey import TaskInstanceKey as WorkloadKey
```
I know there's some community debate over using try/catch on imports, but I
think this feels like the right time to use one.
((I think the same comment goes for celery_executor_utils.py as well))
##########
providers/celery/src/airflow/providers/celery/executors/celery_executor_utils.py:
##########
@@ -369,27 +369,23 @@ def send_workload_to_executor(
try:
import redis.client # noqa: F401
except ImportError:
- pass # Redis not installed or not using Redis backend
+ pass # Redis not installed or not using Redis backend.
try:
with timeout(seconds=OPERATION_TIMEOUT):
- result = task_to_run.apply_async(args=args, queue=queue)
+ result = celery_task.apply_async(args=args, queue=queue)
except (Exception, AirflowTaskTimeout) as e:
exception_traceback = f"Celery Task ID:
{key}\n{traceback.format_exc()}"
result = ExceptionWithTraceback(e, exception_traceback)
# The type is right for the version, but the type cannot be defined
correctly for Airflow 2 and 3
- # concurrently;
+ # concurrently.
return key, args, result
-# Backward compatibility alias
-send_task_to_executor = send_workload_to_executor
Review Comment:
I think we still need this? If not, then we need to update references to it
such as (but not limited to) celery/provider.yaml
##########
providers/celery/tests/integration/celery/test_celery_executor.py:
##########
@@ -316,27 +316,27 @@ def test_retry_on_error_sending_task(self, caplog):
key = (task.dag.dag_id, task.task_id, ti.run_id, 0, -1)
executor.queued_tasks[key] = workload
- # Test that when heartbeat is called again, task is published
again to Celery Queue
+ # Test that when heartbeat is called again, workload is published
again to Celery Queue.
executor.heartbeat()
- assert dict(executor.task_publish_retries) == {key: 1}
- assert len(executor.queued_tasks) == 1, "Task should remain in
queue"
+ assert dict(executor.workload_publish_retries) == {key: 1}
+ assert len(executor.queued_tasks) == 1, "Workload should remain in
queue"
assert executor.event_buffer == {}
- assert f"[Try 1 of 3] Task Timeout Error for Task: ({key})." in
caplog.text
+ assert f"[Try 1 of 3] Celery Task Timeout Error for Workload:
({key})." in caplog.text
Review Comment:
Here and below, it looks like you changed the expected message in the tests
but didn't actually change the message log in the code?
##########
providers/celery/src/airflow/providers/celery/executors/celery_executor.py:
##########
@@ -136,149 +141,151 @@ def __init__(self, *args, **kwargs):
from airflow.providers.celery.executors.celery_executor_utils import
BulkStateFetcher
self.bulk_state_fetcher = BulkStateFetcher(self._sync_parallelism,
celery_app=self.celery_app)
- self.tasks = {}
- self.task_publish_retries: Counter[TaskInstanceKey] = Counter()
- self.task_publish_max_retries = self.conf.getint("celery",
"task_publish_max_retries", fallback=3)
+ self.workloads: dict[WorkloadKey, AsyncResult] = {}
+ self.workload_publish_retries: Counter[WorkloadKey] = Counter()
+ self.workload_publish_max_retries = self.conf.getint("celery",
"task_publish_max_retries", fallback=3)
def start(self) -> None:
self.log.debug("Starting Celery Executor using %s processes for
syncing", self._sync_parallelism)
- def _num_tasks_per_send_process(self, to_send_count: int) -> int:
+ def _num_workloads_per_send_process(self, to_send_count: int) -> int:
"""
- How many Celery tasks should each worker process send.
+ How many Celery workloads should each worker process send.
- :return: Number of tasks that should be sent per process
+ :return: Number of workloads that should be sent per process
"""
return max(1, math.ceil(to_send_count / self._sync_parallelism))
def _process_tasks(self, task_tuples: Sequence[TaskTuple]) -> None:
- # Airflow V2 version
+ # Airflow V2 compatibility path — converts task tuples into
workload-compatible tuples.
task_tuples_to_send = [task_tuple[:3] + (self.team_name,) for
task_tuple in task_tuples]
- self._send_tasks(task_tuples_to_send)
+ self._send_workloads(task_tuples_to_send)
def _process_workloads(self, workloads: Sequence[workloads.All]) -> None:
- # Airflow V3 version -- have to delay imports until we know we are on
v3
+ # Airflow V3 version -- have to delay imports until we know we are on
v3.
from airflow.executors.workloads import ExecuteTask
if AIRFLOW_V_3_2_PLUS:
from airflow.executors.workloads import ExecuteCallback
- tasks: list[WorkloadInCelery] = []
+ workloads_to_be_sent: list[WorkloadInCelery] = []
for workload in workloads:
if isinstance(workload, ExecuteTask):
- tasks.append((workload.ti.key, workload, workload.ti.queue,
self.team_name))
+ workloads_to_be_sent.append((workload.ti.key, workload,
workload.ti.queue, self.team_name))
elif AIRFLOW_V_3_2_PLUS and isinstance(workload, ExecuteCallback):
- # Use default queue for callbacks, or extract from callback
data if available
+ # Use default queue for callbacks, or extract from callback
data if available.
queue = "default"
if isinstance(workload.callback.data, dict) and "queue" in
workload.callback.data:
queue = workload.callback.data["queue"]
- tasks.append((workload.callback.key, workload, queue,
self.team_name))
+ workloads_to_be_sent.append((workload.callback.key, workload,
queue, self.team_name))
else:
raise ValueError(f"{type(self)}._process_workloads cannot
handle {type(workload)}")
- self._send_tasks(tasks)
+ self._send_workloads(workloads_to_be_sent)
- def _send_tasks(self, task_tuples_to_send: Sequence[WorkloadInCelery]):
+ def _send_workloads(self, workload_tuples_to_send:
Sequence[WorkloadInCelery]):
# Celery state queries will be stuck if we do not use one same backend
- # for all tasks.
+ # for all workloads.
cached_celery_backend = self.celery_app.backend
- key_and_async_results = self._send_tasks_to_celery(task_tuples_to_send)
- self.log.debug("Sent all tasks.")
+ key_and_async_results =
self._send_workloads_to_celery(workload_tuples_to_send)
+ self.log.debug("Sent all workloads.")
from airflow.providers.celery.executors.celery_executor_utils import
ExceptionWithTraceback
for key, _, result in key_and_async_results:
if isinstance(result, ExceptionWithTraceback) and isinstance(
result.exception, AirflowTaskTimeout
):
- retries = self.task_publish_retries[key]
- if retries < self.task_publish_max_retries:
+ retries = self.workload_publish_retries[key]
+ if retries < self.workload_publish_max_retries:
Stats.incr("celery.task_timeout_error")
self.log.info(
- "[Try %s of %s] Task Timeout Error for Task: (%s).",
- self.task_publish_retries[key] + 1,
- self.task_publish_max_retries,
+ "[Try %s of %s] Task Timeout Error for Workload:
(%s).",
+ self.workload_publish_retries[key] + 1,
+ self.workload_publish_max_retries,
tuple(key),
)
- self.task_publish_retries[key] = retries + 1
+ self.workload_publish_retries[key] = retries + 1
continue
if key in self.queued_tasks:
self.queued_tasks.pop(key)
else:
self.queued_callbacks.pop(key, None)
- self.task_publish_retries.pop(key, None)
+ self.workload_publish_retries.pop(key, None)
if isinstance(result, ExceptionWithTraceback):
self.log.error("%s: %s\n%s\n", CELERY_SEND_ERR_MSG_HEADER,
result.exception, result.traceback)
self.event_buffer[key] = (TaskInstanceState.FAILED, None)
elif result is not None:
result.backend = cached_celery_backend
self.running.add(key)
- self.tasks[key] = result
+ self.workloads[key] = result
- # Store the Celery task_id in the event buffer. This will get
"overwritten" if the task
+ # Store the Celery task_id (workload execution ID) in the
event buffer. This will get "overwritten" if the task
# has another event, but that is fine, because the only other
events are success/failed at
- # which point we don't need the ID anymore anyway
+ # which point we don't need the ID anymore anyway.
self.event_buffer[key] = (TaskInstanceState.QUEUED,
result.task_id)
- def _send_tasks_to_celery(self, task_tuples_to_send:
Sequence[WorkloadInCelery]):
- from airflow.providers.celery.executors.celery_executor_utils import
send_task_to_executor
+ def _send_workloads_to_celery(self, workload_tuples_to_send:
Sequence[WorkloadInCelery]):
+ from airflow.providers.celery.executors.celery_executor_utils import
send_workload_to_executor
- if len(task_tuples_to_send) == 1 or self._sync_parallelism == 1:
+ if len(workload_tuples_to_send) == 1 or self._sync_parallelism == 1:
# One tuple, or max one process -> send it in the main thread.
- return list(map(send_task_to_executor, task_tuples_to_send))
+ return list(map(send_workload_to_executor,
workload_tuples_to_send))
# Use chunks instead of a work queue to reduce context switching
- # since tasks are roughly uniform in size
- chunksize = self._num_tasks_per_send_process(len(task_tuples_to_send))
- num_processes = min(len(task_tuples_to_send), self._sync_parallelism)
+ # since workloads are roughly uniform in size.
+ chunksize =
self._num_workloads_per_send_process(len(workload_tuples_to_send))
+ num_processes = min(len(workload_tuples_to_send),
self._sync_parallelism)
- # Use ProcessPoolExecutor with team_name instead of task objects to
avoid pickling issues.
+ # Use ProcessPoolExecutor with team_name instead of workload objects
to avoid pickling issues.
# Subprocesses reconstruct the team-specific Celery app from the team
name and existing config.
with ProcessPoolExecutor(max_workers=num_processes) as send_pool:
key_and_async_results = list(
- send_pool.map(send_task_to_executor, task_tuples_to_send,
chunksize=chunksize)
+ send_pool.map(send_workload_to_executor,
workload_tuples_to_send, chunksize=chunksize)
)
return key_and_async_results
def sync(self) -> None:
- if not self.tasks:
- self.log.debug("No task to query celery, skipping sync")
+ if not self.workloads:
+ self.log.debug("No workload to query celery, skipping sync")
return
- self.update_all_task_states()
+ self.update_all_workload_states()
def debug_dump(self) -> None:
"""Debug dump; called in response to SIGUSR2 by the scheduler."""
super().debug_dump()
self.log.info(
- "executor.tasks (%d)\n\t%s", len(self.tasks),
"\n\t".join(map(repr, self.tasks.items()))
+ "executor.workloads (%d)\n\t%s",
+ len(self.workloads),
+ "\n\t".join(map(repr, self.workloads.items())),
)
- def update_all_task_states(self) -> None:
- """Update states of the tasks."""
- self.log.debug("Inquiring about %s celery task(s)", len(self.tasks))
- state_and_info_by_celery_task_id =
self.bulk_state_fetcher.get_many(self.tasks.values())
+ def update_all_workload_states(self) -> None:
+ """Update states of the workloads."""
+ self.log.debug("Inquiring about %s celery workload(s)",
len(self.workloads))
+ state_and_info_by_celery_task_id =
self.bulk_state_fetcher.get_many(self.workloads.values())
self.log.debug("Inquiries completed.")
- for key, async_result in list(self.tasks.items()):
+ for key, async_result in list(self.workloads.items()):
state, info =
state_and_info_by_celery_task_id.get(async_result.task_id)
if state:
- self.update_task_state(key, state, info)
+ self.update_workload_state(key, state, info)
def change_state(
self, key: TaskInstanceKey, state: TaskInstanceState, info=None,
remove_running=True
) -> None:
super().change_state(key, state, info, remove_running=remove_running)
- self.tasks.pop(key, None)
+ self.workloads.pop(key, None)
- def update_task_state(self, key: TaskInstanceKey, state: str, info: Any)
-> None:
- """Update state of a single task."""
+ def update_workload_state(self, key: WorkloadKey, state: str, info: Any)
-> None:
Review Comment:
Double check me, but I don't believe callbacks support this part (yet?), do
they? If not, then maybe we should leave this one for now? With this change
it makes the method look like it properly handles both types when it doesn't.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]