This is an automated email from the ASF dual-hosted git repository.

dstandish pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new 309d658204b Refactor task runner for spans (#62589)
309d658204b is described below

commit 309d658204b418913c6ed23341370e222240c6c7
Author: Daniel Standish <[email protected]>
AuthorDate: Mon Mar 2 19:12:52 2026 -0800

    Refactor task runner for spans (#62589)
    
    We need to get the ti details object so we can get the context carrier, so 
that we can open the span.
    
    This small adjustment lets us get the ti details message separately from 
the rest of startup, which will make it clean to enclose all of it in a span.
---
 task-sdk/src/airflow/sdk/execution_time/task_runner.py  |  9 +++++++--
 .../tests/task_sdk/execution_time/test_task_runner.py   | 17 ++++++++++-------
 2 files changed, 17 insertions(+), 9 deletions(-)

diff --git a/task-sdk/src/airflow/sdk/execution_time/task_runner.py 
b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
index 7355f886ee8..0ef7a74a24b 100644
--- a/task-sdk/src/airflow/sdk/execution_time/task_runner.py
+++ b/task-sdk/src/airflow/sdk/execution_time/task_runner.py
@@ -843,7 +843,7 @@ def _verify_bundle_access(bundle_instance: BaseDagBundle, 
log: Logger) -> None:
         )
 
 
-def startup() -> tuple[RuntimeTaskInstance, Context, Logger]:
+def get_startup_details() -> StartupDetails:
     # The parent sends us a StartupDetails message un-prompted. After this, 
every single message is only sent
     # in response to us sending a request.
     log = structlog.get_logger(logger_name="task")
@@ -867,7 +867,11 @@ def startup() -> tuple[RuntimeTaskInstance, Context, 
Logger]:
 
         if not isinstance(msg, StartupDetails):
             raise RuntimeError(f"Unhandled startup message {type(msg)} {msg}")
+    return msg
 
+
+def startup(msg: StartupDetails) -> tuple[RuntimeTaskInstance, Context, 
Logger]:
+    log = structlog.get_logger("task")
     # setproctitle causes issue on Mac OS: 
https://github.com/benoitc/gunicorn/issues/3021
     os_type = sys.platform
     if os_type == "darwin":
@@ -1803,7 +1807,8 @@ def main():
 
     try:
         try:
-            ti, context, log = startup()
+            startup_details = get_startup_details()
+            ti, context, log = startup(msg=startup_details)
         except AirflowRescheduleException as reschedule:
             log.warning("Rescheduling task during startup, marking task as 
UP_FOR_RESCHEDULE")
             SUPERVISOR_COMMS.send(
diff --git a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py 
b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
index f40b4399cff..2a495e557f7 100644
--- a/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
+++ b/task-sdk/tests/task_sdk/execution_time/test_task_runner.py
@@ -130,6 +130,7 @@ from airflow.sdk.execution_time.task_runner import (
     _push_xcom_if_needed,
     _xcom_push,
     finalize,
+    get_startup_details,
     parse,
     run,
     startup,
@@ -363,9 +364,10 @@ def 
test_parse_not_found_does_not_reschedule_when_max_attempts_reached(test_dags
 
 @mock.patch("builtins.exit", side_effect=lambda code: (_ for _ in 
()).throw(SystemExit(code)))
 @mock.patch("airflow.sdk.execution_time.task_runner.startup")
[email protected]("airflow.sdk.execution_time.task_runner.get_startup_details")
 @mock.patch("airflow.sdk.execution_time.task_runner.CommsDecoder")
 def test_main_sends_reschedule_task_when_startup_reschedules(
-    mock_comms_decoder_cls, mock_startup, mock_exit, time_machine
+    mock_comms_decoder_cls, mock_get_startup_details, mock_startup, mock_exit, 
time_machine
 ):
     """
     If startup raises AirflowRescheduleException, the task runner should 
report a RescheduleTask
@@ -377,6 +379,7 @@ def 
test_main_sends_reschedule_task_when_startup_reschedules(
     mock_comms_instance = mock.Mock()
     mock_comms_instance.socket = None
     mock_comms_decoder_cls.__getitem__.return_value.return_value = 
mock_comms_instance
+    mock_get_startup_details.return_value = mock.Mock()
     mock_startup.side_effect = 
AirflowRescheduleException(reschedule_date=reschedule_date)
 
     # Move time
@@ -927,7 +930,7 @@ def test_startup_and_run_dag_with_rtif(
 
     mock_supervisor_comms._get_response.return_value = what
 
-    run(*startup())
+    run(*startup(get_startup_details()))
     expected_calls = [
         
mock.call.send(SetRenderedFields(rendered_fields=expected_rendered_fields)),
         mock.call.send(
@@ -977,7 +980,7 @@ def test_task_run_with_user_impersonation(
     mock_supervisor_comms.socket.fileno.return_value = 42
 
     with mock.patch.dict(os.environ, {}, clear=True):
-        startup()
+        startup(get_startup_details())
 
         assert os.environ["_AIRFLOW__REEXECUTED_PROCESS"] == "1"
         assert "_AIRFLOW__STARTUP_MSG" in os.environ
@@ -1026,7 +1029,7 @@ def test_task_run_with_user_impersonation_default_user(
     mock_get_user.return_value = "default_user"
 
     with mock.patch.dict(os.environ, {}, clear=True):
-        startup()
+        startup(get_startup_details())
 
         assert "_AIRFLOW__REEXECUTED_PROCESS" not in os.environ
         assert "_AIRFLOW__STARTUP_MSG" not in os.environ
@@ -1069,7 +1072,7 @@ def 
test_task_run_with_user_impersonation_remove_krb5ccname_on_reexecuted_proces
         "_AIRFLOW__STARTUP_MSG": what.model_dump_json(),
     }
     with mock.patch.dict("os.environ", mock_os_env, clear=True):
-        startup()
+        startup(get_startup_details())
 
         assert os.environ["_AIRFLOW__REEXECUTED_PROCESS"] == "1"
         assert "_AIRFLOW__STARTUP_MSG" in os.environ
@@ -1241,7 +1244,7 @@ def test_dag_parsing_context(make_ti_context, 
mock_supervisor_comms, monkeypatch
     )
 
     monkeypatch.setenv("AIRFLOW__DAG_PROCESSOR__DAG_BUNDLE_CONFIG_LIST", 
dag_bundle_val)
-    ti, _, _ = startup()
+    ti, _, _ = startup(get_startup_details())
 
     # Presence of `conditional_task` below means Dag ID is properly set in the 
parsing context!
     # Check the dag file for the actual logic!
@@ -3596,7 +3599,7 @@ class TestTaskRunnerCallsListeners:
         mock_supervisor_comms._get_response.return_value = what
         mocked_parse(what, "basic_dag", task)
 
-        runtime_ti, context, log = startup()
+        runtime_ti, context, log = startup(get_startup_details())
         assert runtime_ti is not None
         assert isinstance(listener.component, TaskRunnerMarker)
         del listener.component

Reply via email to