This is an automated email from the ASF dual-hosted git repository.

ash pushed a commit to branch task-sdk-first-code
in repository https://gitbox.apache.org/repos/asf/airflow.git


The following commit(s) were added to refs/heads/task-sdk-first-code by this 
push:
     new 44efb0b7c1 Move over more of BaseOperator and DAG, along with their 
tests
44efb0b7c1 is described below

commit 44efb0b7c1d383ed8b71d2e6ec25573f542abe56
Author: Ash Berlin-Taylor <a...@apache.org>
AuthorDate: Mon Oct 14 22:16:00 2024 +0100

    Move over more of BaseOperator and DAG, along with their tests
---
 airflow/task/priority_strategy.py                  |   4 +-
 dev/tests_common/test_utils/mock_operators.py      |  10 -
 task_sdk/pyproject.toml                            |  19 ++
 .../airflow/sdk/definitions/abstractoperator.py    |   5 +-
 .../src/airflow/sdk/definitions/baseoperator.py    | 352 ++++++++++++++-------
 task_sdk/src/airflow/sdk/definitions/dag.py        | 114 +++++--
 task_sdk/src/airflow/sdk/definitions/node.py       |   2 +-
 task_sdk/src/airflow/sdk/types.py                  |  19 +-
 task_sdk/tests/defintions/test_baseoperator.py     | 305 ++++++++++++++++++
 task_sdk/tests/defintions/test_dag.py              |  96 +++++-
 tests/models/test_baseoperator.py                  | 253 +--------------
 tests/models/test_dag.py                           | 143 +--------
 uv.lock                                            |  11 +
 13 files changed, 768 insertions(+), 565 deletions(-)

diff --git a/airflow/task/priority_strategy.py 
b/airflow/task/priority_strategy.py
index c22bdfa994..dcef1c865b 100644
--- a/airflow/task/priority_strategy.py
+++ b/airflow/task/priority_strategy.py
@@ -22,8 +22,6 @@ from __future__ import annotations
 from abc import ABC, abstractmethod
 from typing import TYPE_CHECKING, Any
 
-from airflow.exceptions import AirflowException
-
 if TYPE_CHECKING:
     from airflow.models.taskinstance import TaskInstance
 
@@ -150,5 +148,5 @@ def validate_and_load_priority_weight_strategy(
         priority_weight_strategy_class = qualname(priority_weight_strategy)
     loaded_priority_weight_strategy = 
_get_registered_priority_weight_strategy(priority_weight_strategy_class)
     if loaded_priority_weight_strategy is None:
-        raise AirflowException(f"Unknown priority strategy 
{priority_weight_strategy_class}")
+        raise ValueError(f"Unknown priority strategy 
{priority_weight_strategy_class}")
     return loaded_priority_weight_strategy()
diff --git a/dev/tests_common/test_utils/mock_operators.py 
b/dev/tests_common/test_utils/mock_operators.py
index 0df0afec82..ecf8989f4b 100644
--- a/dev/tests_common/test_utils/mock_operators.py
+++ b/dev/tests_common/test_utils/mock_operators.py
@@ -16,7 +16,6 @@
 # under the License.
 from __future__ import annotations
 
-import warnings
 from typing import TYPE_CHECKING, Any, Sequence
 
 import attr
@@ -200,12 +199,3 @@ class GithubLink(BaseOperatorLink):
 
     def get_link(self, operator, *, ti_key):
         return "https://github.com/apache/airflow";
-
-
-class DeprecatedOperator(BaseOperator):
-    def __init__(self, **kwargs):
-        warnings.warn("This operator is deprecated.", DeprecationWarning, 
stacklevel=2)
-        super().__init__(**kwargs)
-
-    def execute(self, context: Context):
-        pass
diff --git a/task_sdk/pyproject.toml b/task_sdk/pyproject.toml
index f83a4b7ec2..f371c3744a 100644
--- a/task_sdk/pyproject.toml
+++ b/task_sdk/pyproject.toml
@@ -46,7 +46,26 @@ namespace-packages = ["src/airflow"]
 
 [tool.uv]
 dev-dependencies = [
+    "kgb>=7.1.1",
     "pytest-asyncio>=0.24.0",
     "pytest-mock>=3.14.0",
     "pytest>=8.3.3",
 ]
+
+[tool.coverage.run]
+branch = true
+relative_files = true
+source = ["src/airflow"]
+include_namespace_packages = true
+
+[tool.coverage.report]
+skip_empty = true
+exclude_also = [
+    "def __repr__",
+    "raise AssertionError",
+    "raise NotImplementedError",
+    "if __name__ == .__main__.:",
+    "@(abc\\.)?abstractmethod",
+    "@(typing(_extensions)?\\.)?overload",
+    "if (typing(_extensions)?\\.)?TYPE_CHECKING:",
+]
diff --git a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py 
b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
index d37a478df3..2dd0d48804 100644
--- a/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/abstractoperator.py
@@ -50,12 +50,9 @@ DEFAULT_RETRIES: int = 0
 DEFAULT_RETRY_DELAY: datetime.timedelta = datetime.timedelta(300)
 MAX_RETRY_DELAY: int = 24 * 60 * 60
 
-# DEFAULT_WEIGHT_RULE: WeightRule = WeightRule(
-#     conf.get("core", "default_task_weight_rule", 
fallback=WeightRule.DOWNSTREAM)
-# )
 # DEFAULT_TRIGGER_RULE: TriggerRule = TriggerRule.ALL_SUCCESS
-DEFAULT_WEIGHT_RULE = 0
 DEFAULT_TRIGGER_RULE = "ALL_SUCCESS"  # TriggerRule.ALL_SUCCESS
+DEFAULT_WEIGHT_RULE = "downstream"
 DEFAULT_TASK_EXECUTION_TIMEOUT: datetime.timedelta | None = None
 
 
diff --git a/task_sdk/src/airflow/sdk/definitions/baseoperator.py 
b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
index d52802db21..6480a89780 100644
--- a/task_sdk/src/airflow/sdk/definitions/baseoperator.py
+++ b/task_sdk/src/airflow/sdk/definitions/baseoperator.py
@@ -22,8 +22,9 @@ import collections.abc
 import contextlib
 import copy
 import inspect
-from collections.abc import Sequence
+from collections.abc import Iterable, Sequence
 from dataclasses import dataclass
+from datetime import datetime, timedelta
 from functools import total_ordering, wraps
 from types import FunctionType
 from typing import TYPE_CHECKING, Any, ClassVar, TypeVar, cast
@@ -44,6 +45,9 @@ from airflow.sdk.definitions.abstractoperator import (
 )
 from airflow.sdk.definitions.decorators import fixup_decorator_warning_stack
 from airflow.sdk.definitions.node import validate_key
+from airflow.sdk.types import NOTSET, validate_instance_args
+from airflow.task.priority_strategy import PriorityWeightStrategy, 
validate_and_load_priority_weight_strategy
+from airflow.utils.types import AttributeRemoved
 
 T = TypeVar("T", bound=FunctionType)
 
@@ -59,7 +63,7 @@ if TYPE_CHECKING:
 
 # TODO: Task-SDK
 AirflowException = RuntimeError
-ParamsDict = object
+ParamsDict = dict
 
 
 def _get_parent_defaults(dag: DAG | None, task_group: TaskGroup | None) -> 
tuple[dict, ParamsDict]:
@@ -83,11 +87,11 @@ def get_merged_defaults(
     args, params = _get_parent_defaults(dag, task_group)
     if task_params:
         if not isinstance(task_params, collections.abc.Mapping):
-            raise TypeError("params must be a mapping")
+            raise TypeError(f"params must be a mapping, got 
{type(task_params)}")
         params.update(task_params)
     if task_default_args:
         if not isinstance(task_default_args, collections.abc.Mapping):
-            raise TypeError("default_args must be a mapping")
+            raise TypeError(f"default_args must be a mapping, got 
{type(task_params)}")
         args.update(task_default_args)
         with contextlib.suppress(KeyError):
             params.update(task_default_args["params"] or {})
@@ -130,7 +134,7 @@ class BaseOperatorMeta(abc.ABCMeta):
             from airflow.sdk.definitions.contextmanager import DagContext, 
TaskGroupContext
 
             if args:
-                raise AirflowException("Use keyword arguments when 
initializing operators")
+                raise TypeError("Use keyword arguments when initializing 
operators")
 
             instantiated_from_mapped = kwargs.pop(
                 "_airflow_from_mapped",
@@ -155,10 +159,10 @@ class BaseOperatorMeta(abc.ABCMeta):
 
             missing_args = non_optional_args.difference(kwargs)
             if len(missing_args) == 1:
-                raise AirflowException(f"missing keyword argument 
{missing_args.pop()!r}")
+                raise TypeError(f"missing keyword argument 
{missing_args.pop()!r}")
             elif missing_args:
                 display = ", ".join(repr(a) for a in sorted(missing_args))
-                raise AirflowException(f"missing keyword arguments {display}")
+                raise TypeError(f"missing keyword arguments {display}")
 
             if merged_params:
                 kwargs["params"] = merged_params
@@ -169,8 +173,8 @@ class BaseOperatorMeta(abc.ABCMeta):
                 default_args = kwargs.pop("default_args", {})
 
             if not hasattr(self, "_BaseOperator__init_kwargs"):
-                self._BaseOperator__init_kwargs = {}
-            self._BaseOperator__from_mapped = instantiated_from_mapped
+                object.__setattr__(self, "_BaseOperator__init_kwargs", {})
+            object.__setattr__(self, "_BaseOperator__from_mapped", 
instantiated_from_mapped)
 
             result = func(self, **kwargs, default_args=default_args)
 
@@ -180,9 +184,9 @@ class BaseOperatorMeta(abc.ABCMeta):
             # Set upstream task defined by XComArgs passed to template fields 
of the operator.
             # BUT: only do this _ONCE_, not once for each class in the 
hierarchy
             if not instantiated_from_mapped and func == 
self.__init__.__wrapped__:  # type: ignore[misc]
-                self.set_xcomargs_dependencies()
-                # Mark instance as instantiated.
-                self._BaseOperator__instantiated = True
+                self._set_xcomargs_dependencies()
+                # Mark instance as instantiated so that futre attr setting 
updates xcomarg-based deps.
+                object.__setattr__(self, "_BaseOperator__instantiated", True)
 
             return result
 
@@ -213,11 +217,60 @@ class BaseOperatorMeta(abc.ABCMeta):
         return new_cls
 
 
+# TODO: The following mapping is used to validate that the arguments passed to 
the BaseOperator are of the
+#  correct type. This is a temporary solution until we find a more 
sophisticated method for argument
+#  validation. One potential method is to use `get_type_hints` from the typing 
module. However, this is not
+#  fully compatible with future annotations for Python versions below 3.10. 
Once we require a minimum Python
+#  version that supports `get_type_hints` effectively or find a better 
approach, we can replace this
+#  manual type-checking method.
+BASEOPERATOR_ARGS_EXPECTED_TYPES = {
+    "task_id": str,
+    "email": (str, Sequence),
+    "email_on_retry": bool,
+    "email_on_failure": bool,
+    "retries": int,
+    "retry_exponential_backoff": bool,
+    "depends_on_past": bool,
+    "ignore_first_depends_on_past": bool,
+    "wait_for_past_depends_before_skipping": bool,
+    "wait_for_downstream": bool,
+    "priority_weight": int,
+    "queue": str,
+    "pool": str,
+    "pool_slots": int,
+    "trigger_rule": str,
+    "run_as_user": str,
+    "task_concurrency": int,
+    "map_index_template": str,
+    "max_active_tis_per_dag": int,
+    "max_active_tis_per_dagrun": int,
+    "executor": str,
+    "do_xcom_push": bool,
+    "multiple_outputs": bool,
+    "doc": str,
+    "doc_md": str,
+    "doc_json": str,
+    "doc_yaml": str,
+    "doc_rst": str,
+    "task_display_name": str,
+    "logger_name": str,
+    "allow_nested_operators": bool,
+    "start_date": datetime,
+    "end_date": datetime,
+}
+
+
+# Note: BaseOperator is defined as a dataclass, and not an attrs class as we 
do too much metaprogramming in
+# here (metaclass, custom `__setattr__` behaviour) and this fights with attrs 
too much to make it worth it.
+#
+# To future reader: if you want to try and make this a "normal" attrs class, 
go ahead and attempt it. If you
+# get no where leave your record here for the next poor soul and what problems 
you ran in to.
+#
+# @ashb, 2024/10/14
+# - "Can't combine custom __setattr__ with on_setattr hooks"
+# - Setting class-wide `define(on_setarrs=...)` isn't called for non-attrs 
subclasses
 @total_ordering
-@dataclass(
-    init=False,
-    repr=False,
-)
+@dataclass(repr=False, kw_only=True)
 class BaseOperator(AbstractOperator, metaclass=BaseOperatorMeta):
     r"""
     Abstract base class for all operators.
@@ -433,7 +486,59 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
                 hello_world_task.execute(context)
     """
 
-    # Implementing Operator.
+    task_id: str
+    owner: str = DEFAULT_OWNER
+    email: str | Sequence[str] | None = None
+    retries: int | None = DEFAULT_RETRIES
+    retry_delay: timedelta | float = DEFAULT_RETRY_DELAY
+    retry_exponential_backoff: bool = False
+    max_retry_delay: timedelta | float | None = None
+    start_date: datetime | None = None
+    end_date: datetime | None = None
+    depends_on_past: bool = False
+    ignore_first_depends_on_past: bool = DEFAULT_IGNORE_FIRST_DEPENDS_ON_PAST
+    wait_for_past_depends_before_skipping: bool = 
DEFAULT_WAIT_FOR_PAST_DEPENDS_BEFORE_SKIPPING
+    wait_for_downstream: bool = False
+    dag: DAG | None = None
+    params: MutableMapping | None = None
+    default_args: dict | None = None
+    priority_weight: int = DEFAULT_PRIORITY_WEIGHT
+    # TODO:
+    weight_rule: PriorityWeightStrategy | str = DEFAULT_WEIGHT_RULE
+    queue: str = DEFAULT_QUEUE
+    pool: str = "default"
+    pool_slots: int = DEFAULT_POOL_SLOTS
+    execution_timeout: timedelta | None = DEFAULT_TASK_EXECUTION_TIMEOUT
+    # on_execute_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None
+    # on_failure_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None
+    # on_success_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None
+    # on_retry_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None
+    # on_skipped_callback: None | TaskStateChangeCallback | 
list[TaskStateChangeCallback] = None
+    # pre_execute: TaskPreExecuteHook | None = None
+    # post_execute: TaskPostExecuteHook | None = None
+    trigger_rule: str = DEFAULT_TRIGGER_RULE
+    resources: dict[str, Any] | None = None
+    run_as_user: str | None = None
+    task_concurrency: int | None = None
+    map_index_template: str | None = None
+    max_active_tis_per_dag: int | None = None
+    max_active_tis_per_dagrun: int | None = None
+    executor: str | None = None
+    executor_config: dict | None = None
+    do_xcom_push: bool = True
+    multiple_outputs: bool = False
+    inlets: Any | None = None
+    outlets: Any | None = None
+    task_group: TaskGroup | None = None
+    doc: str | None = None
+    doc_md: str | None = None
+    doc_json: str | None = None
+    doc_yaml: str | None = None
+    doc_rst: str | None = None
+    _task_display_name: str
+    logger_name: str | None = None
+    allow_nested_operators: bool = True
+
     template_fields: ClassVar[Sequence[str]] = ()
     template_ext: ClassVar[Sequence[str]] = ()
 
@@ -443,12 +548,10 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     ui_color: str = "#fff"
     ui_fgcolor: str = "#000"
 
-    pool: str = ""
-
-    # TODO: Mapping
+    # TODO: Task-SDK Mapping
     # partial: Callable[..., OperatorPartial] = _PartialDescriptor()  # type: 
ignore
 
-    _comps = {
+    _comps: ClassVar[set[str]] = {
         "task_id",
         "dag_id",
         "owner",
@@ -476,29 +579,45 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     }
 
     # Defines if the operator supports lineage without manual definitions
-    supports_lineage = False
+    supports_lineage: bool = False
 
     # If True then the class constructor was called
-    __instantiated = False
+    __instantiated: bool = False
     # List of args as passed to `init()`, after apply_defaults() has been 
updated. Used to "recreate" the task
     # when mapping
     __init_kwargs: dict[str, Any]
 
     # Set to True before calling execute method
-    _lock_for_execution = False
+    _lock_for_execution: bool = False
 
     # Set to True for an operator instantiated by a mapped operator.
-    __from_mapped = False
+    __from_mapped: bool = False
 
     # TODO:
     # start_trigger_args: StartTriggerArgs | None = None
     # start_from_trigger: bool = False
 
+    def __setattr__(self: BaseOperator, key: str, value: Any):
+        if converter := getattr(self, f"_convert_{key}", None):
+            value = converter(value)
+        super().__setattr__(key, value)
+        if self.__from_mapped or self._lock_for_execution:
+            return  # Skip any custom behavior for validation and during 
execute.
+        if key in self.__init_kwargs:
+            self.__init_kwargs[key] = value
+        if self.__instantiated and key in self.template_fields:
+            # Resolve upstreams set by assigning an XComArg after initializing
+            # an operator, example:
+            #   op = BashOperator()
+            #   op.bash_command = "sleep 1"
+            self._set_xcomargs_dependency(key, value)
+
     def __init__(
         self,
+        *,
         task_id: str,
         owner: str = DEFAULT_OWNER,
-        email: str | Iterable[str] | None = None,
+        email: str | Sequence[str] | None = None,
         retries: int | None = DEFAULT_RETRIES,
         retry_delay: timedelta | float = DEFAULT_RETRY_DELAY,
         retry_exponential_backoff: bool = False,
@@ -513,9 +632,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         params: MutableMapping | None = None,
         default_args: dict | None = None,
         priority_weight: int = DEFAULT_PRIORITY_WEIGHT,
-        # TODO:
-        # weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
-        weight_rule: str = DEFAULT_WEIGHT_RULE,
+        weight_rule: str | PriorityWeightStrategy = DEFAULT_WEIGHT_RULE,
         queue: str = DEFAULT_QUEUE,
         pool: str | None = None,
         pool_slots: int = DEFAULT_POOL_SLOTS,
@@ -531,7 +648,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         trigger_rule: str = DEFAULT_TRIGGER_RULE,
         resources: dict[str, Any] | None = None,
         run_as_user: str | None = None,
-        task_concurrency: int | None = None,
         map_index_template: str | None = None,
         max_active_tis_per_dag: int | None = None,
         max_active_tis_per_dagrun: int | None = None,
@@ -554,25 +670,23 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
     ):
         from airflow.sdk.definitions.contextmanager import DagContext, 
TaskGroupContext
 
-        self.__init_kwargs = {}
+        self.task_id = task_group.child_id(task_id) if task_group else task_id
+        if not self.__from_mapped and task_group:
+            task_group.add(self)
+
+        dag = dag or DagContext.get_current()
+        task_group = task_group or TaskGroupContext.get_current(dag)
 
-        super().__init__()
+        super().__init__(dag=dag, task_group=task_group)
 
         kwargs.pop("_airflow_mapped_validation_only", None)
         if kwargs:
-            raise RuntimeError(
+            raise TypeError(
                 f"Invalid arguments were passed to {self.__class__.__name__} 
(task_id: {task_id}). "
                 + f"Invalid arguments were:\n**kwargs: {kwargs}",
             )
         validate_key(task_id)
 
-        dag = dag or DagContext.get_current()
-        task_group = task_group or TaskGroupContext.get_current(dag)
-
-        self.task_id = task_group.child_id(task_id) if task_group else task_id
-        if not self.__from_mapped and task_group:
-            task_group.add(self)
-
         self.owner = owner
         self.email = email
 
@@ -591,9 +705,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         # self._pre_execute_hook = pre_execute
         # self._post_execute_hook = post_execute
 
-        if start_date and not isinstance(start_date, datetime):
-            self.log.warning("start_date for %s isn't datetime.datetime", self)
-        elif start_date:
+        if start_date:
             self.start_date = timezone.convert_to_utc(start_date)
 
         if end_date:
@@ -610,7 +722,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         self.run_as_user = run_as_user
         # TODO:
         # self.retries = parse_retries(retries)
-        self.retries = int(retries)
+        self.retries = retries
         self.queue = queue
         self.pool = "default" if pool is None else pool
         self.pool_slots = pool_slots
@@ -620,15 +732,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         self.sla = sla
 
         """
-        if trigger_rule == "none_failed_or_skipped":
-            warnings.warn(
-                "none_failed_or_skipped Trigger Rule is deprecated. "
-                "Please use `none_failed_min_one_success`.",
-                RemovedInAirflow3Warning,
-                stacklevel=2,
-            )
-            trigger_rule = TriggerRule.NONE_FAILED_MIN_ONE_SUCCESS
-
         # if not TriggerRule.is_valid(trigger_rule):
         #     raise AirflowException(
         #         f"The trigger_rule must be one of 
{TriggerRule.all_triggers()},"
@@ -637,6 +740,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         #
         # self.trigger_rule: TriggerRule = TriggerRule(trigger_rule)
         # FailStopDagInvalidTriggerRule.check(dag=dag, 
trigger_rule=self.trigger_rule)
+        """
 
         self.depends_on_past: bool = depends_on_past
         self.ignore_first_depends_on_past: bool = ignore_first_depends_on_past
@@ -645,34 +749,21 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         if wait_for_downstream:
             self.depends_on_past = True
 
-        self.retry_delay = coerce_timedelta(retry_delay, key="retry_delay")
+        self.retry_delay = retry_delay
         self.retry_exponential_backoff = retry_exponential_backoff
-        self.max_retry_delay = (
-            max_retry_delay
-            if max_retry_delay is None
-            else coerce_timedelta(max_retry_delay, key="max_retry_delay")
-        )
+        if max_retry_delay is not None:
+            self.max_retry_delay = max_retry_delay
+
+        self.resources = resources
 
+        """
         # At execution_time this becomes a normal dict
         self.params: ParamsDict | dict = ParamsDict(params)
-        if priority_weight is not None and not isinstance(priority_weight, 
int):
-            raise AirflowException(
-                f"`priority_weight` for task '{self.task_id}' only accepts 
integers, "
-                f"received '{type(priority_weight)}'."
-            )
+        """
+
         self.priority_weight = priority_weight
         self.weight_rule = 
validate_and_load_priority_weight_strategy(weight_rule)
-        self.resources = coerce_resources(resources)
-        if task_concurrency and not max_active_tis_per_dag:
-            # TODO: Remove in Airflow 3.0
-            warnings.warn(
-                "The 'task_concurrency' parameter is deprecated. Please use 
'max_active_tis_per_dag'.",
-                RemovedInAirflow3Warning,
-                stacklevel=2,
-            )
-            max_active_tis_per_dag = task_concurrency
 
-        """
         self.max_active_tis_per_dag: int | None = max_active_tis_per_dag
         self.max_active_tis_per_dagrun: int | None = max_active_tis_per_dagrun
         self.do_xcom_push: bool = do_xcom_push
@@ -684,22 +775,18 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         self.doc_yaml = doc_yaml
         self.doc_rst = doc_rst
         self.doc = doc
-        # Populate the display field only if provided and different from task 
id
-        self._task_display_property_value = (
-            task_display_name if task_display_name and task_display_name != 
task_id else None
-        )
 
-        if dag:
-            self.dag = dag
+        self._task_display_name = task_display_name
+
+        self.allow_nested_operators = allow_nested_operators
+        self.inlets: list = []
+        self.outlets: list = []
 
         """
         self._log_config_logger_name = "airflow.task.operators"
         self._logger_name = logger_name
-        self.allow_nested_operators: bool = allow_nested_operators
 
         # Lineage
-        self.inlets: list = []
-        self.outlets: list = []
 
         if inlets:
             self.inlets = (
@@ -718,6 +805,7 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
                     outlets,
                 ]
             )
+        """
 
         if isinstance(self.template_fields, str):
             warnings.warn(
@@ -731,11 +819,11 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
 
         self._is_setup = False
         self._is_teardown = False
-        if SetupTeardownContext.active:
-            SetupTeardownContext.update_context_map(self)
+        # TODO: Task-SDK
+        # if SetupTeardownContext.active:
+        #     SetupTeardownContext.update_context_map(self)
 
         validate_instance_args(self, BASEOPERATOR_ARGS_EXPECTED_TYPES)
-        """
 
     def __eq__(self, other):
         if type(self) is type(other):
@@ -810,19 +898,6 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
 
         return self
 
-    def __setattr__(self, key, value):
-        super().__setattr__(key, value)
-        if self.__from_mapped or self._lock_for_execution:
-            return  # Skip any custom behavior for validation and during 
execute.
-        if key in self.__init_kwargs:
-            self.__init_kwargs[key] = value
-        if self.__instantiated and key in self.template_fields:
-            # Resolve upstreams set by assigning an XComArg after initializing
-            # an operator, example:
-            #   op = BashOperator()
-            #   op.bash_command = "sleep 1"
-            self.set_xcomargs_dependencies()
-
     def add_inlets(self, inlets: Iterable[Any]):
         """Set inlets to this operator."""
         self.inlets.extend(inlets)
@@ -832,45 +907,77 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         self.outlets.extend(outlets)
 
     def get_dag(self) -> DAG | None:
-        return self._dag
+        return self.dag
 
-    @property  # type: ignore[override]
-    def dag(self) -> DAG:  # type: ignore[override]
-        """Returns the Operator's DAG if set, otherwise raises an error."""
-        if self._dag:
-            return self._dag
-        else:
-            raise RuntimeError(f"Operator {self} has not been assigned to a 
DAG yet")
-
-    @dag.setter
-    def dag(self, dag: DAG | None):
+    def _convert_dag(self, dag: DAG | None | AttributeRemoved) -> DAG | None | 
AttributeRemoved:
         """Operators can be assigned to one DAG, one time. Repeat assignments 
to that same DAG are ok."""
-        from .dag import DAG
+        from airflow.sdk.definitions.dag import DAG
 
         if dag is None:
-            self._dag = None
-            return
+            return dag
+
+        # if set to removed, then just set and exit
+        if self.dag.__class__ is AttributeRemoved:
+            return dag
+        # if setting to removed, then just set and exit
+        if dag.__class__ is AttributeRemoved:
+            return AttributeRemoved("_dag")  # type: ignore[assignment]
+
         if not isinstance(dag, DAG):
             raise TypeError(f"Expected DAG; received {dag.__class__.__name__}")
-        elif self.has_dag() and self.dag is not dag:
+        elif self.dag is not None and self.dag is not dag:
             raise ValueError(f"The DAG assigned to {self} can not be changed.")
 
         if self.__from_mapped:
             pass  # Don't add to DAG -- the mapped task takes the place.
         elif dag.task_dict.get(self.task_id) is not self:
             dag.add_task(self)
-
-        self._dag = dag
+        return dag
+
+    @staticmethod
+    def _convert_retries(retries: Any) -> int | None:
+        if retries is None:
+            return 0
+        elif type(retries) == int:  # noqa: E721
+            return retries
+        try:
+            parsed_retries = int(retries)
+        except (TypeError, ValueError):
+            raise TypeError(f"'retries' type must be int, not 
{type(retries).__name__}")
+        return parsed_retries
+
+    @staticmethod
+    def _convert_timedelta(value: float | timedelta) -> timedelta:
+        if isinstance(value, timedelta):
+            return value
+        return timedelta(seconds=value)
+
+    _convert_retry_delay = _convert_timedelta
+    _convert_max_retry_delay = _convert_timedelta
+
+    @staticmethod
+    def _convert_resources(resources: dict[str, Any] | None) -> Resources | 
None:
+        if resources is None:
+            return None
+        return Resources(**resources)
 
     @property
     def task_display_name(self) -> str:
-        return self._task_display_property_value or self.task_id
+        return self._task_display_name or self.task_id
 
     def has_dag(self):
         """Return True if the Operator has been assigned to a DAG."""
-        return self._dag is not None
+        return self.dag is not None
+
+    def _set_xcomargs_dependencies(self) -> None:
+        from airflow.models.xcom_arg import XComArg
+
+        for field in self.template_fields:
+            arg = getattr(self, field, NOTSET)
+            if arg is not NOTSET:
+                XComArg.apply_upstream_relationship(self, arg)
 
-    def set_xcomargs_dependencies(self) -> None:
+    def _set_xcomargs_dependency(self, field: str, newvalue: Any) -> None:
         """
         Resolve upstream dependencies of a task.
 
@@ -892,10 +999,9 @@ class BaseOperator(AbstractOperator, 
metaclass=BaseOperatorMeta):
         """
         from airflow.models.xcom_arg import XComArg
 
-        for field in self.template_fields:
-            if hasattr(self, field):
-                arg = getattr(self, field)
-                XComArg.apply_upstream_relationship(self, arg)
+        if field not in self.template_fields:
+            return
+        XComArg.apply_upstream_relationship(self, newvalue)
 
     def on_kill(self) -> None:
         """
diff --git a/task_sdk/src/airflow/sdk/definitions/dag.py 
b/task_sdk/src/airflow/sdk/definitions/dag.py
index f80d8f7f71..5cf76da488 100644
--- a/task_sdk/src/airflow/sdk/definitions/dag.py
+++ b/task_sdk/src/airflow/sdk/definitions/dag.py
@@ -27,7 +27,7 @@ import sys
 import weakref
 from collections import abc
 from collections.abc import Collection, Iterable, Iterator
-from datetime import datetime, timedelta
+from datetime import datetime, timedelta, timezone
 from inspect import signature
 from re import Pattern
 from typing import (
@@ -51,7 +51,6 @@ from airflow import settings
 from airflow.assets import Asset, AssetAlias, BaseAsset
 from airflow.configuration import conf as airflow_conf
 from airflow.exceptions import (
-    AirflowException,
     DuplicateTaskIdFound,
     FailStopDagInvalidTriggerRule,
     ParamValidationError,
@@ -62,6 +61,13 @@ from airflow.sdk.definitions.abstractoperator import 
AbstractOperator
 from airflow.sdk.definitions.baseoperator import BaseOperator
 from airflow.stats import Stats
 from airflow.timetables.base import Timetable
+from airflow.timetables.interval import CronDataIntervalTimetable, 
DeltaDataIntervalTimetable
+from airflow.timetables.simple import (
+    AssetTriggeredTimetable,
+    ContinuousTimetable,
+    NullTimetable,
+    OnceTimetable,
+)
 from airflow.utils.dag_cycle_tester import check_cycle
 from airflow.utils.decorators import fixup_decorator_warning_stack
 from airflow.utils.trigger_rule import TriggerRule
@@ -70,15 +76,17 @@ from airflow.utils.types import NOTSET, EdgeInfoType
 if TYPE_CHECKING:
     from airflow.decorators import TaskDecoratorCollection
     from airflow.models.operator import Operator
-    from airflow.utils.taskgroup import TaskGroup
+    from airflow.sdk.definitions.taskgroup import TaskGroup
 
 log = logging.getLogger(__name__)
 
-DEFAULT_VIEW_PRESETS = ["grid", "graph", "duration", "gantt", "landing_times"]
-ORIENTATION_PRESETS = ["LR", "TB", "RL", "BT"]
-
 TAG_MAX_LEN = 100
 
+__all__ = [
+    "DAG",
+    "dag",
+]
+
 
 # TODO: Task-SDK
 class Context: ...
@@ -135,7 +143,25 @@ DAG_ARGS_EXPECTED_TYPES = {
 }
 
 
-@attrs.define
+def _create_timetable(interval: ScheduleInterval, timezone: Timezone | 
FixedTimezone) -> Timetable:
+    """Create a Timetable instance from a plain ``schedule`` value."""
+    if interval is None:
+        return NullTimetable()
+    if interval == "@once":
+        return OnceTimetable()
+    if interval == "@continuous":
+        return ContinuousTimetable()
+    if isinstance(interval, (timedelta, relativedelta)):
+        return DeltaDataIntervalTimetable(interval)
+    if isinstance(interval, str):
+        if airflow_conf.getboolean("scheduler", "create_cron_data_intervals"):
+            return CronDataIntervalTimetable(interval, timezone)
+        else:
+            return CronTriggerTimetable(interval, timezone=timezone)
+    raise ValueError(f"{interval!r} is not a valid schedule.")
+
+
+@attrs.define(kw_only=True)
 class DAG:
     """
     A dag (directed acyclic graph) is a collection of tasks with directional 
dependencies.
@@ -266,11 +292,11 @@ class DAG:
     # below in sync. (Search for 'def dag(' in this file.)
     dag_id: str
     description: str | None = None
-    schedule: ScheduleArg = NOTSET
-    schedule_interval: ScheduleIntervalArg = NOTSET
-    timetable: Timetable | None = None
     start_date: datetime | None = None
     end_date: datetime | None = None
+    timezone: timezone = timezone.utc
+    schedule: ScheduleArg = attrs.field(default=None, 
on_setattr=attrs.setters.NO_OP)
+    timetable: Timetable = attrs.field(init=False)
     full_filepath: str | None = None
     template_searchpath: str | Iterable[str] | None = None
     # template_undefined: type[jinja2.StrictUndefined] = jinja2.StrictUndefined
@@ -285,7 +311,7 @@ class DAG:
     # on_success_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
     # on_failure_callback: None | DagStateChangeCallback | 
list[DagStateChangeCallback] = None
     doc_md: str | None = None
-    params: abc.MutableMapping | None = None
+    params: abc.MutableMapping | None = attrs.field(default=None)
     access_control: dict | None = None
     is_paused_upon_creation: bool | None = None
     jinja_environment_kwargs: dict | None = None
@@ -310,6 +336,55 @@ class DAG:
 
         return TaskGroup.create_root(dag=self)
 
+    @timetable.default
+    def _set_schedule(self):
+        schedule = self.schedule
+        delattr(self, "schedule")
+        if isinstance(schedule, Timetable):
+            return schedule
+        elif isinstance(schedule, BaseAsset):
+            return AssetTriggeredTimetable(schedule)
+        elif isinstance(schedule, Collection) and not isinstance(schedule, 
str):
+            if not all(isinstance(x, (Asset, AssetAlias)) for x in schedule):
+                raise ValueError("All elements in 'schedule' should be assets 
or asset aliases")
+            return AssetTriggeredTimetable(AssetAll(*schedule))
+        else:
+            return _create_timetable(schedule, self.timezone)
+
+    @params.validator
+    def _validate_params(self, attr, val: abc.MutableMapping | None):
+        """
+        Validate Param values when the DAG has schedule defined.
+
+        Raise exception if there are any Params which can not be resolved by 
their schema definition.
+
+        This will also merge in params from default_args
+        """
+        # TODO: Task-SDK
+        from airflow.models.param import ParamsDict
+
+        val = val or {}
+
+        # merging potentially conflicting default_args['params'] into params
+        if "params" in self.default_args:
+            val.update(self.default_args["params"])
+            del self.default_args["params"]
+
+        params = ParamsDict(val)
+        object.__setattr__(self, "params", params)
+        if not self.timetable or not self.timetable.can_be_scheduled:
+            return
+
+        try:
+            params.validate()
+        except ParamValidationError as pverr:
+            raise ValueError(
+                f"DAG {self.dag_id!r} is not allowed to define a Schedule, "
+                "as there are required params without default values, or the 
default values are not valid."
+            ) from pverr
+
+        # check self.params and convert them into ParamsDict
+
     def __repr__(self):
         return f"<DAG: {self.dag_id}>"
 
@@ -832,23 +907,6 @@ class DAG:
         """
         self.edge_info.setdefault(upstream_task_id, {})[downstream_task_id] = 
info
 
-    def validate_schedule_and_params(self):
-        """
-        Validate Param values when the DAG has schedule defined.
-
-        Raise exception if there are any Params which can not be resolved by 
their schema definition.
-        """
-        if not self.timetable.can_be_scheduled:
-            return
-
-        try:
-            self.params.validate()
-        except ParamValidationError as pverr:
-            raise AirflowException(
-                "DAG is not allowed to define a Schedule, "
-                "if there are any required params without default values or 
default values are not valid."
-            ) from pverr
-
     def iter_invalid_owner_links(self) -> Iterator[tuple[str, str]]:
         """
         Parse a given link, and verifies if it's a valid URL, or a 'mailto' 
link.
diff --git a/task_sdk/src/airflow/sdk/definitions/node.py 
b/task_sdk/src/airflow/sdk/definitions/node.py
index 1d98028467..40236cac8b 100644
--- a/task_sdk/src/airflow/sdk/definitions/node.py
+++ b/task_sdk/src/airflow/sdk/definitions/node.py
@@ -108,8 +108,8 @@ class DAGNode(DependencyMixin, metaclass=ABCMeta):
         edge_modifier: EdgeModifier | None = None,
     ) -> None:
         """Set relatives for the task or task list."""
-        from airflow.models.baseoperator import BaseOperator
         from airflow.models.mappedoperator import MappedOperator
+        from airflow.sdk.definitions.baseoperator import BaseOperator
 
         if not isinstance(task_or_task_list, Sequence):
             task_or_task_list = [task_or_task_list]
diff --git a/task_sdk/src/airflow/sdk/types.py 
b/task_sdk/src/airflow/sdk/types.py
index 505ee4cb19..a412509ba5 100644
--- a/task_sdk/src/airflow/sdk/types.py
+++ b/task_sdk/src/airflow/sdk/types.py
@@ -17,7 +17,7 @@
 
 from __future__ import annotations
 
-from typing import TYPE_CHECKING
+from typing import TYPE_CHECKING, Any
 
 
 class ArgNotSet:
@@ -43,7 +43,24 @@ NOTSET = ArgNotSet()
 if TYPE_CHECKING:
     import logging
 
+    from airflow.sdk.definitions.node import DAGNode
+
     Logger = logging.Logger
 else:
 
     class Logger: ...
+
+
+def validate_instance_args(instance: DAGNode, expected_arg_types: dict[str, 
Any]) -> None:
+    """Validate that the instance has the expected types for the arguments."""
+    from airflow.sdk.definitions.taskgroup import TaskGroup
+
+    typ = "task group" if isinstance(instance, TaskGroup) else "task"
+
+    for arg_name, expected_arg_type in expected_arg_types.items():
+        instance_arg_value = getattr(instance, arg_name, None)
+        if instance_arg_value is not None and not 
isinstance(instance_arg_value, expected_arg_type):
+            raise TypeError(
+                f"{arg_name!r} for {typ} {instance.node_id!r} expects 
{expected_arg_type}, got {type(instance_arg_value)} with value "
+                f"{instance_arg_value!r}"
+            )
diff --git a/task_sdk/tests/defintions/test_baseoperator.py 
b/task_sdk/tests/defintions/test_baseoperator.py
new file mode 100644
index 0000000000..0222de3ee3
--- /dev/null
+++ b/task_sdk/tests/defintions/test_baseoperator.py
@@ -0,0 +1,305 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+
+from __future__ import annotations
+
+import warnings
+from datetime import UTC, datetime, timedelta
+
+import pytest
+
+from airflow.sdk.definitions.baseoperator import BaseOperator, BaseOperatorMeta
+from airflow.sdk.definitions.dag import DAG
+from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, 
_UpstreamPriorityWeightStrategy
+
+DEFAULT_DATE = datetime(2016, 1, 1, tzinfo=UTC)
+
+
+# Essentially similar to airflow.models.baseoperator.BaseOperator
+class FakeOperator(metaclass=BaseOperatorMeta):
+    def __init__(self, test_param, params=None, default_args=None):
+        self.test_param = test_param
+
+    def _set_xcomargs_dependencies(self): ...
+
+
+class FakeSubClass(FakeOperator):
+    def __init__(self, test_sub_param, test_param, **kwargs):
+        super().__init__(test_param=test_param, **kwargs)
+        self.test_sub_param = test_sub_param
+
+
+class DeprecatedOperator(BaseOperator):
+    def __init__(self, **kwargs):
+        warnings.warn("This operator is deprecated.", DeprecationWarning, 
stacklevel=2)
+        super().__init__(**kwargs)
+
+    def execute(self, context: Context):
+        pass
+
+
+class MockOperator(BaseOperator):
+    """Operator for testing purposes."""
+
+    template_fields: Sequence[str] = ("arg1", "arg2")
+
+    def __init__(self, arg1: str = "", arg2: str = "", **kwargs):
+        super().__init__(**kwargs)
+        self.arg1 = arg1
+        self.arg2 = arg2
+
+    def execute(self, context: Context):
+        pass
+
+
+class TestBaseOperator:
+    # Since we have a custom metaclass, lets double check the behaviour of 
passing args in the wrong way (args
+    # etc)
+    def test_kwargs_only(self):
+        with pytest.raises(TypeError, match="keyword arguments"):
+            BaseOperator("task_id")
+
+    def test_missing_kwarg(self):
+        with pytest.raises(TypeError, match="missing keyword argument"):
+            FakeOperator(task_id="task_id")
+
+    def test_missing_kwargs(self):
+        with pytest.raises(TypeError, match="missing keyword arguments"):
+            FakeSubClass(task_id="task_id")
+
+    def test_hash(self):
+        """Two operators created equally should hash equaylly"""
+        # Include a "non-hashable" type too
+        assert hash(MockOperator(task_id="one", retries=1024 * 1024, 
arg1="abcef", params={"a": 1})) == hash(
+            MockOperator(task_id="one", retries=1024 * 1024, arg1="abcef", 
params={"a": 2})
+        )
+
+    def test_expand(self):
+        op = FakeOperator(test_param=True)
+        assert op.test_param
+
+        with pytest.raises(TypeError, match="missing keyword argument 
'test_param'"):
+            FakeSubClass(test_sub_param=True)
+
+    def test_default_args(self):
+        default_args = {"test_param": True}
+        op = FakeOperator(default_args=default_args)
+        assert op.test_param
+
+        default_args = {"test_param": True, "test_sub_param": True}
+        op = FakeSubClass(default_args=default_args)
+        assert op.test_param
+        assert op.test_sub_param
+
+        default_args = {"test_param": True}
+        op = FakeSubClass(default_args=default_args, test_sub_param=True)
+        assert op.test_param
+        assert op.test_sub_param
+
+        with pytest.raises(TypeError, match="missing keyword argument 
'test_sub_param'"):
+            FakeSubClass(default_args=default_args)
+
+    def test_execution_timeout_type(self):
+        with pytest.raises(
+            ValueError, match="execution_timeout must be timedelta object but 
passed as type: <class 'str'>"
+        ):
+            BaseOperator(task_id="test", execution_timeout="1")
+
+        with pytest.raises(
+            ValueError, match="execution_timeout must be timedelta object but 
passed as type: <class 'int'>"
+        ):
+            BaseOperator(task_id="test", execution_timeout=1)
+
+    def test_incorrect_default_args(self):
+        default_args = {"test_param": True, "extra_param": True}
+        op = FakeOperator(default_args=default_args)
+        assert op.test_param
+
+        default_args = {"random_params": True}
+        with pytest.raises(TypeError, match="missing keyword argument 
'test_param'"):
+            FakeOperator(default_args=default_args)
+
+    def test_incorrect_priority_weight(self):
+        error_msg = "'priority_weight' for task 'test_op' expects <class 
'int'>, got <class 'str'>"
+        with pytest.raises(TypeError, match=error_msg):
+            BaseOperator(task_id="test_op", priority_weight="2")
+
+    def test_illegal_args_forbidden(self):
+        """
+        Tests that operators raise exceptions on illegal arguments when
+        illegal arguments are not allowed.
+        """
+        msg = r"Invalid arguments were passed to BaseOperator \(task_id: 
test_illegal_args\)"
+        with pytest.raises(TypeError, match=msg):
+            BaseOperator(
+                task_id="test_illegal_args",
+                illegal_argument_1234="hello?",
+            )
+
+    def test_invalid_type_for_default_arg(self):
+        error_msg = "'max_active_tis_per_dag' for task 'test' expects <class 
'int'>, got <class 'str'> with value 'not_an_int'"
+        with pytest.raises(TypeError, match=error_msg):
+            BaseOperator(task_id="test", 
default_args={"max_active_tis_per_dag": "not_an_int"})
+
+    def test_invalid_type_for_operator_arg(self):
+        error_msg = "'max_active_tis_per_dag' for task 'test' expects <class 
'int'>, got <class 'str'> with value 'not_an_int'"
+        with pytest.raises(TypeError, match=error_msg):
+            BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int")
+
+    def test_weight_rule_default(self):
+        op = BaseOperator(task_id="test_task")
+        assert _DownstreamPriorityWeightStrategy() == op.weight_rule
+
+    def test_weight_rule_override(self):
+        op = BaseOperator(task_id="test_task", weight_rule="upstream")
+        assert _UpstreamPriorityWeightStrategy() == op.weight_rule
+
+    def test_warnings_are_properly_propagated(self):
+        with pytest.warns(DeprecationWarning) as warnings:
+            DeprecatedOperator(task_id="test")
+            assert len(warnings) == 1
+            warning = warnings[0]
+            # Here we check that the trace points to the place
+            # where the deprecated class was used
+            assert warning.filename == __file__
+
+    def test_setattr_performs_no_custom_action_at_execute_time(self, 
spy_agency):
+        from airflow.models.xcom_arg import XComArg
+
+        op = MockOperator(task_id="test_task")
+        # TODO: Task-SDK
+        # op_copy = op.prepare_for_execution()
+        op_copy = op
+
+        spy_agency.spy_on(XComArg.apply_upstream_relationship, 
call_original=False)
+        op_copy.execute({})
+        assert XComArg.apply_upstream_relationship.called == False
+
+    def test_upstream_is_set_when_template_field_is_xcomarg(self):
+        with DAG("xcomargs_test", schedule=None):
+            op1 = BaseOperator(task_id="op1")
+            op2 = MockOperator(task_id="op2", arg1=op1.output)
+
+        assert op1.task_id in op2.upstream_task_ids
+        assert op2.task_id in op1.downstream_task_ids
+
+    def test_set_xcomargs_dependencies_works_recursively(self):
+        with DAG("xcomargs_test", schedule=None):
+            op1 = BaseOperator(task_id="op1")
+            op2 = BaseOperator(task_id="op2")
+            op3 = MockOperator(task_id="op3", arg1=[op1.output, op2.output])
+            op4 = MockOperator(task_id="op4", arg1={"op1": op1.output, "op2": 
op2.output})
+
+        assert op1.task_id in op3.upstream_task_ids
+        assert op2.task_id in op3.upstream_task_ids
+        assert op1.task_id in op4.upstream_task_ids
+        assert op2.task_id in op4.upstream_task_ids
+
+    def test_set_xcomargs_dependencies_works_when_set_after_init(self):
+        with DAG(dag_id="xcomargs_test", schedule=None):
+            op1 = BaseOperator(task_id="op1")
+            op2 = MockOperator(task_id="op2")
+            op2.arg1 = op1.output  # value is set after init
+
+        assert op1.task_id in op2.upstream_task_ids
+
+    def test_set_xcomargs_dependencies_error_when_outside_dag(self):
+        op1 = BaseOperator(task_id="op1")
+        with pytest.raises(ValueError):
+            MockOperator(task_id="op2", arg1=op1.output)
+
+    def test_cannot_change_dag(self):
+        with DAG(dag_id="dag1", schedule=None):
+            op1 = BaseOperator(task_id="op1")
+        with pytest.raises(ValueError, match="can not be changed"):
+            op1.dag = DAG(dag_id="dag2")
+
+
+def test_init_subclass_args():
+    class InitSubclassOp(BaseOperator):
+        _class_arg: Any
+
+        def __init_subclass__(cls, class_arg=None, **kwargs) -> None:
+            cls._class_arg = class_arg
+            super().__init_subclass__()
+
+        def execute(self, context: Context):
+            self.context_arg = context
+
+    class_arg = "foo"
+    context = {"key": "value"}
+
+    class ConcreteSubclassOp(InitSubclassOp, class_arg=class_arg):
+        pass
+
+    task = ConcreteSubclassOp(task_id="op1")
+    # TODO: Task-SDK
+    # task_copy = task.prepare_for_execution()
+    task_copy = task
+
+    task_copy.execute(context)
+
+    assert task_copy._class_arg == class_arg
+    assert task_copy.context_arg == context
+
+
+class CustomInt(int):
+    def __int__(self):
+        raise ValueError("Cannot cast to int")
+
+
+@pytest.mark.parametrize(
+    ("retries", "expected"),
+    [
+        pytest.param("foo", "'retries' type must be int, not str", 
id="string"),
+        pytest.param(CustomInt(10), "'retries' type must be int, not 
CustomInt", id="custom int"),
+    ],
+)
+def test_operator_retries_invalid(dag_maker, retries, expected):
+    with pytest.raises(TypeError) as ctx:
+        BaseOperator(task_id="test_illegal_args", retries=retries)
+    assert str(ctx.value) == expected
+
+
+@pytest.mark.parametrize(
+    ("retries", "expected"),
+    [
+        pytest.param(None, 0, id="None"),
+        pytest.param("5", 5, id="str"),
+        pytest.param(1, 1, id="int"),
+    ],
+)
+def test_operator_retries_conversion(retries, expected):
+    op = BaseOperator(
+        task_id="test_illegal_args",
+        retries=retries,
+    )
+    assert op.retries == expected
+
+
+def test_dag_level_retry_delay(dag_maker):
+    with DAG(dag_id="test_dag_level_retry_delay", default_args={"retry_delay": 
timedelta(seconds=100)}):
+        task1 = BaseOperator(task_id="test_no_explicit_retry_delay")
+
+        assert task1.retry_delay == timedelta(seconds=100)
+
+
+def test_task_level_retry_delay():
+    with DAG(dag_id="test_task_level_retry_delay", 
default_args={"retry_delay": timedelta(seconds=100)}):
+        task1 = BaseOperator(task_id="test_no_explicit_retry_delay", 
retry_delay=200)
+
+        assert task1.retry_delay == timedelta(seconds=200)
diff --git a/task_sdk/tests/defintions/test_dag.py 
b/task_sdk/tests/defintions/test_dag.py
index e07e3a8bfa..29249745ff 100644
--- a/task_sdk/tests/defintions/test_dag.py
+++ b/task_sdk/tests/defintions/test_dag.py
@@ -14,11 +14,14 @@
 # KIND, either express or implied.  See the License for the
 # specific language governing permissions and limitations
 # under the License.
-
 from __future__ import annotations
 
-from datetime import UTC, datetime
+from datetime import UTC, datetime, timedelta
+
+import pytest
 
+from airflow.exceptions import DuplicateTaskIdFound
+from airflow.models.param import Param, ParamsDict
 from airflow.sdk.definitions.baseoperator import BaseOperator
 from airflow.sdk.definitions.dag import DAG
 
@@ -76,3 +79,92 @@ class TestDag:
         assert op7.dag == dag
         assert op8.dag == dag
         assert op9.dag == dag2
+
+    def test_params_not_passed_is_empty_dict(self):
+        """
+        Test that when 'params' is _not_ passed to a new Dag, that the params
+        attribute is set to an empty dictionary.
+        """
+        dag = DAG("test-dag", schedule=None)
+
+        assert isinstance(dag.params, ParamsDict)
+        assert 0 == len(dag.params)
+
+    def test_params_passed_and_params_in_default_args_no_override(self):
+        """
+        Test that when 'params' exists as a key passed to the default_args dict
+        in addition to params being passed explicitly as an argument to the
+        dag, that the 'params' key of the default_args dict is merged with the
+        dict of the params argument.
+        """
+        params1 = {"parameter1": 1}
+        params2 = {"parameter2": 2}
+
+        dag = DAG("test-dag", schedule=None, default_args={"params": params1}, 
params=params2)
+
+        assert params1["parameter1"] == dag.params["parameter1"]
+        assert params2["parameter2"] == dag.params["parameter2"]
+
+    def test_not_none_schedule_with_non_default_params(self):
+        """
+        Test if there is a DAG with a schedule and have some params that don't 
have a default value raise a
+        error while DAG parsing. (Because we can't schedule them if there we 
don't know what value to use)
+        """
+        params = {"param1": Param(type="string")}
+
+        with pytest.raises(ValueError):
+            DAG("my-dag", schedule=timedelta(days=1), start_date=DEFAULT_DATE, 
params=params)
+
+    def test_roots(self):
+        """Verify if dag.roots returns the root tasks of a DAG."""
+        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
+            op1 = BaseOperator(task_id="t1")
+            op2 = BaseOperator(task_id="t2")
+            op3 = BaseOperator(task_id="t3")
+            op4 = BaseOperator(task_id="t4")
+            op5 = BaseOperator(task_id="t5")
+            [op1, op2] >> op3 >> [op4, op5]
+
+            assert set(dag.roots) == {op1, op2}
+
+    def test_leaves(self):
+        """Verify if dag.leaves returns the leaf tasks of a DAG."""
+        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
+            op1 = BaseOperator(task_id="t1")
+            op2 = BaseOperator(task_id="t2")
+            op3 = BaseOperator(task_id="t3")
+            op4 = BaseOperator(task_id="t4")
+            op5 = BaseOperator(task_id="t5")
+            [op1, op2] >> op3 >> [op4, op5]
+
+            assert set(dag.leaves) == {op4, op5}
+
+    def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
+        """Verify tasks with Duplicate task_id raises error"""
+        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
+            op1 = BaseOperator(task_id="t1")
+            with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has 
already been added to the DAG"):
+                BaseOperator(task_id="t1")
+
+        assert dag.task_dict == {op1.task_id: op1}
+
+    def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
+        """Verify tasks with Duplicate task_id raises error"""
+        dag = DAG("test_dag", schedule=None, start_date=DEFAULT_DATE)
+        op1 = BaseOperator(task_id="t1", dag=dag)
+        with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has 
already been added to the DAG"):
+            BaseOperator(task_id="t1", dag=dag)
+
+        assert dag.task_dict == {op1.task_id: op1}
+
+    def test_duplicate_task_ids_for_same_task_is_allowed(self):
+        """Verify that same tasks with Duplicate task_id do not raise error"""
+        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
+            op1 = op2 = BaseOperator(task_id="t1")
+            op3 = BaseOperator(task_id="t3")
+            op1 >> op3
+            op2 >> op3
+
+        assert op1 == op2
+        assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3}
+        assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3}
diff --git a/tests/models/test_baseoperator.py 
b/tests/models/test_baseoperator.py
index 999529e14a..eea5dae269 100644
--- a/tests/models/test_baseoperator.py
+++ b/tests/models/test_baseoperator.py
@@ -22,7 +22,7 @@ import logging
 import uuid
 from collections import defaultdict
 from datetime import date, datetime, timedelta
-from typing import TYPE_CHECKING, Any, NamedTuple
+from typing import NamedTuple
 from unittest import mock
 
 import jinja2
@@ -32,9 +32,7 @@ from airflow.decorators import task as task_decorator
 from airflow.exceptions import AirflowException, FailStopDagInvalidTriggerRule
 from airflow.lineage.entities import File
 from airflow.models.baseoperator import (
-    BASEOPERATOR_ARGS_EXPECTED_TYPES,
     BaseOperator,
-    BaseOperatorMeta,
     chain,
     chain_linear,
     cross_downstream,
@@ -43,7 +41,6 @@ from airflow.models.dag import DAG
 from airflow.models.dagrun import DagRun
 from airflow.models.taskinstance import TaskInstance
 from airflow.providers.common.sql.operators import sql
-from airflow.task.priority_strategy import _DownstreamPriorityWeightStrategy, 
_UpstreamPriorityWeightStrategy
 from airflow.utils.edgemodifier import Label
 from airflow.utils.task_group import TaskGroup
 from airflow.utils.template import literal
@@ -51,10 +48,7 @@ from airflow.utils.trigger_rule import TriggerRule
 from airflow.utils.types import DagRunType
 from tests.models import DEFAULT_DATE
 
-from dev.tests_common.test_utils.mock_operators import DeprecatedOperator, 
MockOperator
-
-if TYPE_CHECKING:
-    from airflow.utils.context import Context
+from dev.tests_common.test_utils.mock_operators import MockOperator
 
 
 class ClassWithCustomAttributes:
@@ -83,93 +77,12 @@ object2 = ClassWithCustomAttributes(attr="{{ foo }}_2", 
ref=object1, template_fi
 setattr(object1, "ref", object2)
 
 
-# Essentially similar to airflow.models.baseoperator.BaseOperator
-class DummyClass(metaclass=BaseOperatorMeta):
-    def __init__(self, test_param, params=None, default_args=None):
-        self.test_param = test_param
-
-    def set_xcomargs_dependencies(self): ...
-
-
-class DummySubClass(DummyClass):
-    def __init__(self, test_sub_param, **kwargs):
-        super().__init__(**kwargs)
-        self.test_sub_param = test_sub_param
-
-
 class MockNamedTuple(NamedTuple):
     var1: str
     var2: str
 
 
-class CustomInt(int):
-    def __int__(self):
-        raise ValueError("Cannot cast to int")
-
-
 class TestBaseOperator:
-    def test_expand(self):
-        dummy = DummyClass(test_param=True)
-        assert dummy.test_param
-
-        with pytest.raises(AirflowException, match="missing keyword argument 
'test_param'"):
-            DummySubClass(test_sub_param=True)
-
-    def test_default_args(self):
-        default_args = {"test_param": True}
-        dummy_class = DummyClass(default_args=default_args)
-        assert dummy_class.test_param
-
-        default_args = {"test_param": True, "test_sub_param": True}
-        dummy_subclass = DummySubClass(default_args=default_args)
-        assert dummy_class.test_param
-        assert dummy_subclass.test_sub_param
-
-        default_args = {"test_param": True}
-        dummy_subclass = DummySubClass(default_args=default_args, 
test_sub_param=True)
-        assert dummy_class.test_param
-        assert dummy_subclass.test_sub_param
-
-        with pytest.raises(AirflowException, match="missing keyword argument 
'test_sub_param'"):
-            DummySubClass(default_args=default_args)
-
-    def test_execution_timeout_type(self):
-        with pytest.raises(
-            ValueError, match="execution_timeout must be timedelta object but 
passed as type: <class 'str'>"
-        ):
-            BaseOperator(task_id="test", execution_timeout="1")
-
-        with pytest.raises(
-            ValueError, match="execution_timeout must be timedelta object but 
passed as type: <class 'int'>"
-        ):
-            BaseOperator(task_id="test", execution_timeout=1)
-
-    def test_incorrect_default_args(self):
-        default_args = {"test_param": True, "extra_param": True}
-        dummy_class = DummyClass(default_args=default_args)
-        assert dummy_class.test_param
-
-        default_args = {"random_params": True}
-        with pytest.raises(AirflowException, match="missing keyword argument 
'test_param'"):
-            DummyClass(default_args=default_args)
-
-    def test_incorrect_priority_weight(self):
-        error_msg = "`priority_weight` for task 'test_op' only accepts 
integers, received '<class 'str'>'."
-        with pytest.raises(AirflowException, match=error_msg):
-            BaseOperator(task_id="test_op", priority_weight="2")
-
-    def test_illegal_args_forbidden(self):
-        """
-        Tests that operators raise exceptions on illegal arguments when
-        illegal arguments are not allowed.
-        """
-        msg = r"Invalid arguments were passed to BaseOperator \(task_id: 
test_illegal_args\)"
-        with pytest.raises(AirflowException, match=msg):
-            BaseOperator(
-                task_id="test_illegal_args",
-                illegal_argument_1234="hello?",
-            )
-
     def test_trigger_rule_validation(self):
         from airflow.models.abstractoperator import DEFAULT_TRIGGER_RULE
 
@@ -659,15 +572,6 @@ class TestBaseOperator:
         task4 > [inlet, outlet, extra]
         assert task4.get_outlet_defs() == [inlet, outlet, extra]
 
-    def test_warnings_are_properly_propagated(self):
-        with pytest.warns(DeprecationWarning) as warnings:
-            DeprecatedOperator(task_id="test")
-            assert len(warnings) == 1
-            warning = warnings[0]
-            # Here we check that the trace points to the place
-            # where the deprecated class was used
-            assert warning.filename == __file__
-
     def test_pre_execute_hook(self):
         hook = mock.MagicMock()
 
@@ -694,47 +598,6 @@ class TestBaseOperator:
         assert op_no_dag.start_date.tzinfo
         assert op_no_dag.end_date.tzinfo
 
-    def test_setattr_performs_no_custom_action_at_execute_time(self):
-        op = MockOperator(task_id="test_task")
-        op_copy = op.prepare_for_execution()
-
-        with 
mock.patch("airflow.models.baseoperator.BaseOperator.set_xcomargs_dependencies")
 as method_mock:
-            op_copy.execute({})
-        assert method_mock.call_count == 0
-
-    def test_upstream_is_set_when_template_field_is_xcomarg(self):
-        with DAG("xcomargs_test", schedule=None, default_args={"start_date": 
datetime.today()}):
-            op1 = BaseOperator(task_id="op1")
-            op2 = MockOperator(task_id="op2", arg1=op1.output)
-
-        assert op1 in op2.upstream_list
-        assert op2 in op1.downstream_list
-
-    def test_set_xcomargs_dependencies_works_recursively(self):
-        with DAG("xcomargs_test", schedule=None, default_args={"start_date": 
datetime.today()}):
-            op1 = BaseOperator(task_id="op1")
-            op2 = BaseOperator(task_id="op2")
-            op3 = MockOperator(task_id="op3", arg1=[op1.output, op2.output])
-            op4 = MockOperator(task_id="op4", arg1={"op1": op1.output, "op2": 
op2.output})
-
-        assert op1 in op3.upstream_list
-        assert op2 in op3.upstream_list
-        assert op1 in op4.upstream_list
-        assert op2 in op4.upstream_list
-
-    def test_set_xcomargs_dependencies_works_when_set_after_init(self):
-        with DAG(dag_id="xcomargs_test", schedule=None, 
default_args={"start_date": datetime.today()}):
-            op1 = BaseOperator(task_id="op1")
-            op2 = MockOperator(task_id="op2")
-            op2.arg1 = op1.output  # value is set after init
-
-        assert op1 in op2.upstream_list
-
-    def test_set_xcomargs_dependencies_error_when_outside_dag(self):
-        op1 = BaseOperator(task_id="op1")
-        with pytest.raises(AirflowException):
-            MockOperator(task_id="op2", arg1=op1.output)
-
     def test_invalid_trigger_rule(self):
         with pytest.raises(
             AirflowException,
@@ -745,14 +608,6 @@ class TestBaseOperator:
         ):
             BaseOperator(task_id="op1", trigger_rule="some_rule")
 
-    def test_weight_rule_default(self):
-        op = BaseOperator(task_id="test_task")
-        assert _DownstreamPriorityWeightStrategy() == op.weight_rule
-
-    def test_weight_rule_override(self):
-        op = BaseOperator(task_id="test_task", weight_rule="upstream")
-        assert _UpstreamPriorityWeightStrategy() == op.weight_rule
-
     # ensure the default logging config is used for this test, no matter what 
ran before
     @pytest.mark.usefixtures("reset_logging_config")
     def test_logging_propogated_by_default(self, caplog):
@@ -763,92 +618,6 @@ class TestBaseOperator:
         # leaking a lot of state)
         assert caplog.messages == ["test"]
 
-    def test_invalid_type_for_default_arg(self):
-        error_msg = "'max_active_tis_per_dag' has an invalid type <class 
'str'> with value not_an_int, expected type is <class 'int'>"
-        with pytest.raises(TypeError, match=error_msg):
-            BaseOperator(task_id="test", 
default_args={"max_active_tis_per_dag": "not_an_int"})
-
-    def test_invalid_type_for_operator_arg(self):
-        error_msg = "'max_active_tis_per_dag' has an invalid type <class 
'str'> with value not_an_int, expected type is <class 'int'>"
-        with pytest.raises(TypeError, match=error_msg):
-            BaseOperator(task_id="test", max_active_tis_per_dag="not_an_int")
-
-    @mock.patch("airflow.models.baseoperator.validate_instance_args")
-    def test_baseoperator_init_validates_arg_types(self, 
mock_validate_instance_args):
-        operator = BaseOperator(task_id="test")
-
-        mock_validate_instance_args.assert_called_once_with(operator, 
BASEOPERATOR_ARGS_EXPECTED_TYPES)
-
-
-def test_init_subclass_args():
-    class InitSubclassOp(BaseOperator):
-        _class_arg: Any
-
-        def __init_subclass__(cls, class_arg=None, **kwargs) -> None:
-            cls._class_arg = class_arg
-            super().__init_subclass__()
-
-        def execute(self, context: Context):
-            self.context_arg = context
-
-    class_arg = "foo"
-    context = {"key": "value"}
-
-    class ConcreteSubclassOp(InitSubclassOp, class_arg=class_arg):
-        pass
-
-    task = ConcreteSubclassOp(task_id="op1")
-    task_copy = task.prepare_for_execution()
-
-    task_copy.execute(context)
-
-    assert task_copy._class_arg == class_arg
-    assert task_copy.context_arg == context
-
-
-@pytest.mark.db_test
-@pytest.mark.parametrize(
-    ("retries", "expected"),
-    [
-        pytest.param("foo", "'retries' type must be int, not str", 
id="string"),
-        pytest.param(CustomInt(10), "'retries' type must be int, not 
CustomInt", id="custom int"),
-    ],
-)
-def test_operator_retries_invalid(dag_maker, retries, expected):
-    with pytest.raises(AirflowException) as ctx:
-        with dag_maker():
-            BaseOperator(task_id="test_illegal_args", retries=retries)
-    assert str(ctx.value) == expected
-
-
-@pytest.mark.db_test
-@pytest.mark.parametrize(
-    ("retries", "expected"),
-    [
-        pytest.param(None, [], id="None"),
-        pytest.param(5, [], id="5"),
-        pytest.param(
-            "1",
-            [
-                (
-                    "airflow.models.baseoperator.BaseOperator",
-                    logging.WARNING,
-                    "Implicitly converting 'retries' from '1' to int",
-                ),
-            ],
-            id="str",
-        ),
-    ],
-)
-def test_operator_retries(caplog, dag_maker, retries, expected):
-    with caplog.at_level(logging.WARNING):
-        with dag_maker():
-            BaseOperator(
-                task_id="test_illegal_args",
-                retries=retries,
-            )
-    assert caplog.record_tuples == expected
-
 
 @pytest.mark.db_test
 def test_default_retry_delay(dag_maker):
@@ -858,24 +627,6 @@ def test_default_retry_delay(dag_maker):
         assert task1.retry_delay == timedelta(seconds=300)
 
 
-@pytest.mark.db_test
-def test_dag_level_retry_delay(dag_maker):
-    with dag_maker(dag_id="test_dag_level_retry_delay", 
default_args={"retry_delay": timedelta(seconds=100)}):
-        task1 = BaseOperator(task_id="test_no_explicit_retry_delay")
-
-        assert task1.retry_delay == timedelta(seconds=100)
-
-
-@pytest.mark.db_test
-def test_task_level_retry_delay(dag_maker):
-    with dag_maker(
-        dag_id="test_task_level_retry_delay", default_args={"retry_delay": 
timedelta(seconds=100)}
-    ):
-        task1 = BaseOperator(task_id="test_no_explicit_retry_delay", 
retry_delay=timedelta(seconds=200))
-
-        assert task1.retry_delay == timedelta(seconds=200)
-
-
 def test_deepcopy():
     # Test bug when copying an operator attached to a DAG
     with DAG("dag0", schedule=None, start_date=DEFAULT_DATE) as dag:
diff --git a/tests/models/test_dag.py b/tests/models/test_dag.py
index 67dc699fc3..c0b705d641 100644
--- a/tests/models/test_dag.py
+++ b/tests/models/test_dag.py
@@ -42,7 +42,6 @@ from airflow.configuration import conf
 from airflow.decorators import setup, task as task_decorator, teardown
 from airflow.exceptions import (
     AirflowException,
-    DuplicateTaskIdFound,
     ParamValidationError,
     UnknownExecutorException,
 )
@@ -65,7 +64,7 @@ from airflow.models.dag import (
     get_asset_triggered_next_run_info,
 )
 from airflow.models.dagrun import DagRun
-from airflow.models.param import DagParam, Param, ParamsDict
+from airflow.models.param import DagParam, Param
 from airflow.models.serialized_dag import SerializedDagModel
 from airflow.models.taskfail import TaskFail
 from airflow.models.taskinstance import TaskInstance as TI
@@ -175,41 +174,6 @@ class TestDag:
                 b_index = i
         return 0 <= a_index < b_index
 
-    def test_params_not_passed_is_empty_dict(self):
-        """
-        Test that when 'params' is _not_ passed to a new Dag, that the params
-        attribute is set to an empty dictionary.
-        """
-        dag = DAG("test-dag", schedule=None)
-
-        assert isinstance(dag.params, ParamsDict)
-        assert 0 == len(dag.params)
-
-    def test_params_passed_and_params_in_default_args_no_override(self):
-        """
-        Test that when 'params' exists as a key passed to the default_args dict
-        in addition to params being passed explicitly as an argument to the
-        dag, that the 'params' key of the default_args dict is merged with the
-        dict of the params argument.
-        """
-        params1 = {"parameter1": 1}
-        params2 = {"parameter2": 2}
-
-        dag = DAG("test-dag", schedule=None, default_args={"params": params1}, 
params=params2)
-
-        assert params1["parameter1"] == dag.params["parameter1"]
-        assert params2["parameter2"] == dag.params["parameter2"]
-
-    def test_not_none_schedule_with_non_default_params(self):
-        """
-        Test if there is a DAG with not None schedule and have some params that
-        don't have a default value raise a error while DAG parsing
-        """
-        params = {"param1": Param(type="string")}
-
-        with pytest.raises(AirflowException):
-            DAG("dummy-dag", schedule=timedelta(days=1), 
start_date=DEFAULT_DATE, params=params)
-
     def test_dag_invalid_default_view(self):
         """
         Test invalid `default_view` of DAG initialization
@@ -238,57 +202,6 @@ class TestDag:
         dag = DAG(dag_id="test-default_orientation", schedule=None)
         assert conf.get("webserver", "dag_orientation") == dag.orientation
 
-    def test_dag_as_context_manager(self):
-        """
-        Test DAG as a context manager.
-        When used as a context manager, Operators are automatically added to
-        the DAG (unless they specify a different DAG)
-        """
-        dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE, 
default_args={"owner": "owner1"})
-        dag2 = DAG("dag2", schedule=None, start_date=DEFAULT_DATE, 
default_args={"owner": "owner2"})
-
-        with dag:
-            op1 = EmptyOperator(task_id="op1")
-            op2 = EmptyOperator(task_id="op2", dag=dag2)
-
-        assert op1.dag is dag
-        assert op1.owner == "owner1"
-        assert op2.dag is dag2
-        assert op2.owner == "owner2"
-
-        with dag2:
-            op3 = EmptyOperator(task_id="op3")
-
-        assert op3.dag is dag2
-        assert op3.owner == "owner2"
-
-        with dag:
-            with dag2:
-                op4 = EmptyOperator(task_id="op4")
-            op5 = EmptyOperator(task_id="op5")
-
-        assert op4.dag is dag2
-        assert op5.dag is dag
-        assert op4.owner == "owner2"
-        assert op5.owner == "owner1"
-
-        with DAG("creating_dag_in_cm", schedule=None, start_date=DEFAULT_DATE) 
as dag:
-            EmptyOperator(task_id="op6")
-
-        assert dag.dag_id == "creating_dag_in_cm"
-        assert dag.tasks[0].task_id == "op6"
-
-        with dag:
-            with dag:
-                op7 = EmptyOperator(task_id="op7")
-            op8 = EmptyOperator(task_id="op8")
-        op9 = EmptyOperator(task_id="op8")
-        op9.dag = dag2
-
-        assert op7.dag == dag
-        assert op8.dag == dag
-        assert op9.dag == dag2
-
     def test_dag_topological_sort_dag_without_tasks(self):
         dag = DAG("dag", schedule=None, start_date=DEFAULT_DATE, 
default_args={"owner": "owner1"})
 
@@ -1287,60 +1200,6 @@ class TestDag:
         dag = DAG("DAG", schedule=None, default_args=default_args)
         assert dag.timezone.name == local_tz.name
 
-    def test_roots(self):
-        """Verify if dag.roots returns the root tasks of a DAG."""
-        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
-            op1 = EmptyOperator(task_id="t1")
-            op2 = EmptyOperator(task_id="t2")
-            op3 = EmptyOperator(task_id="t3")
-            op4 = EmptyOperator(task_id="t4")
-            op5 = EmptyOperator(task_id="t5")
-            [op1, op2] >> op3 >> [op4, op5]
-
-            assert set(dag.roots) == {op1, op2}
-
-    def test_leaves(self):
-        """Verify if dag.leaves returns the leaf tasks of a DAG."""
-        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
-            op1 = EmptyOperator(task_id="t1")
-            op2 = EmptyOperator(task_id="t2")
-            op3 = EmptyOperator(task_id="t3")
-            op4 = EmptyOperator(task_id="t4")
-            op5 = EmptyOperator(task_id="t5")
-            [op1, op2] >> op3 >> [op4, op5]
-
-            assert set(dag.leaves) == {op4, op5}
-
-    def test_duplicate_task_ids_not_allowed_with_dag_context_manager(self):
-        """Verify tasks with Duplicate task_id raises error"""
-        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
-            op1 = EmptyOperator(task_id="t1")
-            with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has 
already been added to the DAG"):
-                BashOperator(task_id="t1", bash_command="sleep 1")
-
-        assert dag.task_dict == {op1.task_id: op1}
-
-    def test_duplicate_task_ids_not_allowed_without_dag_context_manager(self):
-        """Verify tasks with Duplicate task_id raises error"""
-        dag = DAG("test_dag", schedule=None, start_date=DEFAULT_DATE)
-        op1 = EmptyOperator(task_id="t1", dag=dag)
-        with pytest.raises(DuplicateTaskIdFound, match="Task id 't1' has 
already been added to the DAG"):
-            EmptyOperator(task_id="t1", dag=dag)
-
-        assert dag.task_dict == {op1.task_id: op1}
-
-    def test_duplicate_task_ids_for_same_task_is_allowed(self):
-        """Verify that same tasks with Duplicate task_id do not raise error"""
-        with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
-            op1 = op2 = EmptyOperator(task_id="t1")
-            op3 = EmptyOperator(task_id="t3")
-            op1 >> op3
-            op2 >> op3
-
-        assert op1 == op2
-        assert dag.task_dict == {op1.task_id: op1, op3.task_id: op3}
-        assert dag.task_dict == {op2.task_id: op2, op3.task_id: op3}
-
     def test_partial_subset_updates_all_references_while_deepcopy(self):
         with DAG("test_dag", schedule=None, start_date=DEFAULT_DATE) as dag:
             op1 = EmptyOperator(task_id="t1")
diff --git a/uv.lock b/uv.lock
index 3a3c2db04b..1fecc75854 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1373,6 +1373,7 @@ dependencies = [
 
 [package.dev-dependencies]
 dev = [
+    { name = "kgb" },
     { name = "pytest" },
     { name = "pytest-asyncio" },
     { name = "pytest-mock" },
@@ -1387,6 +1388,7 @@ requires-dist = [
 
 [package.metadata.requires-dev]
 dev = [
+    { name = "kgb", specifier = ">=7.1.1" },
     { name = "pytest", specifier = ">=8.3.3" },
     { name = "pytest-asyncio", specifier = ">=0.24.0" },
     { name = "pytest-mock", specifier = ">=3.14.0" },
@@ -2452,6 +2454,15 @@ wheels = [
     { url = 
"https://files.pythonhosted.org/packages/d1/0f/8910b19ac0670a0f80ce1008e5e751c4a57e14d2c4c13a482aa6079fa9d6/jsonschema_specifications-2024.10.1-py3-none-any.whl";,
 hash = 
"sha256:a09a0680616357d9a0ecf05c12ad234479f549239d0f5b55f3deea67475da9bf", size 
= 18459 },
 ]
 
+[[package]]
+name = "kgb"
+version = "7.1.1"
+source = { registry = "https://pypi.org/simple"; }
+sdist = { url = 
"https://files.pythonhosted.org/packages/0c/2e/2b608fa158cd87d7372b1d1c94d70b9b90e4ab5316c77f26feb1e4b6549f/kgb-7.1.1.tar.gz";,
 hash = 
"sha256:74912c8761651f2063151c6c2a36ebe023393de491ec86744771a2888ab9845b", size 
= 61504 }
+wheels = [
+    { url = 
"https://files.pythonhosted.org/packages/80/45/ae8db25f019419b17359ca98f129c0a0d9fa40cadeaac3525b02b690e705/kgb-7.1.1-py2.py3-none-any.whl";,
 hash = 
"sha256:ed535b25caa5d8151bb8700c653a73475a6d3937c75cd2b8ce93c84c97a86a6f", size 
= 58003 },
+]
+
 [[package]]
 name = "lazy-object-proxy"
 version = "1.10.0"

Reply via email to