This is an automated email from the ASF dual-hosted git repository.

potiuk pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/main by this push:
     new e2c73fd35ca AIP-67 - Multi-team: Per team executor config (env var 
only) (#55003)
e2c73fd35ca is described below

commit e2c73fd35cab3274307b0f4a08615fdbc905f53c
Author: Niko Oliveira <oniko...@amazon.com>
AuthorDate: Sun Sep 7 17:08:51 2025 -0700

    AIP-67 - Multi-team: Per team executor config (env var only) (#55003)
    
    * Multi-team: per team executor config (env var only)
    
    Configuration for teams can now be specified via environment variable
    using the triple underscore syntax outlined in AIP-67. This applies to
    any configuration, but specifically is required for executor based
    configuration.
    
    A small shim has been added to BaseExecutor to allow easier access to
    team based config.
    
    ECS executor is converted to this new shim as a proof of concept for the
    mechanism.
    
    * PR Feedback: comment fixup
---
 airflow-core/src/airflow/configuration.py          |  21 ++--
 .../src/airflow/executors/base_executor.py         |  23 ++++
 airflow-core/tests/unit/core/test_configuration.py |  11 ++
 .../amazon/aws/executors/ecs/ecs_executor.py       |  45 ++++----
 .../aws/executors/ecs/ecs_executor_config.py       |  17 +--
 .../amazon/aws/executors/ecs/test_ecs_executor.py  | 117 ++++++++++++++++-----
 6 files changed, 177 insertions(+), 57 deletions(-)

diff --git a/airflow-core/src/airflow/configuration.py 
b/airflow-core/src/airflow/configuration.py
index 8c384e64266..24a439e848e 100644
--- a/airflow-core/src/airflow/configuration.py
+++ b/airflow-core/src/airflow/configuration.py
@@ -877,12 +877,14 @@ class AirflowConfigParser(ConfigParser):
             mask_secret_core(value)
             mask_secret_sdk(value)
 
-    def _env_var_name(self, section: str, key: str) -> str:
-        return f"{ENV_VAR_PREFIX}{section.replace('.', 
'_').upper()}__{key.upper()}"
-
-    def _get_env_var_option(self, section: str, key: str):
-        # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore)
-        env_var = self._env_var_name(section, key)
+    def _env_var_name(self, section: str, key: str, team_name: str | None = 
None) -> str:
+        team_component: str = f"{team_name.upper()}___" if team_name else ""
+        return f"{ENV_VAR_PREFIX}{team_component}{section.replace('.', 
'_').upper()}__{key.upper()}"
+
+    def _get_env_var_option(self, section: str, key: str, team_name: str | 
None = None):
+        # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore) 
OR for team based
+        # configuration must have the format 
AIRFLOW__{TEAM_NAME}___{SECTION}__{KEY}
+        env_var: str = self._env_var_name(section, key, team_name=team_name)
         if env_var in os.environ:
             return expand_env_var(os.environ[env_var])
         # alternatively AIRFLOW__{SECTION}__{KEY}_CMD (for a command)
@@ -982,6 +984,7 @@ class AirflowConfigParser(ConfigParser):
         suppress_warnings: bool = False,
         lookup_from_deprecated: bool = True,
         _extra_stacklevel: int = 0,
+        team_name: str | None = None,
         **kwargs,
     ) -> str | None:
         section = section.lower()
@@ -1044,6 +1047,7 @@ class AirflowConfigParser(ConfigParser):
             section,
             issue_warning=not warning_emitted,
             extra_stacklevel=_extra_stacklevel,
+            team_name=team_name,
         )
         if option is not None:
             return option
@@ -1170,13 +1174,14 @@ class AirflowConfigParser(ConfigParser):
         section: str,
         issue_warning: bool = True,
         extra_stacklevel: int = 0,
+        team_name: str | None = None,
     ) -> str | None:
-        option = self._get_env_var_option(section, key)
+        option = self._get_env_var_option(section, key, team_name=team_name)
         if option is not None:
             return option
         if deprecated_section and deprecated_key:
             with self.suppress_future_warnings():
-                option = self._get_env_var_option(deprecated_section, 
deprecated_key)
+                option = self._get_env_var_option(deprecated_section, 
deprecated_key, team_name=team_name)
             if option is not None:
                 if issue_warning:
                     self._warn_deprecate(section, key, deprecated_section, 
deprecated_key, extra_stacklevel)
diff --git a/airflow-core/src/airflow/executors/base_executor.py 
b/airflow-core/src/airflow/executors/base_executor.py
index 1723d94708b..70b9ba9ef08 100644
--- a/airflow-core/src/airflow/executors/base_executor.py
+++ b/airflow-core/src/airflow/executors/base_executor.py
@@ -97,6 +97,28 @@ class RunningRetryAttemptType:
         return True
 
 
+class ExecutorConf:
+    """
+    This class is used to fetch configuration for an executor for a particular 
team_name.
+
+    It wraps the implementation of the configuration.get() to look for the 
particular section and key
+    prefixed with the team_name. This makes it easy for child classes (i.e. 
concrete executors) to fetch
+    configuration values for a particular team_name without having to worry 
about passing through the
+    team_name for every call to get configuration.
+
+    Currently config only supports environment variables for team specific 
configuration.
+    """
+
+    def __init__(self, team_name: str | None = None) -> None:
+        self.team_name: str | None = team_name
+
+    def get(self, *args, **kwargs) -> str | None:
+        return conf.get(*args, **kwargs, team_name=self.team_name)
+
+    def getboolean(self, *args, **kwargs) -> bool:
+        return conf.getboolean(*args, **kwargs, team_name=self.team_name)
+
+
 class BaseExecutor(LoggingMixin):
     """
     Base class to inherit for concrete executors such as Celery, Kubernetes, 
Local, etc.
@@ -150,6 +172,7 @@ class BaseExecutor(LoggingMixin):
         self.running: set[TaskInstanceKey] = set()
         self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {}
         self._task_event_logs: deque[Log] = deque()
+        self.conf = ExecutorConf(team_name)
 
         if self.parallelism <= 0:
             raise ValueError("parallelism is set to 0 or lower")
diff --git a/airflow-core/tests/unit/core/test_configuration.py 
b/airflow-core/tests/unit/core/test_configuration.py
index 3e518294a43..757804428a4 100644
--- a/airflow-core/tests/unit/core/test_configuration.py
+++ b/airflow-core/tests/unit/core/test_configuration.py
@@ -158,6 +158,17 @@ class TestConf:
 
         assert conf.has_option("testsection", "testkey")
 
+    def test_env_team(self):
+        with patch(
+            "os.environ",
+            {
+                "AIRFLOW__CELERY__RESULT_BACKEND": "FOO",
+                "AIRFLOW__UNIT_TEST_TEAM___CELERY__RESULT_BACKEND": "BAR",
+            },
+        ):
+            assert conf.get("celery", "result_backend") == "FOO"
+            assert conf.get("celery", "result_backend", 
team_name="unit_test_team") == "BAR"
+
     @conf_vars({("core", "percent"): "with%%inside"})
     def test_conf_as_dict(self):
         cfg_dict = conf.as_dict()
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
index a7fe158f14c..396e998af61 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py
@@ -32,7 +32,6 @@ from typing import TYPE_CHECKING
 
 from botocore.exceptions import ClientError, NoCredentialsError
 
-from airflow.configuration import conf
 from airflow.exceptions import AirflowException
 from airflow.executors.base_executor import BaseExecutor
 from airflow.providers.amazon.aws.executors.ecs.boto_schema import 
BotoDescribeTasksSchema, BotoRunTaskSchema
@@ -98,13 +97,6 @@ class AwsEcsExecutor(BaseExecutor):
      Airflow TaskInstance's executor_config.
     """
 
-    # Maximum number of retries to run an ECS task.
-    MAX_RUN_TASK_ATTEMPTS = conf.get(
-        CONFIG_GROUP_NAME,
-        AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS,
-        fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS],
-    )
-
     # AWS limits the maximum number of ARNs in the describe_tasks function.
     DESCRIBE_TASKS_BATCH_SIZE = 99
 
@@ -118,8 +110,18 @@ class AwsEcsExecutor(BaseExecutor):
         self.active_workers: EcsTaskCollection = EcsTaskCollection()
         self.pending_tasks: deque = deque()
 
-        self.cluster = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER)
-        self.container_name = conf.get(CONFIG_GROUP_NAME, 
AllEcsConfigKeys.CONTAINER_NAME)
+        # Check if self has the ExecutorConf set on the self.conf attribute, 
and if not, set it to the global
+        # configuration object. This allows the changes to be backwards 
compatible with older versions of
+        # Airflow.
+        # Can be removed when minimum supported provider version is equal to 
the version of core airflow
+        # which introduces multi-team configuration.
+        if not hasattr(self, "conf"):
+            from airflow.configuration import conf
+
+            self.conf = conf
+
+        self.cluster = self.conf.get(CONFIG_GROUP_NAME, 
AllEcsConfigKeys.CLUSTER)
+        self.container_name = self.conf.get(CONFIG_GROUP_NAME, 
AllEcsConfigKeys.CONTAINER_NAME)
         self.attempts_since_last_successful_connection = 0
 
         self.load_ecs_connection(check_connection=False)
@@ -127,6 +129,13 @@ class AwsEcsExecutor(BaseExecutor):
 
         self.run_task_kwargs = self._load_run_kwargs()
 
+        # Maximum number of retries to run an ECS task.
+        self.max_run_task_attempts = self.conf.get(
+            CONFIG_GROUP_NAME,
+            AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS,
+            fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS],
+        )
+
     def queue_workload(self, workload: workloads.All, session: Session | None) 
-> None:
         from airflow.executors import workloads
 
@@ -154,7 +163,7 @@ class AwsEcsExecutor(BaseExecutor):
 
     def start(self):
         """Call this when the Executor is run for the first time by the 
scheduler."""
-        check_health = conf.getboolean(
+        check_health = self.conf.getboolean(
             CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, 
fallback=False
         )
 
@@ -218,12 +227,12 @@ class AwsEcsExecutor(BaseExecutor):
 
     def load_ecs_connection(self, check_connection: bool = True):
         self.log.info("Loading Connection information")
-        aws_conn_id = conf.get(
+        aws_conn_id = self.conf.get(
             CONFIG_GROUP_NAME,
             AllEcsConfigKeys.AWS_CONN_ID,
             fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID],
         )
-        region_name = conf.get(CONFIG_GROUP_NAME, 
AllEcsConfigKeys.REGION_NAME, fallback=None)
+        region_name = self.conf.get(CONFIG_GROUP_NAME, 
AllEcsConfigKeys.REGION_NAME, fallback=None)
         self.ecs = EcsHook(aws_conn_id=aws_conn_id, 
region_name=region_name).conn
         self.attempts_since_last_successful_connection += 1
         self.last_connection_reload = timezone.utcnow()
@@ -340,13 +349,13 @@ class AwsEcsExecutor(BaseExecutor):
         queue = task_info.queue
         exec_info = task_info.config
         failure_count = self.active_workers.failure_count_by_key(task_key)
-        if int(failure_count) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
+        if int(failure_count) < int(self.max_run_task_attempts):
             self.log.warning(
                 "Airflow task %s failed due to %s. Failure %s out of %s 
occurred on %s. Rescheduling.",
                 task_key,
                 reason,
                 failure_count,
-                self.__class__.MAX_RUN_TASK_ATTEMPTS,
+                self.max_run_task_attempts,
                 task_arn,
             )
             self.pending_tasks.append(
@@ -416,8 +425,8 @@ class AwsEcsExecutor(BaseExecutor):
                     failure_reasons.extend([f["reason"] for f in 
run_task_response["failures"]])
 
             if failure_reasons:
-                # Make sure the number of attempts does not exceed 
MAX_RUN_TASK_ATTEMPTS
-                if int(attempt_number) < 
int(self.__class__.MAX_RUN_TASK_ATTEMPTS):
+                # Make sure the number of attempts does not exceed 
max_run_task_attempts
+                if int(attempt_number) < int(self.max_run_task_attempts):
                     ecs_task.attempt_number += 1
                     ecs_task.next_attempt_time = timezone.utcnow() + 
calculate_next_attempt_delay(
                         attempt_number
@@ -545,7 +554,7 @@ class AwsEcsExecutor(BaseExecutor):
     def _load_run_kwargs(self) -> dict:
         from airflow.providers.amazon.aws.executors.ecs.ecs_executor_config 
import build_task_kwargs
 
-        ecs_executor_run_task_kwargs = build_task_kwargs()
+        ecs_executor_run_task_kwargs = build_task_kwargs(self.conf)
 
         try:
             
self.get_container(ecs_executor_run_task_kwargs["overrides"]["containerOverrides"])["command"]
diff --git 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py
 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py
index bcbbbfc9e8c..f7753903a2f 100644
--- 
a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py
+++ 
b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py
@@ -32,7 +32,6 @@ from __future__ import annotations
 import json
 from json import JSONDecodeError
 
-from airflow.configuration import conf
 from airflow.providers.amazon.aws.executors.ecs.utils import (
     CONFIG_GROUP_NAME,
     ECS_LAUNCH_TYPE_EC2,
@@ -46,23 +45,27 @@ from airflow.providers.amazon.aws.hooks.ecs import EcsHook
 from airflow.utils.helpers import prune_dict
 
 
-def _fetch_templated_kwargs() -> dict[str, str]:
-    run_task_kwargs_value = conf.get(CONFIG_GROUP_NAME, 
AllEcsConfigKeys.RUN_TASK_KWARGS, fallback=dict())
+def _fetch_templated_kwargs(conf) -> dict[str, str]:
+    run_task_kwargs_value = conf.get(
+        CONFIG_GROUP_NAME,
+        AllEcsConfigKeys.RUN_TASK_KWARGS,
+        fallback=dict(),
+    )
     return json.loads(str(run_task_kwargs_value))
 
 
-def _fetch_config_values() -> dict[str, str]:
+def _fetch_config_values(conf) -> dict[str, str]:
     return prune_dict(
         {key: conf.get(CONFIG_GROUP_NAME, key, fallback=None) for key in 
RunTaskKwargsConfigKeys()}
     )
 
 
-def build_task_kwargs() -> dict:
+def build_task_kwargs(conf) -> dict:
     all_config_keys = AllEcsConfigKeys()
     # This will put some kwargs at the root of the dictionary that do NOT 
belong there. However,
     # the code below expects them to be there and will rearrange them as 
necessary.
-    task_kwargs = _fetch_config_values()
-    task_kwargs.update(_fetch_templated_kwargs())
+    task_kwargs = _fetch_config_values(conf)
+    task_kwargs.update(_fetch_templated_kwargs(conf))
 
     has_launch_type: bool = all_config_keys.LAUNCH_TYPE in task_kwargs
     has_capacity_provider: bool = all_config_keys.CAPACITY_PROVIDER_STRATEGY 
in task_kwargs
diff --git 
a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py 
b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
index 6d0b0055906..8f5fe85bf10 100644
--- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
+++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py
@@ -25,7 +25,7 @@ import time
 from collections.abc import Callable
 from functools import partial
 from unittest import mock
-from unittest.mock import MagicMock
+from unittest.mock import MagicMock, patch
 
 import pytest
 import yaml
@@ -33,7 +33,9 @@ from botocore.exceptions import ClientError
 from inflection import camelize
 from semver import VersionInfo
 
+from airflow.configuration import conf
 from airflow.exceptions import AirflowException
+from airflow.executors import base_executor
 from airflow.executors.base_executor import BaseExecutor
 from airflow.models import TaskInstance
 from airflow.models.taskinstancekey import TaskInstanceKey
@@ -538,7 +540,7 @@ class TestAwsEcsExecutor:
         mock_executor.execute_async(mock_airflow_key, mock_cmd)
 
         # No matter what, don't schedule until run_task becomes successful.
-        for _ in range(int(mock_executor.MAX_RUN_TASK_ATTEMPTS) * 2):
+        for _ in range(int(mock_executor.max_run_task_attempts) * 2):
             mock_executor.attempt_task_runs()
             # Task is not stored in active workers.
             assert len(mock_executor.active_workers) == 0
@@ -555,7 +557,7 @@ class TestAwsEcsExecutor:
         mock_executor.execute_async(mock_airflow_key, mock_cmd)
 
         # No matter what, don't schedule until run_task becomes successful.
-        for _ in range(int(mock_executor.MAX_RUN_TASK_ATTEMPTS) * 2):
+        for _ in range(int(mock_executor.max_run_task_attempts) * 2):
             mock_executor.attempt_task_runs()
             # Task is not stored in active workers.
             assert len(mock_executor.active_workers) == 0
@@ -567,7 +569,7 @@ class TestAwsEcsExecutor:
 
         The executor should attempt each task exactly once per sync() 
iteration.
         It should preserve the order of tasks, and attempt each task up to
-        `MAX_RUN_TASK_ATTEMPTS` times before dropping the task.
+        `max_run_task_attempts` times before dropping the task.
         """
         airflow_keys = [
             TaskInstanceKey("a", "task_a", "c", 1, -1),
@@ -627,7 +629,7 @@ class TestAwsEcsExecutor:
 
         The executor should attempt each task exactly once per sync() 
iteration.
         It should preserve the order of tasks, and attempt each task up to
-        `MAX_RUN_TASK_ATTEMPTS` times before dropping the task. If a task 
succeeds, the task
+        `max_run_task_attempts` times before dropping the task. If a task 
succeeds, the task
         should be removed from pending_jobs and into active_workers.
         """
         airflow_keys = [
@@ -705,7 +707,7 @@ class TestAwsEcsExecutor:
         """
         Test API failure retries.
         """
-        AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "2"
+        mock_executor.max_run_task_attempts = "2"
         airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"]
         airflow_commands = [_generate_mock_cmd(), _generate_mock_cmd()]
 
@@ -834,7 +836,7 @@ class TestAwsEcsExecutor:
     @mock.patch.object(BaseExecutor, "success")
     def test_failed_sync(self, success_mock, fail_mock, mock_executor):
         """Test success and failure states."""
-        AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "1"
+        mock_executor.max_run_task_attempts = "1"
         self._mock_sync(mock_executor, State.FAILED)
 
         mock_executor.sync()
@@ -850,7 +852,7 @@ class TestAwsEcsExecutor:
     @mock.patch.object(BaseExecutor, "fail")
     def test_removed_sync(self, fail_mock, success_mock, mock_executor):
         """A removed task will be treated as a failed task."""
-        AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "1"
+        mock_executor.max_run_task_attempts = "1"
         self._mock_sync(mock_executor, expected_state=State.REMOVED, 
set_task_state=State.REMOVED)
 
         mock_executor.sync_running_tasks()
@@ -868,7 +870,7 @@ class TestAwsEcsExecutor:
         self, _, success_mock, fail_mock, mock_airflow_key, mock_executor, 
mock_cmd
     ):
         """Test that failure_count/attempt_number is cumulative for pending 
tasks and active workers."""
-        AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "5"
+        mock_executor.max_run_task_attempts = "5"
         mock_executor.ecs.run_task.return_value = {
             "tasks": [],
             "failures": [
@@ -980,8 +982,8 @@ class TestAwsEcsExecutor:
         assert len(mock_executor.active_workers.get_all_arns()) == 1
         task_key = mock_executor.active_workers.arn_to_key[ARN1]
 
-        # Call Sync 2 times with failures. The task can only fail 
MAX_RUN_TASK_ATTEMPTS times.
-        for check_count in range(1, int(AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS)):
+        # Call Sync 2 times with failures. The task can only fail 
max_run_task_attempts times.
+        for check_count in range(1, int(mock_executor.max_run_task_attempts)):
             mock_executor.sync_running_tasks()
             assert mock_executor.ecs.describe_tasks.call_count == check_count
 
@@ -1215,7 +1217,7 @@ class TestAwsEcsExecutor:
         mock_success_function.assert_called_once()
 
     def test_update_running_tasks_failed(self, mock_executor, caplog):
-        AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "1"
+        mock_executor.max_run_task_attempts = "1"
         caplog.set_level(logging.WARNING)
         self._add_mock_task(mock_executor, ARN1)
         test_response_task_json = {
@@ -1343,14 +1345,81 @@ class TestEcsExecutorConfig:
         }
         with conf_vars(conf_overrides):
             with pytest.raises(ValueError) as raised:
-                ecs_executor_config.build_task_kwargs()
+                ecs_executor_config.build_task_kwargs(conf)
         assert raised.match("At least one subnet is required to run a task.")
 
+    # TODO: When merged this needs updating to the actually supported version
+    @pytest.mark.skipif(
+        not hasattr(base_executor, "ExecutorConf"),
+        reason="Test requires a version of airflow which includes updates to 
support multi team",
+    )
+    def test_team_config(self):
+        # Team name to be used throughout
+        team_name = "team_a"
+        # Patch environment to include two sets of configs for the ECS 
executor. One that is related to a
+        # team and one that is not. The we will create two executors (one with 
a team and one without) and
+        # ensure the correct configs are used.
+        config_overrides = [
+            (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CLUSTER}", 
"some_cluster"),
+            
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}", 
"container_name"),
+            
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.TASK_DEFINITION}", 
"some_task_def"),
+            (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}", 
"FARGATE"),
+            
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.PLATFORM_VERSION}", 
"LATEST"),
+            
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.ASSIGN_PUBLIC_IP}", "False"),
+            
(f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SECURITY_GROUPS}", 
"sg1,sg2"),
+            (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}", 
"sub1,sub2"),
+            (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.REGION_NAME}", 
"us-west-1"),
+            # team Config
+            
(f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CLUSTER}", 
"team_a_cluster"),
+            (
+                
f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}",
+                "team_a_container",
+            ),
+            (
+                
f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.TASK_DEFINITION}",
+                "team_a_task_def",
+            ),
+            
(f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}", 
"EC2"),
+            (
+                
f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SECURITY_GROUPS}",
+                "team_a_sg1,team_a_sg2",
+            ),
+            (
+                
f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}",
+                "team_a_sub1,team_a_sub2",
+            ),
+            
(f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.REGION_NAME}", 
"us-west-2"),
+        ]
+        with patch("os.environ", {key.upper(): value for key, value in 
config_overrides}):
+            team_executor = AwsEcsExecutor(team_name=team_name)
+            task_kwargs = 
ecs_executor_config.build_task_kwargs(team_executor.conf)
+
+            assert task_kwargs["cluster"] == "team_a_cluster"
+            assert task_kwargs["overrides"]["containerOverrides"][0]["name"] 
== "team_a_container"
+            assert task_kwargs["networkConfiguration"]["awsvpcConfiguration"] 
== {
+                "subnets": ["team_a_sub1", "team_a_sub2"],
+                "securityGroups": ["team_a_sg1", "team_a_sg2"],
+            }
+            assert task_kwargs["launchType"] == "EC2"
+            assert task_kwargs["taskDefinition"] == "team_a_task_def"
+            # Now create an executor without a team and ensure the non-team 
configs are used.
+            non_team_executor = AwsEcsExecutor()
+            task_kwargs = 
ecs_executor_config.build_task_kwargs(non_team_executor.conf)
+            assert task_kwargs["cluster"] == "some_cluster"
+            assert task_kwargs["overrides"]["containerOverrides"][0]["name"] 
== "container_name"
+            assert task_kwargs["networkConfiguration"]["awsvpcConfiguration"] 
== {
+                "subnets": ["sub1", "sub2"],
+                "securityGroups": ["sg1", "sg2"],
+                "assignPublicIp": "DISABLED",
+            }
+            assert task_kwargs["launchType"] == "FARGATE"
+            assert task_kwargs["taskDefinition"] == "some_task_def"
+
     @conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME): 
"container-name"})
     def test_config_defaults_are_applied(self, assign_subnets):
         from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
 
-        task_kwargs = 
_recursive_flatten_dict(ecs_executor_config.build_task_kwargs())
+        task_kwargs = 
_recursive_flatten_dict(ecs_executor_config.build_task_kwargs(conf))
         found_keys = {convert_camel_to_snake(key): key for key in 
task_kwargs.keys()}
 
         for expected_key, expected_value in CONFIG_DEFAULTS.items():
@@ -1388,12 +1457,12 @@ class TestEcsExecutorConfig:
         monkeypatch.delenv(run_task_kwargs_env_key, raising=False)
         from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
 
-        task_kwargs = ecs_executor_config.build_task_kwargs()
+        task_kwargs = ecs_executor_config.build_task_kwargs(conf)
         assert task_kwargs["platformVersion"] == default_version
 
         # Provide a new value explicitly and assert that it is applied over 
the default.
         monkeypatch.setenv(platform_version_env_key, first_explicit_version)
-        task_kwargs = ecs_executor_config.build_task_kwargs()
+        task_kwargs = ecs_executor_config.build_task_kwargs(conf)
         assert task_kwargs["platformVersion"] == first_explicit_version
 
         # Provide a value via template and assert that it is applied over the 
explicit value.
@@ -1401,12 +1470,12 @@ class TestEcsExecutorConfig:
             run_task_kwargs_env_key,
             json.dumps({AllEcsConfigKeys.PLATFORM_VERSION: templated_version}),
         )
-        task_kwargs = ecs_executor_config.build_task_kwargs()
+        task_kwargs = ecs_executor_config.build_task_kwargs(conf)
         assert task_kwargs["platformVersion"] == templated_version
 
         # Provide a new value explicitly and assert it is not applied over the 
templated values.
         monkeypatch.setenv(platform_version_env_key, second_explicit_version)
-        task_kwargs = ecs_executor_config.build_task_kwargs()
+        task_kwargs = ecs_executor_config.build_task_kwargs(conf)
         assert task_kwargs["platformVersion"] == templated_version
 
     @mock.patch.object(EcsHook, "conn")
@@ -1428,7 +1497,7 @@ class TestEcsExecutorConfig:
             
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper(),
             json.dumps(provided_run_task_kwargs),
         )
-        task_kwargs = ecs_executor_config.build_task_kwargs()
+        task_kwargs = ecs_executor_config.build_task_kwargs(conf)
         assert task_kwargs["platformVersion"] == templated_version
         assert task_kwargs["cluster"] == templated_cluster
 
@@ -1445,7 +1514,7 @@ class TestEcsExecutorConfig:
 
         run_task_kwargs_env_key = 
f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper()
         monkeypatch.setenv(run_task_kwargs_env_key, 
json.dumps(provided_run_task_kwargs))
-        task_kwargs = ecs_executor_config.build_task_kwargs()
+        task_kwargs = ecs_executor_config.build_task_kwargs(conf)
 
         # Verify that tag names are exempt from the camel-case conversion.
         assert task_kwargs["tags"] == templated_tags
@@ -1465,7 +1534,7 @@ class TestEcsExecutorConfig:
         for key, value in kwargs_to_test.items():
             monkeypatch.setenv(f"AIRFLOW__{CONFIG_GROUP_NAME}__{key}".upper(), 
value)
 
-        run_task_kwargs = ecs_executor_config.build_task_kwargs()
+        run_task_kwargs = ecs_executor_config.build_task_kwargs(conf)
         run_task_kwargs_network_config = 
run_task_kwargs["networkConfiguration"]["awsvpcConfiguration"]
         for key, value in kwargs_to_test.items():
             # Assert that the values are not at the root of the kwargs
@@ -1569,7 +1638,7 @@ class TestEcsExecutorConfig:
         with conf_vars(conf_overrides):
             from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
 
-            task_kwargs = ecs_executor_config.build_task_kwargs()
+            task_kwargs = ecs_executor_config.build_task_kwargs(conf)
             assert "launchType" not in task_kwargs
             assert task_kwargs["capacityProviderStrategy"] == 
valid_capacity_provider
 
@@ -1583,7 +1652,7 @@ class TestEcsExecutorConfig:
         with conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE): 
None}):
             from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
 
-            task_kwargs = ecs_executor_config.build_task_kwargs()
+            task_kwargs = ecs_executor_config.build_task_kwargs(conf)
             assert "launchType" not in task_kwargs
             assert "capacityProviderStrategy" not in task_kwargs
             mock_conn.describe_clusters.assert_called_once()
@@ -1596,7 +1665,7 @@ class TestEcsExecutorConfig:
         with conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE): 
None}):
             from airflow.providers.amazon.aws.executors.ecs import 
ecs_executor_config
 
-            task_kwargs = ecs_executor_config.build_task_kwargs()
+            task_kwargs = ecs_executor_config.build_task_kwargs(conf)
             assert task_kwargs["launchType"] == "FARGATE"
 
     @pytest.mark.parametrize(

Reply via email to