This is an automated email from the ASF dual-hosted git repository. potiuk pushed a commit to branch main in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/main by this push: new e2c73fd35ca AIP-67 - Multi-team: Per team executor config (env var only) (#55003) e2c73fd35ca is described below commit e2c73fd35cab3274307b0f4a08615fdbc905f53c Author: Niko Oliveira <oniko...@amazon.com> AuthorDate: Sun Sep 7 17:08:51 2025 -0700 AIP-67 - Multi-team: Per team executor config (env var only) (#55003) * Multi-team: per team executor config (env var only) Configuration for teams can now be specified via environment variable using the triple underscore syntax outlined in AIP-67. This applies to any configuration, but specifically is required for executor based configuration. A small shim has been added to BaseExecutor to allow easier access to team based config. ECS executor is converted to this new shim as a proof of concept for the mechanism. * PR Feedback: comment fixup --- airflow-core/src/airflow/configuration.py | 21 ++-- .../src/airflow/executors/base_executor.py | 23 ++++ airflow-core/tests/unit/core/test_configuration.py | 11 ++ .../amazon/aws/executors/ecs/ecs_executor.py | 45 ++++---- .../aws/executors/ecs/ecs_executor_config.py | 17 +-- .../amazon/aws/executors/ecs/test_ecs_executor.py | 117 ++++++++++++++++----- 6 files changed, 177 insertions(+), 57 deletions(-) diff --git a/airflow-core/src/airflow/configuration.py b/airflow-core/src/airflow/configuration.py index 8c384e64266..24a439e848e 100644 --- a/airflow-core/src/airflow/configuration.py +++ b/airflow-core/src/airflow/configuration.py @@ -877,12 +877,14 @@ class AirflowConfigParser(ConfigParser): mask_secret_core(value) mask_secret_sdk(value) - def _env_var_name(self, section: str, key: str) -> str: - return f"{ENV_VAR_PREFIX}{section.replace('.', '_').upper()}__{key.upper()}" - - def _get_env_var_option(self, section: str, key: str): - # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore) - env_var = self._env_var_name(section, key) + def _env_var_name(self, section: str, key: str, team_name: str | None = None) -> str: + team_component: str = f"{team_name.upper()}___" if team_name else "" + return f"{ENV_VAR_PREFIX}{team_component}{section.replace('.', '_').upper()}__{key.upper()}" + + def _get_env_var_option(self, section: str, key: str, team_name: str | None = None): + # must have format AIRFLOW__{SECTION}__{KEY} (note double underscore) OR for team based + # configuration must have the format AIRFLOW__{TEAM_NAME}___{SECTION}__{KEY} + env_var: str = self._env_var_name(section, key, team_name=team_name) if env_var in os.environ: return expand_env_var(os.environ[env_var]) # alternatively AIRFLOW__{SECTION}__{KEY}_CMD (for a command) @@ -982,6 +984,7 @@ class AirflowConfigParser(ConfigParser): suppress_warnings: bool = False, lookup_from_deprecated: bool = True, _extra_stacklevel: int = 0, + team_name: str | None = None, **kwargs, ) -> str | None: section = section.lower() @@ -1044,6 +1047,7 @@ class AirflowConfigParser(ConfigParser): section, issue_warning=not warning_emitted, extra_stacklevel=_extra_stacklevel, + team_name=team_name, ) if option is not None: return option @@ -1170,13 +1174,14 @@ class AirflowConfigParser(ConfigParser): section: str, issue_warning: bool = True, extra_stacklevel: int = 0, + team_name: str | None = None, ) -> str | None: - option = self._get_env_var_option(section, key) + option = self._get_env_var_option(section, key, team_name=team_name) if option is not None: return option if deprecated_section and deprecated_key: with self.suppress_future_warnings(): - option = self._get_env_var_option(deprecated_section, deprecated_key) + option = self._get_env_var_option(deprecated_section, deprecated_key, team_name=team_name) if option is not None: if issue_warning: self._warn_deprecate(section, key, deprecated_section, deprecated_key, extra_stacklevel) diff --git a/airflow-core/src/airflow/executors/base_executor.py b/airflow-core/src/airflow/executors/base_executor.py index 1723d94708b..70b9ba9ef08 100644 --- a/airflow-core/src/airflow/executors/base_executor.py +++ b/airflow-core/src/airflow/executors/base_executor.py @@ -97,6 +97,28 @@ class RunningRetryAttemptType: return True +class ExecutorConf: + """ + This class is used to fetch configuration for an executor for a particular team_name. + + It wraps the implementation of the configuration.get() to look for the particular section and key + prefixed with the team_name. This makes it easy for child classes (i.e. concrete executors) to fetch + configuration values for a particular team_name without having to worry about passing through the + team_name for every call to get configuration. + + Currently config only supports environment variables for team specific configuration. + """ + + def __init__(self, team_name: str | None = None) -> None: + self.team_name: str | None = team_name + + def get(self, *args, **kwargs) -> str | None: + return conf.get(*args, **kwargs, team_name=self.team_name) + + def getboolean(self, *args, **kwargs) -> bool: + return conf.getboolean(*args, **kwargs, team_name=self.team_name) + + class BaseExecutor(LoggingMixin): """ Base class to inherit for concrete executors such as Celery, Kubernetes, Local, etc. @@ -150,6 +172,7 @@ class BaseExecutor(LoggingMixin): self.running: set[TaskInstanceKey] = set() self.event_buffer: dict[TaskInstanceKey, EventBufferValueType] = {} self._task_event_logs: deque[Log] = deque() + self.conf = ExecutorConf(team_name) if self.parallelism <= 0: raise ValueError("parallelism is set to 0 or lower") diff --git a/airflow-core/tests/unit/core/test_configuration.py b/airflow-core/tests/unit/core/test_configuration.py index 3e518294a43..757804428a4 100644 --- a/airflow-core/tests/unit/core/test_configuration.py +++ b/airflow-core/tests/unit/core/test_configuration.py @@ -158,6 +158,17 @@ class TestConf: assert conf.has_option("testsection", "testkey") + def test_env_team(self): + with patch( + "os.environ", + { + "AIRFLOW__CELERY__RESULT_BACKEND": "FOO", + "AIRFLOW__UNIT_TEST_TEAM___CELERY__RESULT_BACKEND": "BAR", + }, + ): + assert conf.get("celery", "result_backend") == "FOO" + assert conf.get("celery", "result_backend", team_name="unit_test_team") == "BAR" + @conf_vars({("core", "percent"): "with%%inside"}) def test_conf_as_dict(self): cfg_dict = conf.as_dict() diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py index a7fe158f14c..396e998af61 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor.py @@ -32,7 +32,6 @@ from typing import TYPE_CHECKING from botocore.exceptions import ClientError, NoCredentialsError -from airflow.configuration import conf from airflow.exceptions import AirflowException from airflow.executors.base_executor import BaseExecutor from airflow.providers.amazon.aws.executors.ecs.boto_schema import BotoDescribeTasksSchema, BotoRunTaskSchema @@ -98,13 +97,6 @@ class AwsEcsExecutor(BaseExecutor): Airflow TaskInstance's executor_config. """ - # Maximum number of retries to run an ECS task. - MAX_RUN_TASK_ATTEMPTS = conf.get( - CONFIG_GROUP_NAME, - AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS, - fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS], - ) - # AWS limits the maximum number of ARNs in the describe_tasks function. DESCRIBE_TASKS_BATCH_SIZE = 99 @@ -118,8 +110,18 @@ class AwsEcsExecutor(BaseExecutor): self.active_workers: EcsTaskCollection = EcsTaskCollection() self.pending_tasks: deque = deque() - self.cluster = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER) - self.container_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME) + # Check if self has the ExecutorConf set on the self.conf attribute, and if not, set it to the global + # configuration object. This allows the changes to be backwards compatible with older versions of + # Airflow. + # Can be removed when minimum supported provider version is equal to the version of core airflow + # which introduces multi-team configuration. + if not hasattr(self, "conf"): + from airflow.configuration import conf + + self.conf = conf + + self.cluster = self.conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CLUSTER) + self.container_name = self.conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME) self.attempts_since_last_successful_connection = 0 self.load_ecs_connection(check_connection=False) @@ -127,6 +129,13 @@ class AwsEcsExecutor(BaseExecutor): self.run_task_kwargs = self._load_run_kwargs() + # Maximum number of retries to run an ECS task. + self.max_run_task_attempts = self.conf.get( + CONFIG_GROUP_NAME, + AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS, + fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.MAX_RUN_TASK_ATTEMPTS], + ) + def queue_workload(self, workload: workloads.All, session: Session | None) -> None: from airflow.executors import workloads @@ -154,7 +163,7 @@ class AwsEcsExecutor(BaseExecutor): def start(self): """Call this when the Executor is run for the first time by the scheduler.""" - check_health = conf.getboolean( + check_health = self.conf.getboolean( CONFIG_GROUP_NAME, AllEcsConfigKeys.CHECK_HEALTH_ON_STARTUP, fallback=False ) @@ -218,12 +227,12 @@ class AwsEcsExecutor(BaseExecutor): def load_ecs_connection(self, check_connection: bool = True): self.log.info("Loading Connection information") - aws_conn_id = conf.get( + aws_conn_id = self.conf.get( CONFIG_GROUP_NAME, AllEcsConfigKeys.AWS_CONN_ID, fallback=CONFIG_DEFAULTS[AllEcsConfigKeys.AWS_CONN_ID], ) - region_name = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME, fallback=None) + region_name = self.conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.REGION_NAME, fallback=None) self.ecs = EcsHook(aws_conn_id=aws_conn_id, region_name=region_name).conn self.attempts_since_last_successful_connection += 1 self.last_connection_reload = timezone.utcnow() @@ -340,13 +349,13 @@ class AwsEcsExecutor(BaseExecutor): queue = task_info.queue exec_info = task_info.config failure_count = self.active_workers.failure_count_by_key(task_key) - if int(failure_count) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS): + if int(failure_count) < int(self.max_run_task_attempts): self.log.warning( "Airflow task %s failed due to %s. Failure %s out of %s occurred on %s. Rescheduling.", task_key, reason, failure_count, - self.__class__.MAX_RUN_TASK_ATTEMPTS, + self.max_run_task_attempts, task_arn, ) self.pending_tasks.append( @@ -416,8 +425,8 @@ class AwsEcsExecutor(BaseExecutor): failure_reasons.extend([f["reason"] for f in run_task_response["failures"]]) if failure_reasons: - # Make sure the number of attempts does not exceed MAX_RUN_TASK_ATTEMPTS - if int(attempt_number) < int(self.__class__.MAX_RUN_TASK_ATTEMPTS): + # Make sure the number of attempts does not exceed max_run_task_attempts + if int(attempt_number) < int(self.max_run_task_attempts): ecs_task.attempt_number += 1 ecs_task.next_attempt_time = timezone.utcnow() + calculate_next_attempt_delay( attempt_number @@ -545,7 +554,7 @@ class AwsEcsExecutor(BaseExecutor): def _load_run_kwargs(self) -> dict: from airflow.providers.amazon.aws.executors.ecs.ecs_executor_config import build_task_kwargs - ecs_executor_run_task_kwargs = build_task_kwargs() + ecs_executor_run_task_kwargs = build_task_kwargs(self.conf) try: self.get_container(ecs_executor_run_task_kwargs["overrides"]["containerOverrides"])["command"] diff --git a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py index bcbbbfc9e8c..f7753903a2f 100644 --- a/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py +++ b/providers/amazon/src/airflow/providers/amazon/aws/executors/ecs/ecs_executor_config.py @@ -32,7 +32,6 @@ from __future__ import annotations import json from json import JSONDecodeError -from airflow.configuration import conf from airflow.providers.amazon.aws.executors.ecs.utils import ( CONFIG_GROUP_NAME, ECS_LAUNCH_TYPE_EC2, @@ -46,23 +45,27 @@ from airflow.providers.amazon.aws.hooks.ecs import EcsHook from airflow.utils.helpers import prune_dict -def _fetch_templated_kwargs() -> dict[str, str]: - run_task_kwargs_value = conf.get(CONFIG_GROUP_NAME, AllEcsConfigKeys.RUN_TASK_KWARGS, fallback=dict()) +def _fetch_templated_kwargs(conf) -> dict[str, str]: + run_task_kwargs_value = conf.get( + CONFIG_GROUP_NAME, + AllEcsConfigKeys.RUN_TASK_KWARGS, + fallback=dict(), + ) return json.loads(str(run_task_kwargs_value)) -def _fetch_config_values() -> dict[str, str]: +def _fetch_config_values(conf) -> dict[str, str]: return prune_dict( {key: conf.get(CONFIG_GROUP_NAME, key, fallback=None) for key in RunTaskKwargsConfigKeys()} ) -def build_task_kwargs() -> dict: +def build_task_kwargs(conf) -> dict: all_config_keys = AllEcsConfigKeys() # This will put some kwargs at the root of the dictionary that do NOT belong there. However, # the code below expects them to be there and will rearrange them as necessary. - task_kwargs = _fetch_config_values() - task_kwargs.update(_fetch_templated_kwargs()) + task_kwargs = _fetch_config_values(conf) + task_kwargs.update(_fetch_templated_kwargs(conf)) has_launch_type: bool = all_config_keys.LAUNCH_TYPE in task_kwargs has_capacity_provider: bool = all_config_keys.CAPACITY_PROVIDER_STRATEGY in task_kwargs diff --git a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py index 6d0b0055906..8f5fe85bf10 100644 --- a/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py +++ b/providers/amazon/tests/unit/amazon/aws/executors/ecs/test_ecs_executor.py @@ -25,7 +25,7 @@ import time from collections.abc import Callable from functools import partial from unittest import mock -from unittest.mock import MagicMock +from unittest.mock import MagicMock, patch import pytest import yaml @@ -33,7 +33,9 @@ from botocore.exceptions import ClientError from inflection import camelize from semver import VersionInfo +from airflow.configuration import conf from airflow.exceptions import AirflowException +from airflow.executors import base_executor from airflow.executors.base_executor import BaseExecutor from airflow.models import TaskInstance from airflow.models.taskinstancekey import TaskInstanceKey @@ -538,7 +540,7 @@ class TestAwsEcsExecutor: mock_executor.execute_async(mock_airflow_key, mock_cmd) # No matter what, don't schedule until run_task becomes successful. - for _ in range(int(mock_executor.MAX_RUN_TASK_ATTEMPTS) * 2): + for _ in range(int(mock_executor.max_run_task_attempts) * 2): mock_executor.attempt_task_runs() # Task is not stored in active workers. assert len(mock_executor.active_workers) == 0 @@ -555,7 +557,7 @@ class TestAwsEcsExecutor: mock_executor.execute_async(mock_airflow_key, mock_cmd) # No matter what, don't schedule until run_task becomes successful. - for _ in range(int(mock_executor.MAX_RUN_TASK_ATTEMPTS) * 2): + for _ in range(int(mock_executor.max_run_task_attempts) * 2): mock_executor.attempt_task_runs() # Task is not stored in active workers. assert len(mock_executor.active_workers) == 0 @@ -567,7 +569,7 @@ class TestAwsEcsExecutor: The executor should attempt each task exactly once per sync() iteration. It should preserve the order of tasks, and attempt each task up to - `MAX_RUN_TASK_ATTEMPTS` times before dropping the task. + `max_run_task_attempts` times before dropping the task. """ airflow_keys = [ TaskInstanceKey("a", "task_a", "c", 1, -1), @@ -627,7 +629,7 @@ class TestAwsEcsExecutor: The executor should attempt each task exactly once per sync() iteration. It should preserve the order of tasks, and attempt each task up to - `MAX_RUN_TASK_ATTEMPTS` times before dropping the task. If a task succeeds, the task + `max_run_task_attempts` times before dropping the task. If a task succeeds, the task should be removed from pending_jobs and into active_workers. """ airflow_keys = [ @@ -705,7 +707,7 @@ class TestAwsEcsExecutor: """ Test API failure retries. """ - AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "2" + mock_executor.max_run_task_attempts = "2" airflow_keys = ["TaskInstanceKey1", "TaskInstanceKey2"] airflow_commands = [_generate_mock_cmd(), _generate_mock_cmd()] @@ -834,7 +836,7 @@ class TestAwsEcsExecutor: @mock.patch.object(BaseExecutor, "success") def test_failed_sync(self, success_mock, fail_mock, mock_executor): """Test success and failure states.""" - AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "1" + mock_executor.max_run_task_attempts = "1" self._mock_sync(mock_executor, State.FAILED) mock_executor.sync() @@ -850,7 +852,7 @@ class TestAwsEcsExecutor: @mock.patch.object(BaseExecutor, "fail") def test_removed_sync(self, fail_mock, success_mock, mock_executor): """A removed task will be treated as a failed task.""" - AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "1" + mock_executor.max_run_task_attempts = "1" self._mock_sync(mock_executor, expected_state=State.REMOVED, set_task_state=State.REMOVED) mock_executor.sync_running_tasks() @@ -868,7 +870,7 @@ class TestAwsEcsExecutor: self, _, success_mock, fail_mock, mock_airflow_key, mock_executor, mock_cmd ): """Test that failure_count/attempt_number is cumulative for pending tasks and active workers.""" - AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "5" + mock_executor.max_run_task_attempts = "5" mock_executor.ecs.run_task.return_value = { "tasks": [], "failures": [ @@ -980,8 +982,8 @@ class TestAwsEcsExecutor: assert len(mock_executor.active_workers.get_all_arns()) == 1 task_key = mock_executor.active_workers.arn_to_key[ARN1] - # Call Sync 2 times with failures. The task can only fail MAX_RUN_TASK_ATTEMPTS times. - for check_count in range(1, int(AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS)): + # Call Sync 2 times with failures. The task can only fail max_run_task_attempts times. + for check_count in range(1, int(mock_executor.max_run_task_attempts)): mock_executor.sync_running_tasks() assert mock_executor.ecs.describe_tasks.call_count == check_count @@ -1215,7 +1217,7 @@ class TestAwsEcsExecutor: mock_success_function.assert_called_once() def test_update_running_tasks_failed(self, mock_executor, caplog): - AwsEcsExecutor.MAX_RUN_TASK_ATTEMPTS = "1" + mock_executor.max_run_task_attempts = "1" caplog.set_level(logging.WARNING) self._add_mock_task(mock_executor, ARN1) test_response_task_json = { @@ -1343,14 +1345,81 @@ class TestEcsExecutorConfig: } with conf_vars(conf_overrides): with pytest.raises(ValueError) as raised: - ecs_executor_config.build_task_kwargs() + ecs_executor_config.build_task_kwargs(conf) assert raised.match("At least one subnet is required to run a task.") + # TODO: When merged this needs updating to the actually supported version + @pytest.mark.skipif( + not hasattr(base_executor, "ExecutorConf"), + reason="Test requires a version of airflow which includes updates to support multi team", + ) + def test_team_config(self): + # Team name to be used throughout + team_name = "team_a" + # Patch environment to include two sets of configs for the ECS executor. One that is related to a + # team and one that is not. The we will create two executors (one with a team and one without) and + # ensure the correct configs are used. + config_overrides = [ + (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CLUSTER}", "some_cluster"), + (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}", "container_name"), + (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.TASK_DEFINITION}", "some_task_def"), + (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}", "FARGATE"), + (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.PLATFORM_VERSION}", "LATEST"), + (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.ASSIGN_PUBLIC_IP}", "False"), + (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SECURITY_GROUPS}", "sg1,sg2"), + (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}", "sub1,sub2"), + (f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.REGION_NAME}", "us-west-1"), + # team Config + (f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CLUSTER}", "team_a_cluster"), + ( + f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.CONTAINER_NAME}", + "team_a_container", + ), + ( + f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.TASK_DEFINITION}", + "team_a_task_def", + ), + (f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.LAUNCH_TYPE}", "EC2"), + ( + f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SECURITY_GROUPS}", + "team_a_sg1,team_a_sg2", + ), + ( + f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.SUBNETS}", + "team_a_sub1,team_a_sub2", + ), + (f"AIRFLOW__{team_name}___{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.REGION_NAME}", "us-west-2"), + ] + with patch("os.environ", {key.upper(): value for key, value in config_overrides}): + team_executor = AwsEcsExecutor(team_name=team_name) + task_kwargs = ecs_executor_config.build_task_kwargs(team_executor.conf) + + assert task_kwargs["cluster"] == "team_a_cluster" + assert task_kwargs["overrides"]["containerOverrides"][0]["name"] == "team_a_container" + assert task_kwargs["networkConfiguration"]["awsvpcConfiguration"] == { + "subnets": ["team_a_sub1", "team_a_sub2"], + "securityGroups": ["team_a_sg1", "team_a_sg2"], + } + assert task_kwargs["launchType"] == "EC2" + assert task_kwargs["taskDefinition"] == "team_a_task_def" + # Now create an executor without a team and ensure the non-team configs are used. + non_team_executor = AwsEcsExecutor() + task_kwargs = ecs_executor_config.build_task_kwargs(non_team_executor.conf) + assert task_kwargs["cluster"] == "some_cluster" + assert task_kwargs["overrides"]["containerOverrides"][0]["name"] == "container_name" + assert task_kwargs["networkConfiguration"]["awsvpcConfiguration"] == { + "subnets": ["sub1", "sub2"], + "securityGroups": ["sg1", "sg2"], + "assignPublicIp": "DISABLED", + } + assert task_kwargs["launchType"] == "FARGATE" + assert task_kwargs["taskDefinition"] == "some_task_def" + @conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.CONTAINER_NAME): "container-name"}) def test_config_defaults_are_applied(self, assign_subnets): from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config - task_kwargs = _recursive_flatten_dict(ecs_executor_config.build_task_kwargs()) + task_kwargs = _recursive_flatten_dict(ecs_executor_config.build_task_kwargs(conf)) found_keys = {convert_camel_to_snake(key): key for key in task_kwargs.keys()} for expected_key, expected_value in CONFIG_DEFAULTS.items(): @@ -1388,12 +1457,12 @@ class TestEcsExecutorConfig: monkeypatch.delenv(run_task_kwargs_env_key, raising=False) from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config - task_kwargs = ecs_executor_config.build_task_kwargs() + task_kwargs = ecs_executor_config.build_task_kwargs(conf) assert task_kwargs["platformVersion"] == default_version # Provide a new value explicitly and assert that it is applied over the default. monkeypatch.setenv(platform_version_env_key, first_explicit_version) - task_kwargs = ecs_executor_config.build_task_kwargs() + task_kwargs = ecs_executor_config.build_task_kwargs(conf) assert task_kwargs["platformVersion"] == first_explicit_version # Provide a value via template and assert that it is applied over the explicit value. @@ -1401,12 +1470,12 @@ class TestEcsExecutorConfig: run_task_kwargs_env_key, json.dumps({AllEcsConfigKeys.PLATFORM_VERSION: templated_version}), ) - task_kwargs = ecs_executor_config.build_task_kwargs() + task_kwargs = ecs_executor_config.build_task_kwargs(conf) assert task_kwargs["platformVersion"] == templated_version # Provide a new value explicitly and assert it is not applied over the templated values. monkeypatch.setenv(platform_version_env_key, second_explicit_version) - task_kwargs = ecs_executor_config.build_task_kwargs() + task_kwargs = ecs_executor_config.build_task_kwargs(conf) assert task_kwargs["platformVersion"] == templated_version @mock.patch.object(EcsHook, "conn") @@ -1428,7 +1497,7 @@ class TestEcsExecutorConfig: f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper(), json.dumps(provided_run_task_kwargs), ) - task_kwargs = ecs_executor_config.build_task_kwargs() + task_kwargs = ecs_executor_config.build_task_kwargs(conf) assert task_kwargs["platformVersion"] == templated_version assert task_kwargs["cluster"] == templated_cluster @@ -1445,7 +1514,7 @@ class TestEcsExecutorConfig: run_task_kwargs_env_key = f"AIRFLOW__{CONFIG_GROUP_NAME}__{AllEcsConfigKeys.RUN_TASK_KWARGS}".upper() monkeypatch.setenv(run_task_kwargs_env_key, json.dumps(provided_run_task_kwargs)) - task_kwargs = ecs_executor_config.build_task_kwargs() + task_kwargs = ecs_executor_config.build_task_kwargs(conf) # Verify that tag names are exempt from the camel-case conversion. assert task_kwargs["tags"] == templated_tags @@ -1465,7 +1534,7 @@ class TestEcsExecutorConfig: for key, value in kwargs_to_test.items(): monkeypatch.setenv(f"AIRFLOW__{CONFIG_GROUP_NAME}__{key}".upper(), value) - run_task_kwargs = ecs_executor_config.build_task_kwargs() + run_task_kwargs = ecs_executor_config.build_task_kwargs(conf) run_task_kwargs_network_config = run_task_kwargs["networkConfiguration"]["awsvpcConfiguration"] for key, value in kwargs_to_test.items(): # Assert that the values are not at the root of the kwargs @@ -1569,7 +1638,7 @@ class TestEcsExecutorConfig: with conf_vars(conf_overrides): from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config - task_kwargs = ecs_executor_config.build_task_kwargs() + task_kwargs = ecs_executor_config.build_task_kwargs(conf) assert "launchType" not in task_kwargs assert task_kwargs["capacityProviderStrategy"] == valid_capacity_provider @@ -1583,7 +1652,7 @@ class TestEcsExecutorConfig: with conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE): None}): from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config - task_kwargs = ecs_executor_config.build_task_kwargs() + task_kwargs = ecs_executor_config.build_task_kwargs(conf) assert "launchType" not in task_kwargs assert "capacityProviderStrategy" not in task_kwargs mock_conn.describe_clusters.assert_called_once() @@ -1596,7 +1665,7 @@ class TestEcsExecutorConfig: with conf_vars({(CONFIG_GROUP_NAME, AllEcsConfigKeys.LAUNCH_TYPE): None}): from airflow.providers.amazon.aws.executors.ecs import ecs_executor_config - task_kwargs = ecs_executor_config.build_task_kwargs() + task_kwargs = ecs_executor_config.build_task_kwargs(conf) assert task_kwargs["launchType"] == "FARGATE" @pytest.mark.parametrize(