This is an automated email from the ASF dual-hosted git repository. ash pushed a commit to branch task-sdk-first-code in repository https://gitbox.apache.org/repos/asf/airflow.git
The following commit(s) were added to refs/heads/task-sdk-first-code by this push: new 44efb0b7c1 Move over more of BaseOperator and DAG, along with their tests 44efb0b7c1 is described below commit 44efb0b7c1d383ed8b71d2e6ec25573f542abe56 Author: Ash Berlin-Taylor <a...@apache.org> AuthorDate: Mon Oct 14 22:16:00 2024 +0100 Move over more of BaseOperator and DAG, along with their tests --- airflow/task/priority_strategy.py | 4 +- dev/tests_common/test_utils/mock_operators.py | 10 - task_sdk/pyproject.toml | 19 ++ .../airflow/sdk/definitions/abstractoperator.py | 5 +- .../src/airflow/sdk/definitions/baseoperator.py | 352 ++++++++++++++------- task_sdk/src/airflow/sdk/definitions/dag.py | 114 +++++-- task_sdk/src/airflow/sdk/definitions/node.py | 2 +- task_sdk/src/airflow/sdk/types.py | 19 +- task_sdk/tests/defintions/test_baseoperator.py | 305 ++++++++++++++++++ task_sdk/tests/defintions/test_dag.py | 96 +++++- tests/models/test_baseoperator.py | 253 +-------------- tests/models/test_dag.py | 143 +-------- uv.lock | 11 + 13 files changed, 768 insertions(+), 565 deletions(-) diff --git a/airflow/task/priority_strategy.py b/airflow/task/priority_strategy.py index c22bdfa994..dcef1c865b 100644 --- a/airflow/task/priority_strategy.py +++ b/airflow/task/priority_strategy.py @@ -22,8 +22,6 @@ from __future__ import annotations from abc import ABC, abstractmethod from typing import TYPE_CHECKING, Any -from airflow.exceptions import AirflowException - if TYPE_CHECKING: from airflow.models.taskinstance import TaskInstance @@ -150,5 +148,5 @@ def validate_and_load_priority_weight_strategy( priority_weight_strategy_class = qualname(priority_weight_strategy) loaded_priority_weight_strategy = _get_registered_priority_weight_strategy(priority_weight_strategy_class) if loaded_priority_weight_strategy is None: - raise AirflowException(f"Unknown priority strategy {priority_weight_strategy_class}") + raise ValueError(f"Unknown priority strategy {priority_weight_strategy_class}") return loaded_priority_weight_strategy() diff --git a/dev/tests_common/test_utils/mock_operators.py b/dev/tests_common/test_utils/mock_operators.py index 0df0afec82..ecf8989f4b 100644 --- a/dev/tests_common/test_utils/mock_operators.py +++ b/dev/tests_common/test_utils/mock_operators.py @@ -16,7 +16,6 @@ # under the License. from __future__ import annotations -import warnings from typing import TYPE_CHECKING, Any, Sequence import attr @@ -200,12 +199,3 @@ class GithubLink(BaseOperatorLink): def get_link(self, operator, *, ti_key): return "https://github.com/apache/airflow" - - -class DeprecatedOperator(BaseOperator): - def __init__(self, **kwargs): - warnings.warn("This operator is deprecated.", DeprecationWarning, stacklevel=2) - super().__init__(**kwargs) - - def execute(self, context: Context): - pass diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml index f83a4b7ec2..f371c3744a 100644 --- a/task_sdk/pyproject.toml +++ b/task_sdk/pyproject.toml @@ -46,7 +46,26 @@ namespace-packages = ["src/airflow"] [tool.uv] dev-dependencies = [ + "kgb>=7.1.1", "pytest-asyncio>=0.24.0", "pytest-mock>=3.14.0", "pytest>=8.3.3", ] + +[tool.coverage.run] +branch = true +relative_files = true +source = ["src/airflow"] +include_namespace_packages = true + +[tool.coverage.report] +skip_empty = true +exclude_also = [ + "def __repr__", + "raise AssertionError", + "raise NotImplementedError", + "if __name__ == .__main__.:", + "@(abc\\.)?abstractmethod", + "@(typing(_extensions)?\\.)?overload", + "if (typing(_extensions)?\\.)?TYPE_CHECKING:", +] diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py index d37a478df3..2dd0d48804 100644 --- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py @@ -50,12 +50,9 @@ DEFAULT_RETRIES: int = 0 DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(300) MAX_RETRY_DELAY: int = 24 * 60 * 60 -# DEFAULT_WEIGHT_RULE: WeightRule = WeightRule( -# conf.get("core", "default_task_weight_rule", fallback=WeightRule.DOWNSTREAM) -# ) # DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS -DEFAULT_WEIGHT_RULE = 0 DEFAULT_TRIGGER_RULE = "ALL_SUCCESS" # TriggerRule.ALL_SUCCESS +DEFAULT_WEIGHT_RULE = "downstream" DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = None diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py b/task_sdk/src/airflow/sdk/definitions/baseoperator.py index d52802db21..6480a89780 100644 --- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py +++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py @@ -22,8 +22,9 @@ import collections.abc import contextlib import copy import inspect -from collections.abc import Sequence +from collections.abc import Iterable, Sequence from dataclasses import dataclass +from datetime import datetime, timedelta from functools import total_ordering, wraps from types import FunctionType from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast @@ -44,6 +45,9 @@ from airflow.sdk.definitions.abstractoperator import ( ) from airflow.sdk.definitions.decorators import fixup_decorator_warning_stack from airflow.sdk.definitions.node import validate_key +from airflow.sdk.types import NOTSET, validate_instance_args +from airflow.task.priority_strategy import PriorityWeightStrategy, validate_and_load_priority_weight_strategy +from airflow.utils.types import AttributeRemoved T = TypeVar("T", bound=FunctionType) @@ -59,7 +63,7 @@ if TYPE_CHECKING: # TODO: Task-SDK AirflowException = RuntimeError -ParamsDict = object +ParamsDict = dict def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> tuple[dict, ParamsDict]: @@ -83,11 +87,11 @@ def get_merged_defaults( args, params = _get_parent_defaults(dag, task_group) if task_params: if not isinstance(task_params, collections.abc.Mapping): - raise TypeError("params must be a mapping") + raise TypeError(f"params must be a mapping, got {type(task_params)}") params.update(task_params) if task_default_args: if not isinstance(task_default_args, collections.abc.Mapping): - raise TypeError("default_args must be a mapping") + raise TypeError(f"default_args must be a mapping, got {type(task_params)}") args.update(task_default_args) with contextlib.suppress(KeyError): params.update(task_default_args["params"] or {}) @@ -130,7 +134,7 @@ class BaseOperatorMeta(abc.ABCMeta): from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext if args: - raise AirflowException("Use keyword arguments when initializing operators") + raise TypeError("Use keyword arguments when initializing operators") instantiated_from_mapped = kwargs.pop( "_airflow_from_mapped", @@ -155,10 +159,10 @@ class BaseOperatorMeta(abc.ABCMeta): missing_args = non_optional_args.difference(kwargs) if len(missing_args) == 1: - raise AirflowException(f"missing keyword argument {missing_args.pop()!r}") + raise TypeError(f"missing keyword argument {missing_args.pop()!r}") elif missing_args: display = ", ".join(repr(a) for a in sorted(missing_args)) - raise AirflowException(f"missing keyword arguments {display}") + raise TypeError(f"missing keyword arguments {display}") if merged_params: kwargs["params"] = merged_params @@ -169,8 +173,8 @@ class BaseOperatorMeta(abc.ABCMeta): default_args = kwargs.pop("default_args", {}) if not hasattr(self, "_BaseOperator__init_kwargs"): - self._BaseOperator__init_kwargs = {} - self._BaseOperator__from_mapped = instantiated_from_mapped + object.__setattr__(self, "_BaseOperator__init_kwargs", {}) + object.__setattr__(self, "_BaseOperator__from_mapped", instantiated_from_mapped) result = func(self, **kwargs, default_args=default_args) @@ -180,9 +184,9 @@ class BaseOperatorMeta(abc.ABCMeta): # Set upstream task defined by XComArgs passed to template fields of the operator. # BUT: only do this _ONCE_, not once for each class in the hierarchy if not instantiated_from_mapped and func == self.__init__.__wrapped__: # type: ignore[misc] - self.set_xcomargs_dependencies() - # Mark instance as instantiated. - self._BaseOperator__instantiated = True + self._set_xcomargs_dependencies() + # Mark instance as instantiated so that futre attr setting updates xcomarg-based deps. + object.__setattr__(self, "_BaseOperator__instantiated", True) return result @@ -213,11 +217,60 @@ class BaseOperatorMeta(abc.ABCMeta): return new_cls +# TODO: The following mapping is used to validate that the arguments passed to the BaseOperator are of the +# correct type. This is a temporary solution until we find a more sophisticated method for argument +# validation. One potential method is to use `get_type_hints` from the typing module. However, this is not +# fully compatible with future annotations for Python versions below 3.10. Once we require a minimum Python +# version that supports `get_type_hints` effectively or find a better approach, we can replace this +# manual type-checking method. +BASEOPERATOR_ARGS_EXPECTED_TYPES = { + "task_id": str, + "email": (str, Sequence), + "email_on_retry": bool, + "email_on_failure": bool, + "retries": int, + "retry_exponential_backoff": bool, + "depends_on_past": bool, + "ignore_first_depends_on_past": bool, + "wait_for_past_depends_before_skipping": bool, + "wait_for_downstream": bool, + "priority_weight": int, + "queue": str, + "pool": str, + "pool_slots": int, + "trigger_rule": str, + "run_as_user": str, + "task_concurrency": int, + "map_index_template": str, + "max_active_tis_per_dag": int, + "max_active_tis_per_dagrun": int, + "executor": str, + "do_xcom_push": bool, + "multiple_outputs": bool, + "doc": str, + "doc_md": str, + "doc_json": str, + "doc_yaml": str, + "doc_rst": str, + "task_display_name": str, + "logger_name": str, + "allow_nested_operators": bool, + "start_date": datetime, + "end_date": datetime, +} + + +# Note: BaseOperator is defined as a dataclass, and not an attrs class as we do too much metaprogramming in +# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs too much to make it worth it. +# +# To future reader: if you want to try and make this a "normal" attrs class, go ahead and attempt it. If you +# get no where leave your record here for the next poor soul and what problems you ran in to. +# +# @ashb, 2024/10/14 +# - "Can't combine custom __setattr__ with on_setattr hooks" +# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs subclasses @total_ordering -@dataclass( - init=False, - repr=False, -) +@dataclass(repr=False, kw_only=True) class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): r""" Abstract base class for all operators. @@ -433,7 +486,59 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): hello_world_task.execute(context) """ - # Implementing Operator. + task_id: str + owner: str = DEFAULT_OWNER + email: str | Sequence[str] | None = None + retries: int | None = DEFAULT_RETRIES + retry_delay: timedelta | float = DEFAULT_RETRY_DELAY + retry_exponential_backoff: bool = False + max_retry_delay: timedelta | float | None = None + start_date: datetime | None = None + end_date: datetime | None = None + depends_on_past: bool = False + ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST + wait_for_past_depends_before_skipping: bool = DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING + wait_for_downstream: bool = False + dag: DAG | None = None + params: MutableMapping | None = None + default_args: dict | None = None + priority_weight: int = DEFAULT_PRIORITY_WEIGHT + # TODO: + weight_rule: PriorityWeightStrategy | str = DEFAULT_WEIGHT_RULE + queue: str = DEFAULT_QUEUE + pool: str = "default" + pool_slots: int = DEFAULT_POOL_SLOTS + execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT + # on_execute_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + # on_failure_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + # on_success_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + # on_retry_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + # on_skipped_callback: None | TaskStateChangeCallback | list[TaskStateChangeCallback] = None + # pre_execute: TaskPreExecuteHook | None = None + # post_execute: TaskPostExecuteHook | None = None + trigger_rule: str = DEFAULT_TRIGGER_RULE + resources: dict[str, Any] | None = None + run_as_user: str | None = None + task_concurrency: int | None = None + map_index_template: str | None = None + max_active_tis_per_dag: int | None = None + max_active_tis_per_dagrun: int | None = None + executor: str | None = None + executor_config: dict | None = None + do_xcom_push: bool = True + multiple_outputs: bool = False + inlets: Any | None = None + outlets: Any | None = None + task_group: TaskGroup | None = None + doc: str | None = None + doc_md: str | None = None + doc_json: str | None = None + doc_yaml: str | None = None + doc_rst: str | None = None + _task_display_name: str + logger_name: str | None = None + allow_nested_operators: bool = True + template_fields: ClassVar[Sequence[str]] = () template_ext: ClassVar[Sequence[str]] = () @@ -443,12 +548,10 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): ui_color: str = "#fff" ui_fgcolor: str = "#000" - pool: str = "" - - # TODO: Mapping + # TODO: Task-SDK Mapping # partial: Callable[..., OperatorPartial] = _PartialDescriptor() # type: ignore - _comps = { + _comps: ClassVar[set[str]] = { "task_id", "dag_id", "owner", @@ -476,29 +579,45 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): } # Defines if the operator supports lineage without manual definitions - supports_lineage = False + supports_lineage: bool = False # If True then the class constructor was called - __instantiated = False + __instantiated: bool = False # List of args as passed to `init()`, after apply_defaults() has been updated. Used to "recreate" the task # when mapping __init_kwargs: dict[str, Any] # Set to True before calling execute method - _lock_for_execution = False + _lock_for_execution: bool = False # Set to True for an operator instantiated by a mapped operator. - __from_mapped = False + __from_mapped: bool = False # TODO: # start_trigger_args: StartTriggerArgs | None = None # start_from_trigger: bool = False + def __setattr__(self: BaseOperator, key: str, value: Any): + if converter := getattr(self, f"_convert_{key}", None): + value = converter(value) + super().__setattr__(key, value) + if self.__from_mapped or self._lock_for_execution: + return # Skip any custom behavior for validation and during execute. + if key in self.__init_kwargs: + self.__init_kwargs[key] = value + if self.__instantiated and key in self.template_fields: + # Resolve upstreams set by assigning an XComArg after initializing + # an operator, example: + # op = BashOperator() + # op.bash_command = "sleep 1" + self._set_xcomargs_dependency(key, value) + def __init__( self, + *, task_id: str, owner: str = DEFAULT_OWNER, - email: str | Iterable[str] | None = None, + email: str | Sequence[str] | None = None, retries: int | None = DEFAULT_RETRIES, retry_delay: timedelta | float = DEFAULT_RETRY_DELAY, retry_exponential_backoff: bool = False, @@ -513,9 +632,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): params: MutableMapping | None = None, default_args: dict | None = None, priority_weight: int = DEFAULT_PRIORITY_WEIGHT, - # TODO: - # weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE, - weight_rule: str = DEFAULT_WEIGHT_RULE, + weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE, queue: str = DEFAULT_QUEUE, pool: str | None = None, pool_slots: int = DEFAULT_POOL_SLOTS, @@ -531,7 +648,6 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): trigger_rule: str = DEFAULT_TRIGGER_RULE, resources: dict[str, Any] | None = None, run_as_user: str | None = None, - task_concurrency: int | None = None, map_index_template: str | None = None, max_active_tis_per_dag: int | None = None, max_active_tis_per_dagrun: int | None = None, @@ -554,25 +670,23 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): ): from airflow.sdk.definitions.contextmanager import DagContext, TaskGroupContext - self.__init_kwargs = {} + self.task_id = task_group.child_id(task_id) if task_group else task_id + if not self.__from_mapped and task_group: + task_group.add(self) + + dag = dag or DagContext.get_current() + task_group = task_group or TaskGroupContext.get_current(dag) - super().__init__() + super().__init__(dag=dag, task_group=task_group) kwargs.pop("_airflow_mapped_validation_only", None) if kwargs: - raise RuntimeError( + raise TypeError( f"Invalid arguments were passed to {self.__class__.__name__} (task_id: {task_id}). " + f"Invalid arguments were:\n**kwargs: {kwargs}", ) validate_key(task_id) - dag = dag or DagContext.get_current() - task_group = task_group or TaskGroupContext.get_current(dag) - - self.task_id = task_group.child_id(task_id) if task_group else task_id - if not self.__from_mapped and task_group: - task_group.add(self) - self.owner = owner self.email = email @@ -591,9 +705,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): # self._pre_execute_hook = pre_execute # self._post_execute_hook = post_execute - if start_date and not isinstance(start_date, datetime): - self.log.warning("start_date for %s isn't datetime.datetime", self) - elif start_date: + if start_date: self.start_date = timezone.convert_to_utc(start_date) if end_date: @@ -610,7 +722,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): self.run_as_user = run_as_user # TODO: # self.retries = parse_retries(retries) - self.retries = int(retries) + self.retries = retries self.queue = queue self.pool = "default" if pool is None else pool self.pool_slots = pool_slots @@ -620,15 +732,6 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): self.sla = sla """ - if trigger_rule == "none_failed_or_skipped": - warnings.warn( - "none_failed_or_skipped Trigger Rule is deprecated. " - "Please use `none_failed_min_one_success`.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - trigger_rule = TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS - # if not TriggerRule.is_valid(trigger_rule): # raise AirflowException( # f"The trigger_rule must be one of {TriggerRule.all_triggers()}," @@ -637,6 +740,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): # # self.trigger_rule: TriggerRule = TriggerRule(trigger_rule) # FailStopDagInvalidTriggerRule.check(dag=dag, trigger_rule=self.trigger_rule) + """ self.depends_on_past: bool = depends_on_past self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past @@ -645,34 +749,21 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): if wait_for_downstream: self.depends_on_past = True - self.retry_delay = coerce_timedelta(retry_delay, key="retry_delay") + self.retry_delay = retry_delay self.retry_exponential_backoff = retry_exponential_backoff - self.max_retry_delay = ( - max_retry_delay - if max_retry_delay is None - else coerce_timedelta(max_retry_delay, key="max_retry_delay") - ) + if max_retry_delay is not None: + self.max_retry_delay = max_retry_delay + + self.resources = resources + """ # At execution_time this becomes a normal dict self.params: ParamsDict | dict = ParamsDict(params) - if priority_weight is not None and not isinstance(priority_weight, int): - raise AirflowException( - f"`priority_weight` for task '{self.task_id}' only accepts integers, " - f"received '{type(priority_weight)}'." - ) + """ + self.priority_weight = priority_weight self.weight_rule = validate_and_load_priority_weight_strategy(weight_rule) - self.resources = coerce_resources(resources) - if task_concurrency and not max_active_tis_per_dag: - # TODO: Remove in Airflow 3.0 - warnings.warn( - "The 'task_concurrency' parameter is deprecated. Please use 'max_active_tis_per_dag'.", - RemovedInAirflow3Warning, - stacklevel=2, - ) - max_active_tis_per_dag = task_concurrency - """ self.max_active_tis_per_dag: int | None = max_active_tis_per_dag self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun self.do_xcom_push: bool = do_xcom_push @@ -684,22 +775,18 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): self.doc_yaml = doc_yaml self.doc_rst = doc_rst self.doc = doc - # Populate the display field only if provided and different from task id - self._task_display_property_value = ( - task_display_name if task_display_name and task_display_name != task_id else None - ) - if dag: - self.dag = dag + self._task_display_name = task_display_name + + self.allow_nested_operators = allow_nested_operators + self.inlets: list = [] + self.outlets: list = [] """ self._log_config_logger_name = "airflow.task.operators" self._logger_name = logger_name - self.allow_nested_operators: bool = allow_nested_operators # Lineage - self.inlets: list = [] - self.outlets: list = [] if inlets: self.inlets = ( @@ -718,6 +805,7 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): outlets, ] ) + """ if isinstance(self.template_fields, str): warnings.warn( @@ -731,11 +819,11 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): self._is_setup = False self._is_teardown = False - if SetupTeardownContext.active: - SetupTeardownContext.update_context_map(self) + # TODO: Task-SDK + # if SetupTeardownContext.active: + # SetupTeardownContext.update_context_map(self) validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES) - """ def __eq__(self, other): if type(self) is type(other): @@ -810,19 +898,6 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): return self - def __setattr__(self, key, value): - super().__setattr__(key, value) - if self.__from_mapped or self._lock_for_execution: - return # Skip any custom behavior for validation and during execute. - if key in self.__init_kwargs: - self.__init_kwargs[key] = value - if self.__instantiated and key in self.template_fields: - # Resolve upstreams set by assigning an XComArg after initializing - # an operator, example: - # op = BashOperator() - # op.bash_command = "sleep 1" - self.set_xcomargs_dependencies() - def add_inlets(self, inlets: Iterable[Any]): """Set inlets to this operator.""" self.inlets.extend(inlets) @@ -832,45 +907,77 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): self.outlets.extend(outlets) def get_dag(self) -> DAG | None: - return self._dag + return self.dag - @property # type: ignore[override] - def dag(self) -> DAG: # type: ignore[override] - """Returns the Operator's DAG if set, otherwise raises an error.""" - if self._dag: - return self._dag - else: - raise RuntimeError(f"Operator {self} has not been assigned to a DAG yet") - - @dag.setter - def dag(self, dag: DAG | None): + def _convert_dag(self, dag: DAG | None | AttributeRemoved) -> DAG | None | AttributeRemoved: """Operators can be assigned to one DAG, one time. Repeat assignments to that same DAG are ok.""" - from .dag import DAG + from airflow.sdk.definitions.dag import DAG if dag is None: - self._dag = None - return + return dag + + # if set to removed, then just set and exit + if self.dag.__class__ is AttributeRemoved: + return dag + # if setting to removed, then just set and exit + if dag.__class__ is AttributeRemoved: + return AttributeRemoved("_dag") # type: ignore[assignment] + if not isinstance(dag, DAG): raise TypeError(f"Expected DAG; received {dag.__class__.__name__}") - elif self.has_dag() and self.dag is not dag: + elif self.dag is not None and self.dag is not dag: raise ValueError(f"The DAG assigned to {self} can not be changed.") if self.__from_mapped: pass # Don't add to DAG -- the mapped task takes the place. elif dag.task_dict.get(self.task_id) is not self: dag.add_task(self) - - self._dag = dag + return dag + + @staticmethod + def _convert_retries(retries: Any) -> int | None: + if retries is None: + return 0 + elif type(retries) == int: # noqa: E721 + return retries + try: + parsed_retries = int(retries) + except (TypeError, ValueError): + raise TypeError(f"'retries' type must be int, not {type(retries).__name__}") + return parsed_retries + + @staticmethod + def _convert_timedelta(value: float | timedelta) -> timedelta: + if isinstance(value, timedelta): + return value + return timedelta(seconds=value) + + _convert_retry_delay = _convert_timedelta + _convert_max_retry_delay = _convert_timedelta + + @staticmethod + def _convert_resources(resources: dict[str, Any] | None) -> Resources | None: + if resources is None: + return None + return Resources(**resources) @property def task_display_name(self) -> str: - return self._task_display_property_value or self.task_id + return self._task_display_name or self.task_id def has_dag(self): """Return True if the Operator has been assigned to a DAG.""" - return self._dag is not None + return self.dag is not None + + def _set_xcomargs_dependencies(self) -> None: + from airflow.models.xcom_arg import XComArg + + for field in self.template_fields: + arg = getattr(self, field, NOTSET) + if arg is not NOTSET: + XComArg.apply_upstream_relationship(self, arg) - def set_xcomargs_dependencies(self) -> None: + def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None: """ Resolve upstream dependencies of a task. @@ -892,10 +999,9 @@ class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta): """ from airflow.models.xcom_arg import XComArg - for field in self.template_fields: - if hasattr(self, field): - arg = getattr(self, field) - XComArg.apply_upstream_relationship(self, arg) + if field not in self.template_fields: + return + XComArg.apply_upstream_relationship(self, newvalue) def on_kill(self) -> None: """ diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py b/task_sdk/src/airflow/sdk/definitions/dag.py index f80d8f7f71..5cf76da488 100644 --- a/task_sdk/src/airflow/sdk/definitions/dag.py +++ b/task_sdk/src/airflow/sdk/definitions/dag.py @@ -27,7 +27,7 @@ import sys import weakref from collections import abc from collections.abc import Collection, Iterable, Iterator -from datetime import datetime, timedelta +from datetime import datetime, timedelta, timezone from inspect import signature from re import Pattern from typing import ( @@ -51,7 +51,6 @@ from airflow import settings from airflow.assets import Asset, AssetAlias, BaseAsset from airflow.configuration import conf as airflow_conf from airflow.exceptions import ( - AirflowException, DuplicateTaskIdFound, FailStopDagInvalidTriggerRule, ParamValidationError, @@ -62,6 +61,13 @@ from airflow.sdk.definitions.abstractoperator import AbstractOperator from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.stats import Stats from airflow.timetables.base import Timetable +from airflow.timetables.interval import CronDataIntervalTimetable, DeltaDataIntervalTimetable +from airflow.timetables.simple import ( + AssetTriggeredTimetable, + ContinuousTimetable, + NullTimetable, + OnceTimetable, +) from airflow.utils.dag_cycle_tester import check_cycle from airflow.utils.decorators import fixup_decorator_warning_stack from airflow.utils.trigger_rule import TriggerRule @@ -70,15 +76,17 @@ from airflow.utils.types import NOTSET, EdgeInfoType if TYPE_CHECKING: from airflow.decorators import TaskDecoratorCollection from airflow.models.operator import Operator - from airflow.utils.taskgroup import TaskGroup + from airflow.sdk.definitions.taskgroup import TaskGroup log = logging.getLogger(__name__) -DEFAULT_VIEW_PRESETS = ["grid", "graph", "duration", "gantt", "landing_times"] -ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"] - TAG_MAX_LEN = 100 +__all__ = [ + "DAG", + "dag", +] + # TODO: Task-SDK class Context: ... @@ -135,7 +143,25 @@ DAG_ARGS_EXPECTED_TYPES = { } -@attrs.define +def _create_timetable(interval: ScheduleInterval, timezone: Timezone | FixedTimezone) -> Timetable: + """Create a Timetable instance from a plain ``schedule`` value.""" + if interval is None: + return NullTimetable() + if interval == "@once": + return OnceTimetable() + if interval == "@continuous": + return ContinuousTimetable() + if isinstance(interval, (timedelta, relativedelta)): + return DeltaDataIntervalTimetable(interval) + if isinstance(interval, str): + if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"): + return CronDataIntervalTimetable(interval, timezone) + else: + return CronTriggerTimetable(interval, timezone=timezone) + raise ValueError(f"{interval!r} is not a valid schedule.") + + +@attrs.define(kw_only=True) class DAG: """ A dag (directed acyclic graph) is a collection of tasks with directional dependencies. @@ -266,11 +292,11 @@ class DAG: # below in sync. (Search for 'def dag(' in this file.) dag_id: str description: str | None = None - schedule: ScheduleArg = NOTSET - schedule_interval: ScheduleIntervalArg = NOTSET - timetable: Timetable | None = None start_date: datetime | None = None end_date: datetime | None = None + timezone: timezone = timezone.utc + schedule: ScheduleArg = attrs.field(default=None, on_setattr=attrs.setters.NO_OP) + timetable: Timetable = attrs.field(init=False) full_filepath: str | None = None template_searchpath: str | Iterable[str] | None = None # template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined @@ -285,7 +311,7 @@ class DAG: # on_success_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None # on_failure_callback: None | DagStateChangeCallback | list[DagStateChangeCallback] = None doc_md: str | None = None - params: abc.MutableMapping | None = None + params: abc.MutableMapping | None = attrs.field(default=None) access_control: dict | None = None is_paused_upon_creation: bool | None = None jinja_environment_kwargs: dict | None = None @@ -310,6 +336,55 @@ class DAG: return TaskGroup.create_root(dag=self) + @timetable.default + def _set_schedule(self): + schedule = self.schedule + delattr(self, "schedule") + if isinstance(schedule, Timetable): + return schedule + elif isinstance(schedule, BaseAsset): + return AssetTriggeredTimetable(schedule) + elif isinstance(schedule, Collection) and not isinstance(schedule, str): + if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule): + raise ValueError("All elements in 'schedule' should be assets or asset aliases") + return AssetTriggeredTimetable(AssetAll(*schedule)) + else: + return _create_timetable(schedule, self.timezone) + + @params.validator + def _validate_params(self, attr, val: abc.MutableMapping | None): + """ + Validate Param values when the DAG has schedule defined. + + Raise exception if there are any Params which can not be resolved by their schema definition. + + This will also merge in params from default_args + """ + # TODO: Task-SDK + from airflow.models.param import ParamsDict + + val = val or {} + + # merging potentially conflicting default_args['params'] into params + if "params" in self.default_args: + val.update(self.default_args["params"]) + del self.default_args["params"] + + params = ParamsDict(val) + object.__setattr__(self, "params", params) + if not self.timetable or not self.timetable.can_be_scheduled: + return + + try: + params.validate() + except ParamValidationError as pverr: + raise ValueError( + f"DAG {self.dag_id!r} is not allowed to define a Schedule, " + "as there are required params without default values, or the default values are not valid." + ) from pverr + + # check self.params and convert them into ParamsDict + def __repr__(self): return f"<DAG: {self.dag_id}>" @@ -832,23 +907,6 @@ class DAG: """ self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = info - def validate_schedule_and_params(self): - """ - Validate Param values when the DAG has schedule defined. - - Raise exception if there are any Params which can not be resolved by their schema definition. - """ - if not self.timetable.can_be_scheduled: - return - - try: - self.params.validate() - except ParamValidationError as pverr: - raise AirflowException( - "DAG is not allowed to define a Schedule, " - "if there are any required params without default values or default values are not valid." - ) from pverr - def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]: """ Parse a given link, and verifies if it's a valid URL, or a 'mailto' link. diff --git a/task_sdk/src/airflow/sdk/definitions/node.py b/task_sdk/src/airflow/sdk/definitions/node.py index 1d98028467..40236cac8b 100644 --- a/task_sdk/src/airflow/sdk/definitions/node.py +++ b/task_sdk/src/airflow/sdk/definitions/node.py @@ -108,8 +108,8 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta): edge_modifier: EdgeModifier | None = None, ) -> None: """Set relatives for the task or task list.""" - from airflow.models.baseoperator import BaseOperator from airflow.models.mappedoperator import MappedOperator + from airflow.sdk.definitions.baseoperator import BaseOperator if not isinstance(task_or_task_list, Sequence): task_or_task_list = [task_or_task_list] diff --git a/task_sdk/src/airflow/sdk/types.py b/task_sdk/src/airflow/sdk/types.py index 505ee4cb19..a412509ba5 100644 --- a/task_sdk/src/airflow/sdk/types.py +++ b/task_sdk/src/airflow/sdk/types.py @@ -17,7 +17,7 @@ from __future__ import annotations -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any class ArgNotSet: @@ -43,7 +43,24 @@ NOTSET = ArgNotSet() if TYPE_CHECKING: import logging + from airflow.sdk.definitions.node import DAGNode + Logger = logging.Logger else: class Logger: ... + + +def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, Any]) -> None: + """Validate that the instance has the expected types for the arguments.""" + from airflow.sdk.definitions.taskgroup import TaskGroup + + typ = "task group" if isinstance(instance, TaskGroup) else "task" + + for arg_name, expected_arg_type in expected_arg_types.items(): + instance_arg_value = getattr(instance, arg_name, None) + if instance_arg_value is not None and not isinstance(instance_arg_value, expected_arg_type): + raise TypeError( + f"{arg_name!r} for {typ} {instance.node_id!r} expects {expected_arg_type}, got {type(instance_arg_value)} with value " + f"{instance_arg_value!r}" + ) diff --git a/task_sdk/tests/defintions/test_baseoperator.py b/task_sdk/tests/defintions/test_baseoperator.py new file mode 100644 index 0000000000..0222de3ee3 --- /dev/null +++ b/task_sdk/tests/defintions/test_baseoperator.py @@ -0,0 +1,305 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +from __future__ import annotations + +import warnings +from datetime import UTC, datetime, timedelta + +import pytest + +from airflow.sdk.definitions.baseoperator import BaseOperator, BaseOperatorMeta +from airflow.sdk.definitions.dag import DAG +from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy + +DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=UTC) + + +# Essentially similar to airflow.models.baseoperator.BaseOperator +class FakeOperator(metaclass=BaseOperatorMeta): + def __init__(self, test_param, params=None, default_args=None): + self.test_param = test_param + + def _set_xcomargs_dependencies(self): ... + + +class FakeSubClass(FakeOperator): + def __init__(self, test_sub_param, test_param, **kwargs): + super().__init__(test_param=test_param, **kwargs) + self.test_sub_param = test_sub_param + + +class DeprecatedOperator(BaseOperator): + def __init__(self, **kwargs): + warnings.warn("This operator is deprecated.", DeprecationWarning, stacklevel=2) + super().__init__(**kwargs) + + def execute(self, context: Context): + pass + + +class MockOperator(BaseOperator): + """Operator for testing purposes.""" + + template_fields: Sequence[str] = ("arg1", "arg2") + + def __init__(self, arg1: str = "", arg2: str = "", **kwargs): + super().__init__(**kwargs) + self.arg1 = arg1 + self.arg2 = arg2 + + def execute(self, context: Context): + pass + + +class TestBaseOperator: + # Since we have a custom metaclass, lets double check the behaviour of passing args in the wrong way (args + # etc) + def test_kwargs_only(self): + with pytest.raises(TypeError, match="keyword arguments"): + BaseOperator("task_id") + + def test_missing_kwarg(self): + with pytest.raises(TypeError, match="missing keyword argument"): + FakeOperator(task_id="task_id") + + def test_missing_kwargs(self): + with pytest.raises(TypeError, match="missing keyword arguments"): + FakeSubClass(task_id="task_id") + + def test_hash(self): + """Two operators created equally should hash equaylly""" + # Include a "non-hashable" type too + assert hash(MockOperator(task_id="one", retries=1024 * 1024, arg1="abcef", params={"a": 1})) == hash( + MockOperator(task_id="one", retries=1024 * 1024, arg1="abcef", params={"a": 2}) + ) + + def test_expand(self): + op = FakeOperator(test_param=True) + assert op.test_param + + with pytest.raises(TypeError, match="missing keyword argument 'test_param'"): + FakeSubClass(test_sub_param=True) + + def test_default_args(self): + default_args = {"test_param": True} + op = FakeOperator(default_args=default_args) + assert op.test_param + + default_args = {"test_param": True, "test_sub_param": True} + op = FakeSubClass(default_args=default_args) + assert op.test_param + assert op.test_sub_param + + default_args = {"test_param": True} + op = FakeSubClass(default_args=default_args, test_sub_param=True) + assert op.test_param + assert op.test_sub_param + + with pytest.raises(TypeError, match="missing keyword argument 'test_sub_param'"): + FakeSubClass(default_args=default_args) + + def test_execution_timeout_type(self): + with pytest.raises( + ValueError, match="execution_timeout must be timedelta object but passed as type: <class 'str'>" + ): + BaseOperator(task_id="test", execution_timeout="1") + + with pytest.raises( + ValueError, match="execution_timeout must be timedelta object but passed as type: <class 'int'>" + ): + BaseOperator(task_id="test", execution_timeout=1) + + def test_incorrect_default_args(self): + default_args = {"test_param": True, "extra_param": True} + op = FakeOperator(default_args=default_args) + assert op.test_param + + default_args = {"random_params": True} + with pytest.raises(TypeError, match="missing keyword argument 'test_param'"): + FakeOperator(default_args=default_args) + + def test_incorrect_priority_weight(self): + error_msg = "'priority_weight' for task 'test_op' expects <class 'int'>, got <class 'str'>" + with pytest.raises(TypeError, match=error_msg): + BaseOperator(task_id="test_op", priority_weight="2") + + def test_illegal_args_forbidden(self): + """ + Tests that operators raise exceptions on illegal arguments when + illegal arguments are not allowed. + """ + msg = r"Invalid arguments were passed to BaseOperator \(task_id: test_illegal_args\)" + with pytest.raises(TypeError, match=msg): + BaseOperator( + task_id="test_illegal_args", + illegal_argument_1234="hello?", + ) + + def test_invalid_type_for_default_arg(self): + error_msg = "'max_active_tis_per_dag' for task 'test' expects <class 'int'>, got <class 'str'> with value 'not_an_int'" + with pytest.raises(TypeError, match=error_msg): + BaseOperator(task_id="test", default_args={"max_active_tis_per_dag": "not_an_int"}) + + def test_invalid_type_for_operator_arg(self): + error_msg = "'max_active_tis_per_dag' for task 'test' expects <class 'int'>, got <class 'str'> with value 'not_an_int'" + with pytest.raises(TypeError, match=error_msg): + BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int") + + def test_weight_rule_default(self): + op = BaseOperator(task_id="test_task") + assert _DownstreamPriorityWeightStrategy() == op.weight_rule + + def test_weight_rule_override(self): + op = BaseOperator(task_id="test_task", weight_rule="upstream") + assert _UpstreamPriorityWeightStrategy() == op.weight_rule + + def test_warnings_are_properly_propagated(self): + with pytest.warns(DeprecationWarning) as warnings: + DeprecatedOperator(task_id="test") + assert len(warnings) == 1 + warning = warnings[0] + # Here we check that the trace points to the place + # where the deprecated class was used + assert warning.filename == __file__ + + def test_setattr_performs_no_custom_action_at_execute_time(self, spy_agency): + from airflow.models.xcom_arg import XComArg + + op = MockOperator(task_id="test_task") + # TODO: Task-SDK + # op_copy = op.prepare_for_execution() + op_copy = op + + spy_agency.spy_on(XComArg.apply_upstream_relationship, call_original=False) + op_copy.execute({}) + assert XComArg.apply_upstream_relationship.called == False + + def test_upstream_is_set_when_template_field_is_xcomarg(self): + with DAG("xcomargs_test", schedule=None): + op1 = BaseOperator(task_id="op1") + op2 = MockOperator(task_id="op2", arg1=op1.output) + + assert op1.task_id in op2.upstream_task_ids + assert op2.task_id in op1.downstream_task_ids + + def test_set_xcomargs_dependencies_works_recursively(self): + with DAG("xcomargs_test", schedule=None): + op1 = BaseOperator(task_id="op1") + op2 = BaseOperator(task_id="op2") + op3 = MockOperator(task_id="op3", arg1=[op1.output, op2.output]) + op4 = MockOperator(task_id="op4", arg1={"op1": op1.output, "op2": op2.output}) + + assert op1.task_id in op3.upstream_task_ids + assert op2.task_id in op3.upstream_task_ids + assert op1.task_id in op4.upstream_task_ids + assert op2.task_id in op4.upstream_task_ids + + def test_set_xcomargs_dependencies_works_when_set_after_init(self): + with DAG(dag_id="xcomargs_test", schedule=None): + op1 = BaseOperator(task_id="op1") + op2 = MockOperator(task_id="op2") + op2.arg1 = op1.output # value is set after init + + assert op1.task_id in op2.upstream_task_ids + + def test_set_xcomargs_dependencies_error_when_outside_dag(self): + op1 = BaseOperator(task_id="op1") + with pytest.raises(ValueError): + MockOperator(task_id="op2", arg1=op1.output) + + def test_cannot_change_dag(self): + with DAG(dag_id="dag1", schedule=None): + op1 = BaseOperator(task_id="op1") + with pytest.raises(ValueError, match="can not be changed"): + op1.dag = DAG(dag_id="dag2") + + +def test_init_subclass_args(): + class InitSubclassOp(BaseOperator): + _class_arg: Any + + def __init_subclass__(cls, class_arg=None, **kwargs) -> None: + cls._class_arg = class_arg + super().__init_subclass__() + + def execute(self, context: Context): + self.context_arg = context + + class_arg = "foo" + context = {"key": "value"} + + class ConcreteSubclassOp(InitSubclassOp, class_arg=class_arg): + pass + + task = ConcreteSubclassOp(task_id="op1") + # TODO: Task-SDK + # task_copy = task.prepare_for_execution() + task_copy = task + + task_copy.execute(context) + + assert task_copy._class_arg == class_arg + assert task_copy.context_arg == context + + +class CustomInt(int): + def __int__(self): + raise ValueError("Cannot cast to int") + + +@pytest.mark.parametrize( + ("retries", "expected"), + [ + pytest.param("foo", "'retries' type must be int, not str", id="string"), + pytest.param(CustomInt(10), "'retries' type must be int, not CustomInt", id="custom int"), + ], +) +def test_operator_retries_invalid(dag_maker, retries, expected): + with pytest.raises(TypeError) as ctx: + BaseOperator(task_id="test_illegal_args", retries=retries) + assert str(ctx.value) == expected + + +@pytest.mark.parametrize( + ("retries", "expected"), + [ + pytest.param(None, 0, id="None"), + pytest.param("5", 5, id="str"), + pytest.param(1, 1, id="int"), + ], +) +def test_operator_retries_conversion(retries, expected): + op = BaseOperator( + task_id="test_illegal_args", + retries=retries, + ) + assert op.retries == expected + + +def test_dag_level_retry_delay(dag_maker): + with DAG(dag_id="test_dag_level_retry_delay", default_args={"retry_delay": timedelta(seconds=100)}): + task1 = BaseOperator(task_id="test_no_explicit_retry_delay") + + assert task1.retry_delay == timedelta(seconds=100) + + +def test_task_level_retry_delay(): + with DAG(dag_id="test_task_level_retry_delay", default_args={"retry_delay": timedelta(seconds=100)}): + task1 = BaseOperator(task_id="test_no_explicit_retry_delay", retry_delay=200) + + assert task1.retry_delay == timedelta(seconds=200) diff --git a/task_sdk/tests/defintions/test_dag.py b/task_sdk/tests/defintions/test_dag.py index e07e3a8bfa..29249745ff 100644 --- a/task_sdk/tests/defintions/test_dag.py +++ b/task_sdk/tests/defintions/test_dag.py @@ -14,11 +14,14 @@ # KIND, either express or implied. See the License for the # specific language governing permissions and limitations # under the License. - from __future__ import annotations -from datetime import UTC, datetime +from datetime import UTC, datetime, timedelta + +import pytest +from airflow.exceptions import DuplicateTaskIdFound +from airflow.models.param import Param, ParamsDict from airflow.sdk.definitions.baseoperator import BaseOperator from airflow.sdk.definitions.dag import DAG @@ -76,3 +79,92 @@ class TestDag: assert op7.dag == dag assert op8.dag == dag assert op9.dag == dag2 + + def test_params_not_passed_is_empty_dict(self): + """ + Test that when 'params' is _not_ passed to a new Dag, that the params + attribute is set to an empty dictionary. + """ + dag = DAG("test-dag", schedule=None) + + assert isinstance(dag.params, ParamsDict) + assert 0 == len(dag.params) + + def test_params_passed_and_params_in_default_args_no_override(self): + """ + Test that when 'params' exists as a key passed to the default_args dict + in addition to params being passed explicitly as an argument to the + dag, that the 'params' key of the default_args dict is merged with the + dict of the params argument. + """ + params1 = {"parameter1": 1} + params2 = {"parameter2": 2} + + dag = DAG("test-dag", schedule=None, default_args={"params": params1}, params=params2) + + assert params1["parameter1"] == dag.params["parameter1"] + assert params2["parameter2"] == dag.params["parameter2"] + + def test_not_none_schedule_with_non_default_params(self): + """ + Test if there is a DAG with a schedule and have some params that don't have a default value raise a + error while DAG parsing. (Because we can't schedule them if there we don't know what value to use) + """ + params = {"param1": Param(type="string")} + + with pytest.raises(ValueError): + DAG("my-dag", schedule=timedelta(days=1), start_date=DEFAULT_DATE, params=params) + + def test_roots(self): + """Verify if dag.roots returns the root tasks of a DAG.""" + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + op1 = BaseOperator(task_id="t1") + op2 = BaseOperator(task_id="t2") + op3 = BaseOperator(task_id="t3") + op4 = BaseOperator(task_id="t4") + op5 = BaseOperator(task_id="t5") + [op1, op2] >> op3 >> [op4, op5] + + assert set(dag.roots) == {op1, op2} + + def test_leaves(self): + """Verify if dag.leaves returns the leaf tasks of a DAG.""" + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + op1 = BaseOperator(task_id="t1") + op2 = BaseOperator(task_id="t2") + op3 = BaseOperator(task_id="t3") + op4 = BaseOperator(task_id="t4") + op5 = BaseOperator(task_id="t5") + [op1, op2] >> op3 >> [op4, op5] + + assert set(dag.leaves) == {op4, op5} + + def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self): + """Verify tasks with Duplicate task_id raises error""" + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + op1 = BaseOperator(task_id="t1") + with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"): + BaseOperator(task_id="t1") + + assert dag.task_dict == {op1.task_id: op1} + + def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self): + """Verify tasks with Duplicate task_id raises error""" + dag = DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) + op1 = BaseOperator(task_id="t1", dag=dag) + with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"): + BaseOperator(task_id="t1", dag=dag) + + assert dag.task_dict == {op1.task_id: op1} + + def test_duplicate_task_ids_for_same_task_is_allowed(self): + """Verify that same tasks with Duplicate task_id do not raise error""" + with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: + op1 = op2 = BaseOperator(task_id="t1") + op3 = BaseOperator(task_id="t3") + op1 >> op3 + op2 >> op3 + + assert op1 == op2 + assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3} + assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3} diff --git a/tests/models/test_baseoperator.py b/tests/models/test_baseoperator.py index 999529e14a..eea5dae269 100644 --- a/tests/models/test_baseoperator.py +++ b/tests/models/test_baseoperator.py @@ -22,7 +22,7 @@ import logging import uuid from collections import defaultdict from datetime import date, datetime, timedelta -from typing import TYPE_CHECKING, Any, NamedTuple +from typing import NamedTuple from unittest import mock import jinja2 @@ -32,9 +32,7 @@ from airflow.decorators import task as task_decorator from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule from airflow.lineage.entities import File from airflow.models.baseoperator import ( - BASEOPERATOR_ARGS_EXPECTED_TYPES, BaseOperator, - BaseOperatorMeta, chain, chain_linear, cross_downstream, @@ -43,7 +41,6 @@ from airflow.models.dag import DAG from airflow.models.dagrun import DagRun from airflow.models.taskinstance import TaskInstance from airflow.providers.common.sql.operators import sql -from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, _UpstreamPriorityWeightStrategy from airflow.utils.edgemodifier import Label from airflow.utils.task_group import TaskGroup from airflow.utils.template import literal @@ -51,10 +48,7 @@ from airflow.utils.trigger_rule import TriggerRule from airflow.utils.types import DagRunType from tests.models import DEFAULT_DATE -from dev.tests_common.test_utils.mock_operators import DeprecatedOperator, MockOperator - -if TYPE_CHECKING: - from airflow.utils.context import Context +from dev.tests_common.test_utils.mock_operators import MockOperator class ClassWithCustomAttributes: @@ -83,93 +77,12 @@ object2 = ClassWithCustomAttributes(attr="{{ foo }}_2", ref=object1, template_fi setattr(object1, "ref", object2) -# Essentially similar to airflow.models.baseoperator.BaseOperator -class DummyClass(metaclass=BaseOperatorMeta): - def __init__(self, test_param, params=None, default_args=None): - self.test_param = test_param - - def set_xcomargs_dependencies(self): ... - - -class DummySubClass(DummyClass): - def __init__(self, test_sub_param, **kwargs): - super().__init__(**kwargs) - self.test_sub_param = test_sub_param - - class MockNamedTuple(NamedTuple): var1: str var2: str -class CustomInt(int): - def __int__(self): - raise ValueError("Cannot cast to int") - - class TestBaseOperator: - def test_expand(self): - dummy = DummyClass(test_param=True) - assert dummy.test_param - - with pytest.raises(AirflowException, match="missing keyword argument 'test_param'"): - DummySubClass(test_sub_param=True) - - def test_default_args(self): - default_args = {"test_param": True} - dummy_class = DummyClass(default_args=default_args) - assert dummy_class.test_param - - default_args = {"test_param": True, "test_sub_param": True} - dummy_subclass = DummySubClass(default_args=default_args) - assert dummy_class.test_param - assert dummy_subclass.test_sub_param - - default_args = {"test_param": True} - dummy_subclass = DummySubClass(default_args=default_args, test_sub_param=True) - assert dummy_class.test_param - assert dummy_subclass.test_sub_param - - with pytest.raises(AirflowException, match="missing keyword argument 'test_sub_param'"): - DummySubClass(default_args=default_args) - - def test_execution_timeout_type(self): - with pytest.raises( - ValueError, match="execution_timeout must be timedelta object but passed as type: <class 'str'>" - ): - BaseOperator(task_id="test", execution_timeout="1") - - with pytest.raises( - ValueError, match="execution_timeout must be timedelta object but passed as type: <class 'int'>" - ): - BaseOperator(task_id="test", execution_timeout=1) - - def test_incorrect_default_args(self): - default_args = {"test_param": True, "extra_param": True} - dummy_class = DummyClass(default_args=default_args) - assert dummy_class.test_param - - default_args = {"random_params": True} - with pytest.raises(AirflowException, match="missing keyword argument 'test_param'"): - DummyClass(default_args=default_args) - - def test_incorrect_priority_weight(self): - error_msg = "`priority_weight` for task 'test_op' only accepts integers, received '<class 'str'>'." - with pytest.raises(AirflowException, match=error_msg): - BaseOperator(task_id="test_op", priority_weight="2") - - def test_illegal_args_forbidden(self): - """ - Tests that operators raise exceptions on illegal arguments when - illegal arguments are not allowed. - """ - msg = r"Invalid arguments were passed to BaseOperator \(task_id: test_illegal_args\)" - with pytest.raises(AirflowException, match=msg): - BaseOperator( - task_id="test_illegal_args", - illegal_argument_1234="hello?", - ) - def test_trigger_rule_validation(self): from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE @@ -659,15 +572,6 @@ class TestBaseOperator: task4 > [inlet, outlet, extra] assert task4.get_outlet_defs() == [inlet, outlet, extra] - def test_warnings_are_properly_propagated(self): - with pytest.warns(DeprecationWarning) as warnings: - DeprecatedOperator(task_id="test") - assert len(warnings) == 1 - warning = warnings[0] - # Here we check that the trace points to the place - # where the deprecated class was used - assert warning.filename == __file__ - def test_pre_execute_hook(self): hook = mock.MagicMock() @@ -694,47 +598,6 @@ class TestBaseOperator: assert op_no_dag.start_date.tzinfo assert op_no_dag.end_date.tzinfo - def test_setattr_performs_no_custom_action_at_execute_time(self): - op = MockOperator(task_id="test_task") - op_copy = op.prepare_for_execution() - - with mock.patch("airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies") as method_mock: - op_copy.execute({}) - assert method_mock.call_count == 0 - - def test_upstream_is_set_when_template_field_is_xcomarg(self): - with DAG("xcomargs_test", schedule=None, default_args={"start_date": datetime.today()}): - op1 = BaseOperator(task_id="op1") - op2 = MockOperator(task_id="op2", arg1=op1.output) - - assert op1 in op2.upstream_list - assert op2 in op1.downstream_list - - def test_set_xcomargs_dependencies_works_recursively(self): - with DAG("xcomargs_test", schedule=None, default_args={"start_date": datetime.today()}): - op1 = BaseOperator(task_id="op1") - op2 = BaseOperator(task_id="op2") - op3 = MockOperator(task_id="op3", arg1=[op1.output, op2.output]) - op4 = MockOperator(task_id="op4", arg1={"op1": op1.output, "op2": op2.output}) - - assert op1 in op3.upstream_list - assert op2 in op3.upstream_list - assert op1 in op4.upstream_list - assert op2 in op4.upstream_list - - def test_set_xcomargs_dependencies_works_when_set_after_init(self): - with DAG(dag_id="xcomargs_test", schedule=None, default_args={"start_date": datetime.today()}): - op1 = BaseOperator(task_id="op1") - op2 = MockOperator(task_id="op2") - op2.arg1 = op1.output # value is set after init - - assert op1 in op2.upstream_list - - def test_set_xcomargs_dependencies_error_when_outside_dag(self): - op1 = BaseOperator(task_id="op1") - with pytest.raises(AirflowException): - MockOperator(task_id="op2", arg1=op1.output) - def test_invalid_trigger_rule(self): with pytest.raises( AirflowException, @@ -745,14 +608,6 @@ class TestBaseOperator: ): BaseOperator(task_id="op1", trigger_rule="some_rule") - def test_weight_rule_default(self): - op = BaseOperator(task_id="test_task") - assert _DownstreamPriorityWeightStrategy() == op.weight_rule - - def test_weight_rule_override(self): - op = BaseOperator(task_id="test_task", weight_rule="upstream") - assert _UpstreamPriorityWeightStrategy() == op.weight_rule - # ensure the default logging config is used for this test, no matter what ran before @pytest.mark.usefixtures("reset_logging_config") def test_logging_propogated_by_default(self, caplog): @@ -763,92 +618,6 @@ class TestBaseOperator: # leaking a lot of state) assert caplog.messages == ["test"] - def test_invalid_type_for_default_arg(self): - error_msg = "'max_active_tis_per_dag' has an invalid type <class 'str'> with value not_an_int, expected type is <class 'int'>" - with pytest.raises(TypeError, match=error_msg): - BaseOperator(task_id="test", default_args={"max_active_tis_per_dag": "not_an_int"}) - - def test_invalid_type_for_operator_arg(self): - error_msg = "'max_active_tis_per_dag' has an invalid type <class 'str'> with value not_an_int, expected type is <class 'int'>" - with pytest.raises(TypeError, match=error_msg): - BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int") - - @mock.patch("airflow.models.baseoperator.validate_instance_args") - def test_baseoperator_init_validates_arg_types(self, mock_validate_instance_args): - operator = BaseOperator(task_id="test") - - mock_validate_instance_args.assert_called_once_with(operator, BASEOPERATOR_ARGS_EXPECTED_TYPES) - - -def test_init_subclass_args(): - class InitSubclassOp(BaseOperator): - _class_arg: Any - - def __init_subclass__(cls, class_arg=None, **kwargs) -> None: - cls._class_arg = class_arg - super().__init_subclass__() - - def execute(self, context: Context): - self.context_arg = context - - class_arg = "foo" - context = {"key": "value"} - - class ConcreteSubclassOp(InitSubclassOp, class_arg=class_arg): - pass - - task = ConcreteSubclassOp(task_id="op1") - task_copy = task.prepare_for_execution() - - task_copy.execute(context) - - assert task_copy._class_arg == class_arg - assert task_copy.context_arg == context - - -@pytest.mark.db_test -@pytest.mark.parametrize( - ("retries", "expected"), - [ - pytest.param("foo", "'retries' type must be int, not str", id="string"), - pytest.param(CustomInt(10), "'retries' type must be int, not CustomInt", id="custom int"), - ], -) -def test_operator_retries_invalid(dag_maker, retries, expected): - with pytest.raises(AirflowException) as ctx: - with dag_maker(): - BaseOperator(task_id="test_illegal_args", retries=retries) - assert str(ctx.value) == expected - - -@pytest.mark.db_test -@pytest.mark.parametrize( - ("retries", "expected"), - [ - pytest.param(None, [], id="None"), - pytest.param(5, [], id="5"), - pytest.param( - "1", - [ - ( - "airflow.models.baseoperator.BaseOperator", - logging.WARNING, - "Implicitly converting 'retries' from '1' to int", - ), - ], - id="str", - ), - ], -) -def test_operator_retries(caplog, dag_maker, retries, expected): - with caplog.at_level(logging.WARNING): - with dag_maker(): - BaseOperator( - task_id="test_illegal_args", - retries=retries, - ) - assert caplog.record_tuples == expected - @pytest.mark.db_test def test_default_retry_delay(dag_maker): @@ -858,24 +627,6 @@ def test_default_retry_delay(dag_maker): assert task1.retry_delay == timedelta(seconds=300) -@pytest.mark.db_test -def test_dag_level_retry_delay(dag_maker): - with dag_maker(dag_id="test_dag_level_retry_delay", default_args={"retry_delay": timedelta(seconds=100)}): - task1 = BaseOperator(task_id="test_no_explicit_retry_delay") - - assert task1.retry_delay == timedelta(seconds=100) - - -@pytest.mark.db_test -def test_task_level_retry_delay(dag_maker): - with dag_maker( - dag_id="test_task_level_retry_delay", default_args={"retry_delay": timedelta(seconds=100)} - ): - task1 = BaseOperator(task_id="test_no_explicit_retry_delay", retry_delay=timedelta(seconds=200)) - - assert task1.retry_delay == timedelta(seconds=200) - - def test_deepcopy(): # Test bug when copying an operator attached to a DAG with DAG("dag0", schedule=None, start_date=DEFAULT_DATE) as dag: diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py index 67dc699fc3..c0b705d641 100644 --- a/tests/models/test_dag.py +++ b/tests/models/test_dag.py @@ -42,7 +42,6 @@ from airflow.configuration import conf from airflow.decorators import setup, task as task_decorator, teardown from airflow.exceptions import ( AirflowException, - DuplicateTaskIdFound, ParamValidationError, UnknownExecutorException, ) @@ -65,7 +64,7 @@ from airflow.models.dag import ( get_asset_triggered_next_run_info, ) from airflow.models.dagrun import DagRun -from airflow.models.param import DagParam, Param, ParamsDict +from airflow.models.param import DagParam, Param from airflow.models.serialized_dag import SerializedDagModel from airflow.models.taskfail import TaskFail from airflow.models.taskinstance import TaskInstance as TI @@ -175,41 +174,6 @@ class TestDag: b_index = i return 0 <= a_index < b_index - def test_params_not_passed_is_empty_dict(self): - """ - Test that when 'params' is _not_ passed to a new Dag, that the params - attribute is set to an empty dictionary. - """ - dag = DAG("test-dag", schedule=None) - - assert isinstance(dag.params, ParamsDict) - assert 0 == len(dag.params) - - def test_params_passed_and_params_in_default_args_no_override(self): - """ - Test that when 'params' exists as a key passed to the default_args dict - in addition to params being passed explicitly as an argument to the - dag, that the 'params' key of the default_args dict is merged with the - dict of the params argument. - """ - params1 = {"parameter1": 1} - params2 = {"parameter2": 2} - - dag = DAG("test-dag", schedule=None, default_args={"params": params1}, params=params2) - - assert params1["parameter1"] == dag.params["parameter1"] - assert params2["parameter2"] == dag.params["parameter2"] - - def test_not_none_schedule_with_non_default_params(self): - """ - Test if there is a DAG with not None schedule and have some params that - don't have a default value raise a error while DAG parsing - """ - params = {"param1": Param(type="string")} - - with pytest.raises(AirflowException): - DAG("dummy-dag", schedule=timedelta(days=1), start_date=DEFAULT_DATE, params=params) - def test_dag_invalid_default_view(self): """ Test invalid `default_view` of DAG initialization @@ -238,57 +202,6 @@ class TestDag: dag = DAG(dag_id="test-default_orientation", schedule=None) assert conf.get("webserver", "dag_orientation") == dag.orientation - def test_dag_as_context_manager(self): - """ - Test DAG as a context manager. - When used as a context manager, Operators are automatically added to - the DAG (unless they specify a different DAG) - """ - dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) - dag2 = DAG("dag2", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner2"}) - - with dag: - op1 = EmptyOperator(task_id="op1") - op2 = EmptyOperator(task_id="op2", dag=dag2) - - assert op1.dag is dag - assert op1.owner == "owner1" - assert op2.dag is dag2 - assert op2.owner == "owner2" - - with dag2: - op3 = EmptyOperator(task_id="op3") - - assert op3.dag is dag2 - assert op3.owner == "owner2" - - with dag: - with dag2: - op4 = EmptyOperator(task_id="op4") - op5 = EmptyOperator(task_id="op5") - - assert op4.dag is dag2 - assert op5.dag is dag - assert op4.owner == "owner2" - assert op5.owner == "owner1" - - with DAG("creating_dag_in_cm", schedule=None, start_date=DEFAULT_DATE) as dag: - EmptyOperator(task_id="op6") - - assert dag.dag_id == "creating_dag_in_cm" - assert dag.tasks[0].task_id == "op6" - - with dag: - with dag: - op7 = EmptyOperator(task_id="op7") - op8 = EmptyOperator(task_id="op8") - op9 = EmptyOperator(task_id="op8") - op9.dag = dag2 - - assert op7.dag == dag - assert op8.dag == dag - assert op9.dag == dag2 - def test_dag_topological_sort_dag_without_tasks(self): dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE, default_args={"owner": "owner1"}) @@ -1287,60 +1200,6 @@ class TestDag: dag = DAG("DAG", schedule=None, default_args=default_args) assert dag.timezone.name == local_tz.name - def test_roots(self): - """Verify if dag.roots returns the root tasks of a DAG.""" - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - op1 = EmptyOperator(task_id="t1") - op2 = EmptyOperator(task_id="t2") - op3 = EmptyOperator(task_id="t3") - op4 = EmptyOperator(task_id="t4") - op5 = EmptyOperator(task_id="t5") - [op1, op2] >> op3 >> [op4, op5] - - assert set(dag.roots) == {op1, op2} - - def test_leaves(self): - """Verify if dag.leaves returns the leaf tasks of a DAG.""" - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - op1 = EmptyOperator(task_id="t1") - op2 = EmptyOperator(task_id="t2") - op3 = EmptyOperator(task_id="t3") - op4 = EmptyOperator(task_id="t4") - op5 = EmptyOperator(task_id="t5") - [op1, op2] >> op3 >> [op4, op5] - - assert set(dag.leaves) == {op4, op5} - - def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self): - """Verify tasks with Duplicate task_id raises error""" - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - op1 = EmptyOperator(task_id="t1") - with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"): - BashOperator(task_id="t1", bash_command="sleep 1") - - assert dag.task_dict == {op1.task_id: op1} - - def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self): - """Verify tasks with Duplicate task_id raises error""" - dag = DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) - op1 = EmptyOperator(task_id="t1", dag=dag) - with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has already been added to the DAG"): - EmptyOperator(task_id="t1", dag=dag) - - assert dag.task_dict == {op1.task_id: op1} - - def test_duplicate_task_ids_for_same_task_is_allowed(self): - """Verify that same tasks with Duplicate task_id do not raise error""" - with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: - op1 = op2 = EmptyOperator(task_id="t1") - op3 = EmptyOperator(task_id="t3") - op1 >> op3 - op2 >> op3 - - assert op1 == op2 - assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3} - assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3} - def test_partial_subset_updates_all_references_while_deepcopy(self): with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag: op1 = EmptyOperator(task_id="t1") diff --git a/uv.lock b/uv.lock index 3a3c2db04b..1fecc75854 100644 --- a/uv.lock +++ b/uv.lock @@ -1373,6 +1373,7 @@ dependencies = [ [package.dev-dependencies] dev = [ + { name = "kgb" }, { name = "pytest" }, { name = "pytest-asyncio" }, { name = "pytest-mock" }, @@ -1387,6 +1388,7 @@ requires-dist = [ [package.metadata.requires-dev] dev = [ + { name = "kgb", specifier = ">=7.1.1" }, { name = "pytest", specifier = ">=8.3.3" }, { name = "pytest-asyncio", specifier = ">=0.24.0" }, { name = "pytest-mock", specifier = ">=3.14.0" }, @@ -2452,6 +2454,15 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/d1/0f/8910b19ac0670a0f80ce1008e5e751c4a57e14d2c4c13a482aa6079fa9d6/jsonschema_specifications-2024.10.1-py3-none-any.whl", hash = "sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf", size = 18459 }, ] +[[package]] +name = "kgb" +version = "7.1.1" +source = { registry = "https://pypi.org/simple" } +sdist = { url = "https://files.pythonhosted.org/packages/0c/2e/2b608fa158cd87d7372b1d1c94d70b9b90e4ab5316c77f26feb1e4b6549f/kgb-7.1.1.tar.gz", hash = "sha256:74912c8761651f2063151c6c2a36ebe023393de491ec86744771a2888ab9845b", size = 61504 } +wheels = [ + { url = "https://files.pythonhosted.org/packages/80/45/ae8db25f019419b17359ca98f129c0a0d9fa40cadeaac3525b02b690e705/kgb-7.1.1-py2.py3-none-any.whl", hash = "sha256:ed535b25caa5d8151bb8700c653a73475a6d3937c75cd2b8ce93c84c97a86a6f", size = 58003 }, +] + [[package]] name = "lazy-object-proxy" version = "1.10.0"