kaxil commented on code in PR #62922:
URL: https://github.com/apache/airflow/pull/62922#discussion_r2890440469


##########
task-sdk/src/airflow/sdk/bases/iterableoperator.py:
##########
@@ -0,0 +1,432 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+import time
+from collections import deque
+from collections.abc import Iterable, Sequence
+from concurrent.futures import Future
+from math import ceil
+from time import sleep
+from typing import TYPE_CHECKING, Any
+
+from more_itertools import ichunked
+
+from airflow.exceptions import (

Review Comment:
   `more-itertools` is not a dependency of `task-sdk` — it's only in 
`providers/common/sql/pyproject.toml`. This will be an `ImportError` at runtime.
   
   Either add it to task-sdk's dependencies or replace `ichunked` with 
`itertools.batched` (Python 3.12+) or a simple chunking generator.



##########
task-sdk/src/airflow/sdk/definitions/mappedoperator.py:
##########
@@ -336,19 +364,20 @@ def __repr__(self):
         return f"<Mapped({self.task_type}): {self.task_id}>"
 
     def __attrs_post_init__(self):
-        from airflow.sdk.definitions.xcom_arg import XComArg
-
-        if self.get_closest_mapped_task_group() is not None:
-            raise NotImplementedError("operator expansion in an expanded task 
group is not yet supported")
-
-        if self.task_group:
-            self.task_group.add(self)
-        if self.dag:
-            self.dag.add_task(self)
-        XComArg.apply_upstream_relationship(self, 
self._get_specified_expand_input().value)
-        for k, v in self.partial_kwargs.items():
-            if k in self.template_fields:
-                XComArg.apply_upstream_relationship(self, v)
+        if self._apply_upstream_relationship:
+            from airflow.sdk.definitions.xcom_arg import XComArg
+
+            if self.get_closest_mapped_task_group() is not None:
+                raise NotImplementedError("operator expansion in an expanded 
task group is not yet supported")
+
+            if self.task_group:
+                self.task_group.add(self)
+            if self.dag:

Review Comment:
   When `apply_upstream_relationship=False`, this skips *everything* — not just 
upstream relationships, but also task group registration 
(`self.task_group.add(self)`) and DAG registration (`self.dag.add_task(self)`). 
That means `IterableOperator` creates `MappedOperator` instances that are 
invisible to the DAG, which breaks serialization, the UI, and dependency 
tracking.
   
   Should we separate these concerns? The flag name suggests it only controls 
upstream relationships, but it actually controls all post-init behavior.



##########
task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py:
##########
@@ -184,7 +228,22 @@ def iter_references(self) -> Iterable[tuple[Operator, 
str]]:
             if isinstance(x, XComArg):
                 yield from x.iter_references()
 
-    def resolve(self, context: Mapping[str, Any]) -> tuple[Mapping[str, Any], 
set[int]]:
+    def iter_values(self, context: Context) -> Iterable[Any]:
+        def resolve(value: Any) -> Any:
+            if isinstance(value, XComArg):
+                return value.iter_values(context=context)
+            return value
+

Review Comment:
   This `iter_values` doesn't produce the right output for `expand()` semantics.
   
   For `expand(a=[1, 2], b=[3, 4])`, this yields `{"a": 1}, {"a": 2}, {"b": 3}, 
{"b": 4}` — 4 separate single-key dicts. But dynamic task mapping produces the 
cartesian product: `{"a": 1, "b": 3}, {"a": 1, "b": 4}, {"a": 2, "b": 3}, {"a": 
2, "b": 4}` — 4 dicts with all keys combined.
   
   Tasks receiving only `{"a": 1}` without `b` would fail with missing 
arguments.



##########
task-sdk/src/airflow/sdk/definitions/_internal/expandinput.py:
##########
@@ -79,8 +79,45 @@ def _needs_run_time_resolution(v: OperatorExpandArgument) -> 
TypeGuard[MappedArg
     return isinstance(v, (MappedArgument, XComArg))
 
 
+class ExpandInput(ABC, ResolveMixin):

Review Comment:
   Changing `ExpandInput` from a Union type alias to an ABC is a significant 
refactor. A couple issues:
   
   1. `DecoratedExpandInput` and `MappedArgument` inherit from `ExpandInput` 
but don't set `EXPAND_INPUT_TYPE`. The serializer in 
`airflow-core/src/airflow/serialization/encoders.py` (`encode_expand_input`) 
accesses `var.EXPAND_INPUT_TYPE` — this will crash with `AttributeError`.
   
   2. `MappedArgument` previously inherited from `ResolveMixin`, now from 
`ExpandInput`. Is it really an expand input? It's a stand-in for 
task-group-mapping arguments — making it an `ExpandInput` subclass seems like 
it conflates two different concepts.



##########
task-sdk/src/airflow/sdk/execution_time/lazy_sequence.py:
##########
@@ -166,6 +166,60 @@ def __getitem__(self, key: int | slice) -> T | Sequence[T]:
         return XCom.deserialize_value(_XComWrapper(msg.root))
 
 
[email protected]
+class XComIterable(Sequence):
+    """An iterable that lazily fetches XCom values one by one instead of 
loading all at once."""

Review Comment:
   `XComIterable` implements both `Sequence` (via inheritance) and `Iterator` 
(via `__iter__` returning `self` + `__next__`). A `Sequence.__iter__` should 
return a fresh iterator on each call, not reuse the object itself — otherwise 
you can't iterate twice concurrently, and multiple consumers sharing the same 
instance will interfere with each other.
   
   Consider having `__iter__` return a separate iterator object, similar to how 
`LazyXComSequence` uses `LazyXComIterator`.



##########
task-sdk/src/airflow/sdk/bases/iterableoperator.py:
##########
@@ -0,0 +1,432 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+import time
+from collections import deque
+from collections.abc import Iterable, Sequence
+from concurrent.futures import Future
+from math import ceil
+from time import sleep
+from typing import TYPE_CHECKING, Any
+
+from more_itertools import ichunked
+
+from airflow.exceptions import (
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk import timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException
+from airflow.sdk.execution_time.executor import HybridExecutor, 
_execute_async_task, collect_futures
+from airflow.sdk.execution_time.lazy_sequence import XComIterable
+from airflow.sdk.execution_time.task_runner import MappedTaskInstance, 
RuntimeTaskInstance, _execute_task
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.definitions.mappedoperator import MappedOperator
+
+
+class TaskExecutor(LoggingMixin):
+    """Base class to run an operator or trigger with given task context and 
task instance."""
+
+    def __init__(
+        self,
+        task_instance: RuntimeTaskInstance,
+    ):
+        super().__init__()
+        self._task_instance = task_instance
+        self._result: Any | None = None
+        self._start_time: float | None = None
+
+    @property
+    def task_instance(self) -> RuntimeTaskInstance:
+        return self._task_instance
+
+    @property
+    def dag_id(self) -> str:
+        return self.task_instance.dag_id
+
+    @property
+    def task_id(self) -> str:
+        return self.task_instance.task_id
+
+    @property
+    def task_index(self) -> int:
+        return self.task_instance.map_index
+
+    @property
+    def key(self):
+        return self.task_instance.xcom_key
+
+    @property
+    def operator(self) -> BaseOperator:
+        return self.task_instance.task
+
+    @property
+    def is_async(self) -> bool:
+        return self.task_instance.is_async
+
+    def run(self, context: Context):
+        return _execute_task(context, self.task_instance, self.log)
+
+    async def arun(self, context: Context):
+        return await _execute_async_task(context, self.task_instance, self.log)
+
+    def __enter__(self):
+        self._start_time = time.monotonic()
+
+        if self.log.isEnabledFor(logging.INFO):
+            self.log.info(
+                "Attempting running task %s of %s for %s with map_index %s in 
%s mode.",
+                self.task_instance.try_number,
+                self.operator.retries,
+                self.task_instance.task_id,
+                self.task_index,
+                "async" if self.is_async else "sync",
+            )
+        return self
+
+    async def __aenter__(self):
+        return self.__enter__()
+
+    def __exit__(self, exc_type, exc_value, traceback):

Review Comment:
   `self._start_time` is initialized to `None` and only set in `__enter__`. If 
`__exit__` gets called without a matching `__enter__` (edge case, but 
possible), `time.monotonic() - self._start_time` will be `TypeError`.



##########
task-sdk/src/airflow/sdk/bases/operator.py:
##########
@@ -1657,7 +1662,14 @@ def resume_execution(self, next_method: str, 
next_kwargs: dict[str, Any] | None,
                 raise TaskDeferralTimeout(error)
             raise TaskDeferralError(error)
         # Grab the callable off the Operator/Task and add in any kwargs
-        execute_callable = getattr(self, next_method)
+        return getattr(self, next_method)
+

Review Comment:
   `next_callable` takes 3 positional args after `self`: `next_method`, 
`next_kwargs`, `context`. But `DecoratedDeferredAsyncOperator.aexecute` calls 
it with only 2:
   
   ```python
   next_method = self._operator.next_callable(
       self._task_deferred.method_name,
       self._task_deferred.kwargs,
   )
   ```
   
   That will raise `TypeError` for the missing `context` argument.
   
   Also — the original `resume_execution` handled `next_kwargs is None` 
defaulting before the `__fail__` check. Now `next_callable` receives the raw 
kwargs without that defaulting, so if `DecoratedDeferredAsyncOperator` passes 
`None` kwargs, `next_kwargs.get("traceback")` will `AttributeError`.



##########
task-sdk/src/airflow/sdk/bases/iterableoperator.py:
##########
@@ -0,0 +1,432 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+import time
+from collections import deque
+from collections.abc import Iterable, Sequence
+from concurrent.futures import Future
+from math import ceil
+from time import sleep
+from typing import TYPE_CHECKING, Any
+
+from more_itertools import ichunked
+
+from airflow.exceptions import (
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk import timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException
+from airflow.sdk.execution_time.executor import HybridExecutor, 
_execute_async_task, collect_futures
+from airflow.sdk.execution_time.lazy_sequence import XComIterable
+from airflow.sdk.execution_time.task_runner import MappedTaskInstance, 
RuntimeTaskInstance, _execute_task
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.definitions.mappedoperator import MappedOperator
+
+
+class TaskExecutor(LoggingMixin):
+    """Base class to run an operator or trigger with given task context and 
task instance."""
+
+    def __init__(
+        self,
+        task_instance: RuntimeTaskInstance,
+    ):
+        super().__init__()
+        self._task_instance = task_instance
+        self._result: Any | None = None
+        self._start_time: float | None = None
+
+    @property
+    def task_instance(self) -> RuntimeTaskInstance:
+        return self._task_instance
+
+    @property
+    def dag_id(self) -> str:
+        return self.task_instance.dag_id
+
+    @property
+    def task_id(self) -> str:
+        return self.task_instance.task_id
+
+    @property
+    def task_index(self) -> int:
+        return self.task_instance.map_index
+
+    @property
+    def key(self):
+        return self.task_instance.xcom_key
+
+    @property
+    def operator(self) -> BaseOperator:
+        return self.task_instance.task
+
+    @property
+    def is_async(self) -> bool:
+        return self.task_instance.is_async
+
+    def run(self, context: Context):
+        return _execute_task(context, self.task_instance, self.log)
+
+    async def arun(self, context: Context):
+        return await _execute_async_task(context, self.task_instance, self.log)
+
+    def __enter__(self):
+        self._start_time = time.monotonic()
+
+        if self.log.isEnabledFor(logging.INFO):
+            self.log.info(
+                "Attempting running task %s of %s for %s with map_index %s in 
%s mode.",
+                self.task_instance.try_number,
+                self.operator.retries,
+                self.task_instance.task_id,
+                self.task_index,
+                "async" if self.is_async else "sync",
+            )
+        return self
+
+    async def __aenter__(self):
+        return self.__enter__()
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        elapsed = time.monotonic() - self._start_time
+
+        if exc_value:
+            if not isinstance(exc_value, TaskDeferred):
+                if self.task_instance.next_try_number > 
self.task_instance.max_tries:
+                    self.log.error(
+                        "Task instance %s for %s failed after %s attempts in 
%.2f seconds due to: %s",
+                        self.task_index,
+                        self.task_instance.task_id,
+                        self.task_instance.max_tries,
+                        elapsed,
+                        exc_value,
+                    )
+                    self.task_instance.state = TaskInstanceState.FAILED
+                    raise exc_value
+                self.task_instance.try_number += 1
+                self.task_instance.end_date = timezone.utcnow()
+                self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE
+                raise 
AirflowRescheduleTaskInstanceException(task=self.task_instance)
+            raise exc_value
+
+        self.task_instance.state = TaskInstanceState.SUCCESS
+        if self.log.isEnabledFor(logging.INFO):
+            self.log.info(
+                "Task instance %s for %s finished successfully in %s attempts 
in %.2f seconds",
+                self.task_index,
+                self.task_instance.task_id,
+                self.task_instance.next_try_number,
+                elapsed,
+            )
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        self.__exit__(exc_type, exc_value, traceback)
+
+
+class IterableOperator(BaseOperator):
+    """Object representing an iterable operator in a DAG."""
+
+    _operator: MappedOperator
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+    # each operator should override this class attr for shallow copy attrs.
+    shallow_copy_attrs: Sequence[str] = (
+        "_operator",
+        "expand_input",
+        "partial_kwargs",
+        "_log",
+    )
+
+    def __init__(
+        self,
+        *,
+        operator: MappedOperator,
+        expand_input: ExpandInput,
+        **kwargs,
+    ):
+        super().__init__(
+            **{
+                **kwargs,
+                "task_id": operator.task_id,
+                "owner": operator.owner,
+                "email": operator.email,
+                "email_on_retry": operator.email_on_retry,
+                "email_on_failure": operator.email_on_failure,
+                "retries": 0,  # We should not retry the IterableOperator, 
only the mapped ti's should be retried
+                "retry_delay": operator.retry_delay,
+                "retry_exponential_backoff": 
operator.retry_exponential_backoff,
+                "max_retry_delay": operator.max_retry_delay,
+                "start_date": operator.start_date,
+                "end_date": operator.end_date,
+                "depends_on_past": operator.depends_on_past,
+                "ignore_first_depends_on_past": 
operator.ignore_first_depends_on_past,
+                "wait_for_past_depends_before_skipping": 
operator.wait_for_past_depends_before_skipping,
+                "wait_for_downstream": operator.wait_for_downstream,
+                "dag": operator.dag,
+                "priority_weight": operator.priority_weight,
+                "queue": operator.queue,
+                "pool": operator.pool,
+                "pool_slots": operator.pool_slots,
+                "execution_timeout": None,
+                "trigger_rule": operator.trigger_rule,
+                "resources": operator.resources,
+                "run_as_user": operator.run_as_user,
+                "map_index_template": operator.map_index_template,
+                "max_active_tis_per_dag": operator.max_active_tis_per_dag,
+                "max_active_tis_per_dagrun": 
operator.max_active_tis_per_dagrun,
+                "executor": operator.executor,
+                "executor_config": operator.executor_config,
+                "inlets": operator.inlets,
+                "outlets": operator.outlets,
+                "task_group": operator.task_group,
+                "doc": operator.doc,
+                "doc_md": operator.doc_md,
+                "doc_json": operator.doc_json,
+                "doc_yaml": operator.doc_yaml,
+                "doc_rst": operator.doc_rst,
+                "task_display_name": operator.task_display_name,
+                "allow_nested_operators": operator.allow_nested_operators,
+            }
+        )
+        self._operator = operator
+        self.expand_input = expand_input
+        self.partial_kwargs = operator.partial_kwargs or {}
+        self._number_of_tasks: int = 0
+        XComArg.apply_upstream_relationship(self, self.expand_input.value)
+
+    @property
+    def task_type(self) -> str:
+        """@property: type of the task."""
+        return self._operator.__class__.__name__
+
+    @property
+    def max_workers(self):
+        return self.max_active_tis_per_dag or os.cpu_count() or 1
+
+    @property
+    def timeout(self) -> float | None:
+        if self.execution_timeout:
+            return self.execution_timeout.total_seconds()
+        return None
+
+    def _get_specified_expand_input(self) -> ExpandInput:
+        return self.expand_input
+
+    def _unmap_operator(self, mapped_kwargs: dict):
+        self._number_of_tasks += 1
+        return self._operator.unmap(mapped_kwargs)
+
+    def _xcom_push(self, context: Context, task: RuntimeTaskInstance, value: 
Any) -> None:
+        self.log.debug("Pushing XCom %s", task.map_index)
+
+        context["ti"].xcom_push(key=task.xcom_key, value=value)
+
+    def _run_tasks(
+        self,
+        context: Context,
+        tasks: Iterable[RuntimeTaskInstance],
+    ) -> None:
+        exception: BaseException | None = None
+        reschedule_date = timezone.utcnow()
+        prev_futures_count = 0
+        futures: dict[Future, RuntimeTaskInstance] = {}
+        failed_tasks: deque[RuntimeTaskInstance] = deque()
+        chunked_tasks = ichunked(tasks, self.max_workers)
+
+        self.log.info("Running tasks with %d workers", self.max_workers)
+
+        with event_loop() as loop:
+            with HybridExecutor(loop=loop, max_workers=self.max_workers) as 
executor:
+                for task in next(chunked_tasks, []):
+                    if task.is_async:
+                        future = executor.submit(self._run_async_operator, 
context, task)
+                    else:
+                        future = executor.submit(self._run_operator, context, 
task)
+                    futures[future] = task
+
+                while futures:
+                    futures_count = len(futures)
+
+                    if futures_count != prev_futures_count:
+                        self.log.info("Number of remaining futures: %s", 
futures_count)
+                        prev_futures_count = futures_count
+
+                    ready_futures = False
+
+                    for future in collect_futures(loop, futures.keys()):
+                        task = futures.pop(future)
+                        ready_futures = True
+
+                        try:
+                            if isinstance(future, asyncio.futures.Future):
+                                result = future.result()
+                            else:
+                                result = future.result(timeout=self.timeout)
+
+                            self.log.debug("result: %s", result)
+
+                            if result and task.task.do_xcom_push:
+                                self._xcom_push(
+                                    context=context,
+                                    task=task,
+                                    value=result,
+                                )
+                        except TaskDeferred as task_deferred:
+                            operator = DecoratedDeferredAsyncOperator(
+                                operator=task.task, task_deferred=task_deferred
+                            )
+                            failed_tasks.append(
+                                self._create_mapped_task(task.run_id, 
task.map_index, operator)
+                            )
+                        except asyncio.TimeoutError as e:
+                            self.log.warning("A timeout occurred for task_id 
%s", task.task_id)
+                            if task.next_try_number > self.retries:
+                                exception = AirflowTaskTimeout(e)
+                            else:
+                                reschedule_date = min(reschedule_date, 
task.next_retry_datetime())
+                                failed_tasks.append(task)
+                        except AirflowRescheduleTaskInstanceException as e:
+                            reschedule_date = min(reschedule_date, 
e.reschedule_date)
+                            self.log.warning(
+                                "An exception occurred for task_id %s with 
map_index %s, it has been rescheduled at %s",
+                                task.task_id,
+                                task.map_index,
+                                reschedule_date,
+                            )
+                            failed_tasks.append(e.task)
+                        except Exception as e:
+                            self.log.error(
+                                "An exception occurred for task_id %s with 
map_index %s",
+                                task.task_id,
+                                task.map_index,
+                            )
+                            exception = e
+
+                    if len(futures) < self.max_workers:
+                        for task in next(chunked_tasks, []):
+                            if task.is_async:
+                                future = 
executor.submit(self._run_async_operator, context, task)
+                            else:
+                                future = executor.submit(self._run_operator, 
context, task)
+                            futures[future] = task
+                    elif not ready_futures and futures:
+                        sleep(len(futures) * 0.1)
+
+        if not failed_tasks:
+            if exception:
+                raise exception
+            if self.do_xcom_push:
+                return XComIterable(
+                    task_id=self.task_id,
+                    dag_id=self.dag_id,
+                    run_id=context["run_id"],
+                    length=self._number_of_tasks,
+                )
+
+        now = timezone.utcnow()
+
+        # Calculate delay before the next retry
+        if reschedule_date > now:
+            delay_seconds = ceil((reschedule_date - now).total_seconds())
+
+            self.log.info(
+                "Attempting to run %s failed tasks within %s seconds...",
+                len(failed_tasks),
+                delay_seconds,
+            )
+
+            sleep(delay_seconds)
+
+        return self._run_tasks(context=context, tasks=list(failed_tasks))
+
+    def _run_operator(self, context: Context, task_instance: 
RuntimeTaskInstance):
+        with TaskExecutor(task_instance=task_instance) as executor:
+            return executor.run(
+                context={

Review Comment:
   `_run_operator` creates a new dict merging `context` with task-specific 
overrides, but the base `context` dict is shared across all threads. If any 
task modifies values in `context` during execution (which tasks commonly do), 
that's a race condition.
   
   Same issue with `_run_async_operator` below. Probably need to deep-copy 
`context` per task.



##########
task-sdk/src/airflow/sdk/bases/iterableoperator.py:
##########
@@ -0,0 +1,432 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+import time
+from collections import deque
+from collections.abc import Iterable, Sequence
+from concurrent.futures import Future
+from math import ceil
+from time import sleep
+from typing import TYPE_CHECKING, Any
+
+from more_itertools import ichunked
+
+from airflow.exceptions import (
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk import timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException
+from airflow.sdk.execution_time.executor import HybridExecutor, 
_execute_async_task, collect_futures
+from airflow.sdk.execution_time.lazy_sequence import XComIterable
+from airflow.sdk.execution_time.task_runner import MappedTaskInstance, 
RuntimeTaskInstance, _execute_task
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.definitions.mappedoperator import MappedOperator
+
+
+class TaskExecutor(LoggingMixin):
+    """Base class to run an operator or trigger with given task context and 
task instance."""
+
+    def __init__(
+        self,
+        task_instance: RuntimeTaskInstance,
+    ):
+        super().__init__()
+        self._task_instance = task_instance
+        self._result: Any | None = None
+        self._start_time: float | None = None
+
+    @property
+    def task_instance(self) -> RuntimeTaskInstance:
+        return self._task_instance
+
+    @property
+    def dag_id(self) -> str:
+        return self.task_instance.dag_id
+
+    @property
+    def task_id(self) -> str:
+        return self.task_instance.task_id
+
+    @property
+    def task_index(self) -> int:
+        return self.task_instance.map_index
+
+    @property
+    def key(self):
+        return self.task_instance.xcom_key
+
+    @property
+    def operator(self) -> BaseOperator:
+        return self.task_instance.task
+
+    @property
+    def is_async(self) -> bool:
+        return self.task_instance.is_async
+
+    def run(self, context: Context):
+        return _execute_task(context, self.task_instance, self.log)
+
+    async def arun(self, context: Context):
+        return await _execute_async_task(context, self.task_instance, self.log)
+
+    def __enter__(self):
+        self._start_time = time.monotonic()
+
+        if self.log.isEnabledFor(logging.INFO):
+            self.log.info(
+                "Attempting running task %s of %s for %s with map_index %s in 
%s mode.",
+                self.task_instance.try_number,
+                self.operator.retries,
+                self.task_instance.task_id,
+                self.task_index,
+                "async" if self.is_async else "sync",
+            )
+        return self
+
+    async def __aenter__(self):
+        return self.__enter__()
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        elapsed = time.monotonic() - self._start_time
+
+        if exc_value:
+            if not isinstance(exc_value, TaskDeferred):
+                if self.task_instance.next_try_number > 
self.task_instance.max_tries:
+                    self.log.error(
+                        "Task instance %s for %s failed after %s attempts in 
%.2f seconds due to: %s",
+                        self.task_index,
+                        self.task_instance.task_id,
+                        self.task_instance.max_tries,
+                        elapsed,
+                        exc_value,
+                    )
+                    self.task_instance.state = TaskInstanceState.FAILED
+                    raise exc_value
+                self.task_instance.try_number += 1
+                self.task_instance.end_date = timezone.utcnow()
+                self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE
+                raise 
AirflowRescheduleTaskInstanceException(task=self.task_instance)
+            raise exc_value
+
+        self.task_instance.state = TaskInstanceState.SUCCESS
+        if self.log.isEnabledFor(logging.INFO):
+            self.log.info(
+                "Task instance %s for %s finished successfully in %s attempts 
in %.2f seconds",
+                self.task_index,
+                self.task_instance.task_id,
+                self.task_instance.next_try_number,
+                elapsed,
+            )
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        self.__exit__(exc_type, exc_value, traceback)
+
+
+class IterableOperator(BaseOperator):
+    """Object representing an iterable operator in a DAG."""
+
+    _operator: MappedOperator
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+    # each operator should override this class attr for shallow copy attrs.
+    shallow_copy_attrs: Sequence[str] = (
+        "_operator",
+        "expand_input",
+        "partial_kwargs",
+        "_log",
+    )
+
+    def __init__(
+        self,
+        *,
+        operator: MappedOperator,
+        expand_input: ExpandInput,
+        **kwargs,
+    ):
+        super().__init__(
+            **{
+                **kwargs,
+                "task_id": operator.task_id,
+                "owner": operator.owner,
+                "email": operator.email,
+                "email_on_retry": operator.email_on_retry,
+                "email_on_failure": operator.email_on_failure,
+                "retries": 0,  # We should not retry the IterableOperator, 
only the mapped ti's should be retried
+                "retry_delay": operator.retry_delay,
+                "retry_exponential_backoff": 
operator.retry_exponential_backoff,
+                "max_retry_delay": operator.max_retry_delay,
+                "start_date": operator.start_date,
+                "end_date": operator.end_date,
+                "depends_on_past": operator.depends_on_past,
+                "ignore_first_depends_on_past": 
operator.ignore_first_depends_on_past,
+                "wait_for_past_depends_before_skipping": 
operator.wait_for_past_depends_before_skipping,
+                "wait_for_downstream": operator.wait_for_downstream,
+                "dag": operator.dag,
+                "priority_weight": operator.priority_weight,
+                "queue": operator.queue,
+                "pool": operator.pool,
+                "pool_slots": operator.pool_slots,
+                "execution_timeout": None,
+                "trigger_rule": operator.trigger_rule,
+                "resources": operator.resources,
+                "run_as_user": operator.run_as_user,
+                "map_index_template": operator.map_index_template,
+                "max_active_tis_per_dag": operator.max_active_tis_per_dag,
+                "max_active_tis_per_dagrun": 
operator.max_active_tis_per_dagrun,
+                "executor": operator.executor,
+                "executor_config": operator.executor_config,
+                "inlets": operator.inlets,
+                "outlets": operator.outlets,
+                "task_group": operator.task_group,
+                "doc": operator.doc,
+                "doc_md": operator.doc_md,
+                "doc_json": operator.doc_json,
+                "doc_yaml": operator.doc_yaml,
+                "doc_rst": operator.doc_rst,
+                "task_display_name": operator.task_display_name,
+                "allow_nested_operators": operator.allow_nested_operators,
+            }
+        )
+        self._operator = operator
+        self.expand_input = expand_input
+        self.partial_kwargs = operator.partial_kwargs or {}
+        self._number_of_tasks: int = 0
+        XComArg.apply_upstream_relationship(self, self.expand_input.value)
+
+    @property
+    def task_type(self) -> str:
+        """@property: type of the task."""
+        return self._operator.__class__.__name__
+
+    @property
+    def max_workers(self):
+        return self.max_active_tis_per_dag or os.cpu_count() or 1
+
+    @property
+    def timeout(self) -> float | None:
+        if self.execution_timeout:
+            return self.execution_timeout.total_seconds()
+        return None
+
+    def _get_specified_expand_input(self) -> ExpandInput:
+        return self.expand_input
+
+    def _unmap_operator(self, mapped_kwargs: dict):
+        self._number_of_tasks += 1
+        return self._operator.unmap(mapped_kwargs)
+
+    def _xcom_push(self, context: Context, task: RuntimeTaskInstance, value: 
Any) -> None:
+        self.log.debug("Pushing XCom %s", task.map_index)
+
+        context["ti"].xcom_push(key=task.xcom_key, value=value)
+
+    def _run_tasks(
+        self,
+        context: Context,
+        tasks: Iterable[RuntimeTaskInstance],
+    ) -> None:
+        exception: BaseException | None = None
+        reschedule_date = timezone.utcnow()
+        prev_futures_count = 0
+        futures: dict[Future, RuntimeTaskInstance] = {}
+        failed_tasks: deque[RuntimeTaskInstance] = deque()
+        chunked_tasks = ichunked(tasks, self.max_workers)
+
+        self.log.info("Running tasks with %d workers", self.max_workers)
+
+        with event_loop() as loop:
+            with HybridExecutor(loop=loop, max_workers=self.max_workers) as 
executor:
+                for task in next(chunked_tasks, []):
+                    if task.is_async:
+                        future = executor.submit(self._run_async_operator, 
context, task)
+                    else:
+                        future = executor.submit(self._run_operator, context, 
task)
+                    futures[future] = task
+
+                while futures:
+                    futures_count = len(futures)
+
+                    if futures_count != prev_futures_count:
+                        self.log.info("Number of remaining futures: %s", 
futures_count)
+                        prev_futures_count = futures_count
+
+                    ready_futures = False
+
+                    for future in collect_futures(loop, futures.keys()):
+                        task = futures.pop(future)
+                        ready_futures = True
+
+                        try:
+                            if isinstance(future, asyncio.futures.Future):
+                                result = future.result()
+                            else:
+                                result = future.result(timeout=self.timeout)
+
+                            self.log.debug("result: %s", result)
+
+                            if result and task.task.do_xcom_push:
+                                self._xcom_push(
+                                    context=context,
+                                    task=task,
+                                    value=result,
+                                )
+                        except TaskDeferred as task_deferred:
+                            operator = DecoratedDeferredAsyncOperator(
+                                operator=task.task, task_deferred=task_deferred
+                            )
+                            failed_tasks.append(
+                                self._create_mapped_task(task.run_id, 
task.map_index, operator)
+                            )
+                        except asyncio.TimeoutError as e:
+                            self.log.warning("A timeout occurred for task_id 
%s", task.task_id)
+                            if task.next_try_number > self.retries:
+                                exception = AirflowTaskTimeout(e)
+                            else:
+                                reschedule_date = min(reschedule_date, 
task.next_retry_datetime())
+                                failed_tasks.append(task)
+                        except AirflowRescheduleTaskInstanceException as e:
+                            reschedule_date = min(reschedule_date, 
e.reschedule_date)
+                            self.log.warning(
+                                "An exception occurred for task_id %s with 
map_index %s, it has been rescheduled at %s",
+                                task.task_id,
+                                task.map_index,
+                                reschedule_date,
+                            )
+                            failed_tasks.append(e.task)
+                        except Exception as e:
+                            self.log.error(
+                                "An exception occurred for task_id %s with 
map_index %s",
+                                task.task_id,
+                                task.map_index,
+                            )
+                            exception = e
+
+                    if len(futures) < self.max_workers:
+                        for task in next(chunked_tasks, []):
+                            if task.is_async:
+                                future = 
executor.submit(self._run_async_operator, context, task)
+                            else:
+                                future = executor.submit(self._run_operator, 
context, task)
+                            futures[future] = task
+                    elif not ready_futures and futures:
+                        sleep(len(futures) * 0.1)
+
+        if not failed_tasks:
+            if exception:
+                raise exception
+            if self.do_xcom_push:
+                return XComIterable(
+                    task_id=self.task_id,
+                    dag_id=self.dag_id,
+                    run_id=context["run_id"],
+                    length=self._number_of_tasks,
+                )
+
+        now = timezone.utcnow()
+
+        # Calculate delay before the next retry
+        if reschedule_date > now:
+            delay_seconds = ceil((reschedule_date - now).total_seconds())
+
+            self.log.info(
+                "Attempting to run %s failed tasks within %s seconds...",
+                len(failed_tasks),
+                delay_seconds,
+            )
+
+            sleep(delay_seconds)
+

Review Comment:
   `_run_tasks` calls itself recursively for retries:
   
   ```python
   return self._run_tasks(context=context, tasks=list(failed_tasks))
   ```
   
   If tasks keep failing (e.g., a sensor-like pattern with many retries), this 
will hit Python's recursion limit. Should be a loop instead.



##########
task-sdk/src/airflow/sdk/execution_time/executor.py:
##########
@@ -0,0 +1,141 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import contextvars
+import inspect
+import os
+from asyncio import AbstractEventLoop, Semaphore
+from collections.abc import Callable
+from concurrent.futures import Future, ThreadPoolExecutor, as_completed
+from logging import Logger
+from typing import TYPE_CHECKING, Any
+
+from airflow.sdk import BaseAsyncOperator, BaseOperator
+from airflow.sdk.bases.operator import ExecutorSafeguard
+from airflow.sdk.execution_time.callback_runner import create_executable_runner
+from airflow.sdk.execution_time.context import (
+    context_get_outlet_events,
+    context_to_airflow_vars,
+)
+from airflow.sdk.execution_time.task_runner import (
+    RuntimeTaskInstance,
+    _run_task_state_change_callbacks,
+)
+
+if TYPE_CHECKING:
+    from airflow.sdk import Context
+
+
+def collect_futures(loop: AbstractEventLoop, futures: list[Any]):
+    """Yield futures as they complete (sync or async)."""
+    yield from as_completed(f for f in futures if isinstance(f, Future))
+
+    async_tasks = [f for f in futures if isinstance(f, asyncio.Task)]
+
+    if async_tasks:
+        for task, _ in zip(
+            async_tasks,
+            loop.run_until_complete(asyncio.gather(*async_tasks, 
return_exceptions=True)),
+        ):
+            yield task
+
+    return []
+

Review Comment:
   `os.environ.update(airflow_context_vars)` — `os.environ` is process-global. 
Multiple tasks running concurrently in threads via `HybridExecutor` will 
clobber each other's environment variables.



##########
task-sdk/src/airflow/sdk/bases/iterableoperator.py:
##########
@@ -0,0 +1,432 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+import time
+from collections import deque
+from collections.abc import Iterable, Sequence
+from concurrent.futures import Future
+from math import ceil
+from time import sleep
+from typing import TYPE_CHECKING, Any
+
+from more_itertools import ichunked
+
+from airflow.exceptions import (
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk import timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException
+from airflow.sdk.execution_time.executor import HybridExecutor, 
_execute_async_task, collect_futures
+from airflow.sdk.execution_time.lazy_sequence import XComIterable
+from airflow.sdk.execution_time.task_runner import MappedTaskInstance, 
RuntimeTaskInstance, _execute_task
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.definitions.mappedoperator import MappedOperator
+
+
+class TaskExecutor(LoggingMixin):
+    """Base class to run an operator or trigger with given task context and 
task instance."""
+
+    def __init__(
+        self,
+        task_instance: RuntimeTaskInstance,
+    ):
+        super().__init__()
+        self._task_instance = task_instance
+        self._result: Any | None = None
+        self._start_time: float | None = None
+
+    @property
+    def task_instance(self) -> RuntimeTaskInstance:
+        return self._task_instance
+
+    @property
+    def dag_id(self) -> str:
+        return self.task_instance.dag_id
+
+    @property
+    def task_id(self) -> str:
+        return self.task_instance.task_id
+
+    @property
+    def task_index(self) -> int:
+        return self.task_instance.map_index
+
+    @property
+    def key(self):
+        return self.task_instance.xcom_key
+
+    @property
+    def operator(self) -> BaseOperator:
+        return self.task_instance.task
+
+    @property
+    def is_async(self) -> bool:
+        return self.task_instance.is_async
+
+    def run(self, context: Context):
+        return _execute_task(context, self.task_instance, self.log)
+
+    async def arun(self, context: Context):
+        return await _execute_async_task(context, self.task_instance, self.log)
+
+    def __enter__(self):
+        self._start_time = time.monotonic()
+
+        if self.log.isEnabledFor(logging.INFO):
+            self.log.info(
+                "Attempting running task %s of %s for %s with map_index %s in 
%s mode.",
+                self.task_instance.try_number,
+                self.operator.retries,
+                self.task_instance.task_id,
+                self.task_index,
+                "async" if self.is_async else "sync",
+            )
+        return self
+
+    async def __aenter__(self):
+        return self.__enter__()
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        elapsed = time.monotonic() - self._start_time
+
+        if exc_value:
+            if not isinstance(exc_value, TaskDeferred):
+                if self.task_instance.next_try_number > 
self.task_instance.max_tries:
+                    self.log.error(
+                        "Task instance %s for %s failed after %s attempts in 
%.2f seconds due to: %s",
+                        self.task_index,
+                        self.task_instance.task_id,
+                        self.task_instance.max_tries,
+                        elapsed,
+                        exc_value,
+                    )
+                    self.task_instance.state = TaskInstanceState.FAILED
+                    raise exc_value
+                self.task_instance.try_number += 1
+                self.task_instance.end_date = timezone.utcnow()
+                self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE
+                raise 
AirflowRescheduleTaskInstanceException(task=self.task_instance)
+            raise exc_value
+
+        self.task_instance.state = TaskInstanceState.SUCCESS
+        if self.log.isEnabledFor(logging.INFO):
+            self.log.info(
+                "Task instance %s for %s finished successfully in %s attempts 
in %.2f seconds",
+                self.task_index,
+                self.task_instance.task_id,
+                self.task_instance.next_try_number,
+                elapsed,
+            )
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        self.__exit__(exc_type, exc_value, traceback)
+
+
+class IterableOperator(BaseOperator):
+    """Object representing an iterable operator in a DAG."""
+
+    _operator: MappedOperator
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+    # each operator should override this class attr for shallow copy attrs.
+    shallow_copy_attrs: Sequence[str] = (
+        "_operator",
+        "expand_input",
+        "partial_kwargs",
+        "_log",
+    )
+
+    def __init__(
+        self,
+        *,
+        operator: MappedOperator,
+        expand_input: ExpandInput,
+        **kwargs,
+    ):
+        super().__init__(
+            **{
+                **kwargs,
+                "task_id": operator.task_id,
+                "owner": operator.owner,
+                "email": operator.email,
+                "email_on_retry": operator.email_on_retry,
+                "email_on_failure": operator.email_on_failure,
+                "retries": 0,  # We should not retry the IterableOperator, 
only the mapped ti's should be retried
+                "retry_delay": operator.retry_delay,
+                "retry_exponential_backoff": 
operator.retry_exponential_backoff,
+                "max_retry_delay": operator.max_retry_delay,
+                "start_date": operator.start_date,
+                "end_date": operator.end_date,
+                "depends_on_past": operator.depends_on_past,
+                "ignore_first_depends_on_past": 
operator.ignore_first_depends_on_past,
+                "wait_for_past_depends_before_skipping": 
operator.wait_for_past_depends_before_skipping,
+                "wait_for_downstream": operator.wait_for_downstream,
+                "dag": operator.dag,
+                "priority_weight": operator.priority_weight,
+                "queue": operator.queue,
+                "pool": operator.pool,
+                "pool_slots": operator.pool_slots,
+                "execution_timeout": None,
+                "trigger_rule": operator.trigger_rule,
+                "resources": operator.resources,
+                "run_as_user": operator.run_as_user,
+                "map_index_template": operator.map_index_template,
+                "max_active_tis_per_dag": operator.max_active_tis_per_dag,
+                "max_active_tis_per_dagrun": 
operator.max_active_tis_per_dagrun,
+                "executor": operator.executor,
+                "executor_config": operator.executor_config,
+                "inlets": operator.inlets,
+                "outlets": operator.outlets,
+                "task_group": operator.task_group,
+                "doc": operator.doc,
+                "doc_md": operator.doc_md,
+                "doc_json": operator.doc_json,
+                "doc_yaml": operator.doc_yaml,
+                "doc_rst": operator.doc_rst,
+                "task_display_name": operator.task_display_name,
+                "allow_nested_operators": operator.allow_nested_operators,
+            }
+        )
+        self._operator = operator
+        self.expand_input = expand_input
+        self.partial_kwargs = operator.partial_kwargs or {}
+        self._number_of_tasks: int = 0
+        XComArg.apply_upstream_relationship(self, self.expand_input.value)
+
+    @property
+    def task_type(self) -> str:
+        """@property: type of the task."""
+        return self._operator.__class__.__name__
+
+    @property
+    def max_workers(self):
+        return self.max_active_tis_per_dag or os.cpu_count() or 1
+
+    @property
+    def timeout(self) -> float | None:
+        if self.execution_timeout:
+            return self.execution_timeout.total_seconds()
+        return None
+
+    def _get_specified_expand_input(self) -> ExpandInput:
+        return self.expand_input
+
+    def _unmap_operator(self, mapped_kwargs: dict):
+        self._number_of_tasks += 1
+        return self._operator.unmap(mapped_kwargs)
+
+    def _xcom_push(self, context: Context, task: RuntimeTaskInstance, value: 
Any) -> None:
+        self.log.debug("Pushing XCom %s", task.map_index)
+
+        context["ti"].xcom_push(key=task.xcom_key, value=value)
+
+    def _run_tasks(
+        self,
+        context: Context,
+        tasks: Iterable[RuntimeTaskInstance],
+    ) -> None:
+        exception: BaseException | None = None
+        reschedule_date = timezone.utcnow()
+        prev_futures_count = 0
+        futures: dict[Future, RuntimeTaskInstance] = {}
+        failed_tasks: deque[RuntimeTaskInstance] = deque()
+        chunked_tasks = ichunked(tasks, self.max_workers)
+
+        self.log.info("Running tasks with %d workers", self.max_workers)
+
+        with event_loop() as loop:
+            with HybridExecutor(loop=loop, max_workers=self.max_workers) as 
executor:
+                for task in next(chunked_tasks, []):
+                    if task.is_async:
+                        future = executor.submit(self._run_async_operator, 
context, task)
+                    else:
+                        future = executor.submit(self._run_operator, context, 
task)
+                    futures[future] = task
+
+                while futures:
+                    futures_count = len(futures)
+
+                    if futures_count != prev_futures_count:
+                        self.log.info("Number of remaining futures: %s", 
futures_count)
+                        prev_futures_count = futures_count
+
+                    ready_futures = False
+
+                    for future in collect_futures(loop, futures.keys()):
+                        task = futures.pop(future)
+                        ready_futures = True
+
+                        try:
+                            if isinstance(future, asyncio.futures.Future):
+                                result = future.result()
+                            else:
+                                result = future.result(timeout=self.timeout)
+
+                            self.log.debug("result: %s", result)
+
+                            if result and task.task.do_xcom_push:
+                                self._xcom_push(
+                                    context=context,
+                                    task=task,
+                                    value=result,
+                                )
+                        except TaskDeferred as task_deferred:
+                            operator = DecoratedDeferredAsyncOperator(
+                                operator=task.task, task_deferred=task_deferred
+                            )
+                            failed_tasks.append(
+                                self._create_mapped_task(task.run_id, 
task.map_index, operator)
+                            )
+                        except asyncio.TimeoutError as e:
+                            self.log.warning("A timeout occurred for task_id 
%s", task.task_id)
+                            if task.next_try_number > self.retries:
+                                exception = AirflowTaskTimeout(e)
+                            else:
+                                reschedule_date = min(reschedule_date, 
task.next_retry_datetime())
+                                failed_tasks.append(task)
+                        except AirflowRescheduleTaskInstanceException as e:
+                            reschedule_date = min(reschedule_date, 
e.reschedule_date)
+                            self.log.warning(
+                                "An exception occurred for task_id %s with 
map_index %s, it has been rescheduled at %s",
+                                task.task_id,
+                                task.map_index,
+                                reschedule_date,
+                            )
+                            failed_tasks.append(e.task)
+                        except Exception as e:
+                            self.log.error(
+                                "An exception occurred for task_id %s with 
map_index %s",
+                                task.task_id,
+                                task.map_index,
+                            )
+                            exception = e
+
+                    if len(futures) < self.max_workers:
+                        for task in next(chunked_tasks, []):
+                            if task.is_async:
+                                future = 
executor.submit(self._run_async_operator, context, task)
+                            else:
+                                future = executor.submit(self._run_operator, 
context, task)
+                            futures[future] = task
+                    elif not ready_futures and futures:

Review Comment:
   A couple concerns with `sleep()` here:
   
   1. `sleep(len(futures) * 0.1)` blocks the worker, preventing heartbeats. 
With many futures this could be a significant pause.
   
   2. The retry delay `sleep(delay_seconds)` below blocks the entire worker 
process. The scheduler might consider it dead and kill it.
   
   What's the reasoning behind the `0.1` multiplier per future?



##########
task-sdk/src/airflow/sdk/bases/iterableoperator.py:
##########
@@ -0,0 +1,432 @@
+#
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+from __future__ import annotations
+
+import asyncio
+import logging
+import os
+import time
+from collections import deque
+from collections.abc import Iterable, Sequence
+from concurrent.futures import Future
+from math import ceil
+from time import sleep
+from typing import TYPE_CHECKING, Any
+
+from more_itertools import ichunked
+
+from airflow.exceptions import (
+    AirflowTaskTimeout,
+    TaskDeferred,
+)
+from airflow.sdk import timezone
+from airflow.sdk.bases.operator import BaseOperator, 
DecoratedDeferredAsyncOperator, event_loop
+from airflow.sdk.definitions.xcom_arg import MapXComArg, XComArg  # noqa: F401
+from airflow.sdk.exceptions import AirflowRescheduleTaskInstanceException
+from airflow.sdk.execution_time.executor import HybridExecutor, 
_execute_async_task, collect_futures
+from airflow.sdk.execution_time.lazy_sequence import XComIterable
+from airflow.sdk.execution_time.task_runner import MappedTaskInstance, 
RuntimeTaskInstance, _execute_task
+from airflow.utils.log.logging_mixin import LoggingMixin
+from airflow.utils.state import TaskInstanceState
+
+if TYPE_CHECKING:
+    from airflow.sdk.definitions._internal.expandinput import ExpandInput
+    from airflow.sdk.definitions.context import Context
+    from airflow.sdk.definitions.mappedoperator import MappedOperator
+
+
+class TaskExecutor(LoggingMixin):
+    """Base class to run an operator or trigger with given task context and 
task instance."""
+
+    def __init__(
+        self,
+        task_instance: RuntimeTaskInstance,
+    ):
+        super().__init__()
+        self._task_instance = task_instance
+        self._result: Any | None = None
+        self._start_time: float | None = None
+
+    @property
+    def task_instance(self) -> RuntimeTaskInstance:
+        return self._task_instance
+
+    @property
+    def dag_id(self) -> str:
+        return self.task_instance.dag_id
+
+    @property
+    def task_id(self) -> str:
+        return self.task_instance.task_id
+
+    @property
+    def task_index(self) -> int:
+        return self.task_instance.map_index
+
+    @property
+    def key(self):
+        return self.task_instance.xcom_key
+
+    @property
+    def operator(self) -> BaseOperator:
+        return self.task_instance.task
+
+    @property
+    def is_async(self) -> bool:
+        return self.task_instance.is_async
+
+    def run(self, context: Context):
+        return _execute_task(context, self.task_instance, self.log)
+
+    async def arun(self, context: Context):
+        return await _execute_async_task(context, self.task_instance, self.log)
+
+    def __enter__(self):
+        self._start_time = time.monotonic()
+
+        if self.log.isEnabledFor(logging.INFO):
+            self.log.info(
+                "Attempting running task %s of %s for %s with map_index %s in 
%s mode.",
+                self.task_instance.try_number,
+                self.operator.retries,
+                self.task_instance.task_id,
+                self.task_index,
+                "async" if self.is_async else "sync",
+            )
+        return self
+
+    async def __aenter__(self):
+        return self.__enter__()
+
+    def __exit__(self, exc_type, exc_value, traceback):
+        elapsed = time.monotonic() - self._start_time
+
+        if exc_value:
+            if not isinstance(exc_value, TaskDeferred):
+                if self.task_instance.next_try_number > 
self.task_instance.max_tries:
+                    self.log.error(
+                        "Task instance %s for %s failed after %s attempts in 
%.2f seconds due to: %s",
+                        self.task_index,
+                        self.task_instance.task_id,
+                        self.task_instance.max_tries,
+                        elapsed,
+                        exc_value,
+                    )
+                    self.task_instance.state = TaskInstanceState.FAILED
+                    raise exc_value
+                self.task_instance.try_number += 1
+                self.task_instance.end_date = timezone.utcnow()
+                self.task_instance.state = TaskInstanceState.UP_FOR_RESCHEDULE
+                raise 
AirflowRescheduleTaskInstanceException(task=self.task_instance)
+            raise exc_value
+
+        self.task_instance.state = TaskInstanceState.SUCCESS
+        if self.log.isEnabledFor(logging.INFO):
+            self.log.info(
+                "Task instance %s for %s finished successfully in %s attempts 
in %.2f seconds",
+                self.task_index,
+                self.task_instance.task_id,
+                self.task_instance.next_try_number,
+                elapsed,
+            )
+
+    async def __aexit__(self, exc_type, exc_value, traceback):
+        self.__exit__(exc_type, exc_value, traceback)
+
+
+class IterableOperator(BaseOperator):
+    """Object representing an iterable operator in a DAG."""
+
+    _operator: MappedOperator
+    expand_input: ExpandInput
+    partial_kwargs: dict[str, Any]
+    # each operator should override this class attr for shallow copy attrs.
+    shallow_copy_attrs: Sequence[str] = (
+        "_operator",
+        "expand_input",
+        "partial_kwargs",
+        "_log",
+    )
+
+    def __init__(
+        self,
+        *,
+        operator: MappedOperator,
+        expand_input: ExpandInput,
+        **kwargs,
+    ):
+        super().__init__(
+            **{
+                **kwargs,
+                "task_id": operator.task_id,
+                "owner": operator.owner,
+                "email": operator.email,
+                "email_on_retry": operator.email_on_retry,
+                "email_on_failure": operator.email_on_failure,
+                "retries": 0,  # We should not retry the IterableOperator, 
only the mapped ti's should be retried
+                "retry_delay": operator.retry_delay,
+                "retry_exponential_backoff": 
operator.retry_exponential_backoff,
+                "max_retry_delay": operator.max_retry_delay,
+                "start_date": operator.start_date,
+                "end_date": operator.end_date,
+                "depends_on_past": operator.depends_on_past,
+                "ignore_first_depends_on_past": 
operator.ignore_first_depends_on_past,
+                "wait_for_past_depends_before_skipping": 
operator.wait_for_past_depends_before_skipping,
+                "wait_for_downstream": operator.wait_for_downstream,
+                "dag": operator.dag,
+                "priority_weight": operator.priority_weight,
+                "queue": operator.queue,
+                "pool": operator.pool,
+                "pool_slots": operator.pool_slots,
+                "execution_timeout": None,
+                "trigger_rule": operator.trigger_rule,
+                "resources": operator.resources,
+                "run_as_user": operator.run_as_user,
+                "map_index_template": operator.map_index_template,
+                "max_active_tis_per_dag": operator.max_active_tis_per_dag,
+                "max_active_tis_per_dagrun": 
operator.max_active_tis_per_dagrun,
+                "executor": operator.executor,
+                "executor_config": operator.executor_config,
+                "inlets": operator.inlets,
+                "outlets": operator.outlets,
+                "task_group": operator.task_group,
+                "doc": operator.doc,
+                "doc_md": operator.doc_md,
+                "doc_json": operator.doc_json,
+                "doc_yaml": operator.doc_yaml,
+                "doc_rst": operator.doc_rst,
+                "task_display_name": operator.task_display_name,
+                "allow_nested_operators": operator.allow_nested_operators,
+            }
+        )
+        self._operator = operator
+        self.expand_input = expand_input
+        self.partial_kwargs = operator.partial_kwargs or {}
+        self._number_of_tasks: int = 0
+        XComArg.apply_upstream_relationship(self, self.expand_input.value)
+
+    @property
+    def task_type(self) -> str:
+        """@property: type of the task."""
+        return self._operator.__class__.__name__
+
+    @property
+    def max_workers(self):
+        return self.max_active_tis_per_dag or os.cpu_count() or 1
+
+    @property
+    def timeout(self) -> float | None:
+        if self.execution_timeout:
+            return self.execution_timeout.total_seconds()
+        return None
+
+    def _get_specified_expand_input(self) -> ExpandInput:
+        return self.expand_input
+
+    def _unmap_operator(self, mapped_kwargs: dict):
+        self._number_of_tasks += 1
+        return self._operator.unmap(mapped_kwargs)
+
+    def _xcom_push(self, context: Context, task: RuntimeTaskInstance, value: 
Any) -> None:
+        self.log.debug("Pushing XCom %s", task.map_index)
+
+        context["ti"].xcom_push(key=task.xcom_key, value=value)
+
+    def _run_tasks(
+        self,
+        context: Context,
+        tasks: Iterable[RuntimeTaskInstance],
+    ) -> None:
+        exception: BaseException | None = None
+        reschedule_date = timezone.utcnow()
+        prev_futures_count = 0
+        futures: dict[Future, RuntimeTaskInstance] = {}
+        failed_tasks: deque[RuntimeTaskInstance] = deque()
+        chunked_tasks = ichunked(tasks, self.max_workers)
+
+        self.log.info("Running tasks with %d workers", self.max_workers)
+
+        with event_loop() as loop:
+            with HybridExecutor(loop=loop, max_workers=self.max_workers) as 
executor:
+                for task in next(chunked_tasks, []):
+                    if task.is_async:
+                        future = executor.submit(self._run_async_operator, 
context, task)
+                    else:
+                        future = executor.submit(self._run_operator, context, 
task)
+                    futures[future] = task
+
+                while futures:
+                    futures_count = len(futures)
+
+                    if futures_count != prev_futures_count:
+                        self.log.info("Number of remaining futures: %s", 
futures_count)
+                        prev_futures_count = futures_count
+
+                    ready_futures = False
+
+                    for future in collect_futures(loop, futures.keys()):
+                        task = futures.pop(future)
+                        ready_futures = True
+
+                        try:
+                            if isinstance(future, asyncio.futures.Future):
+                                result = future.result()
+                            else:
+                                result = future.result(timeout=self.timeout)
+
+                            self.log.debug("result: %s", result)
+
+                            if result and task.task.do_xcom_push:
+                                self._xcom_push(
+                                    context=context,
+                                    task=task,
+                                    value=result,
+                                )
+                        except TaskDeferred as task_deferred:
+                            operator = DecoratedDeferredAsyncOperator(
+                                operator=task.task, task_deferred=task_deferred
+                            )
+                            failed_tasks.append(
+                                self._create_mapped_task(task.run_id, 
task.map_index, operator)
+                            )
+                        except asyncio.TimeoutError as e:
+                            self.log.warning("A timeout occurred for task_id 
%s", task.task_id)
+                            if task.next_try_number > self.retries:
+                                exception = AirflowTaskTimeout(e)
+                            else:
+                                reschedule_date = min(reschedule_date, 
task.next_retry_datetime())
+                                failed_tasks.append(task)
+                        except AirflowRescheduleTaskInstanceException as e:
+                            reschedule_date = min(reschedule_date, 
e.reschedule_date)
+                            self.log.warning(
+                                "An exception occurred for task_id %s with 
map_index %s, it has been rescheduled at %s",
+                                task.task_id,
+                                task.map_index,
+                                reschedule_date,
+                            )
+                            failed_tasks.append(e.task)
+                        except Exception as e:
+                            self.log.error(
+                                "An exception occurred for task_id %s with 
map_index %s",
+                                task.task_id,
+                                task.map_index,
+                            )

Review Comment:
   When multiple tasks fail, `exception = e` overwrites the previous exception 
each time. Only the last failure gets raised — all prior failures are silently 
lost. Makes debugging really hard when multiple sub-tasks fail.
   
   Consider collecting all exceptions and raising an `ExceptionGroup` or at 
least logging each one.



-- 
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.

To unsubscribe, e-mail: [email protected]

For queries about this service, please contact Infrastructure at:
[email protected]

Reply via email to