This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 61ae85b9d1 [REFACTOR][PYTHON] Consolidate derived_object into
tvm.ir.utils (#19630)
61ae85b9d1 is described below
commit 61ae85b9d105980bf9113af36c76317ca4ef0191
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed May 27 23:15:43 2026 -0400
[REFACTOR][PYTHON] Consolidate derived_object into tvm.ir.utils (#19630)
## Summary
`derived_object` was duplicated byte-for-byte across
`python/tvm/runtime/support.py` and
`python/tvm/s_tir/meta_schedule/utils.py`. The function is not a runtime
feature and is used outside meta_schedule (tvm.relax, tvm.tirx), so
neither location was the right home.
Move the single canonical definition into a new
`python/tvm/ir/utils.py`. `tvm.ir` loads before both `tvm.tirx` and
`tvm.s_tir`, so eager top-level imports work from every consumer without
load-order workarounds.
Rewrite all 25 caller imports. Keep the better-typed `cls: type[T] ->
type[T]` signature from the runtime-side copy. After this change
`runtime/support.py` is empty and is removed;
`meta_schedule/__init__.py` drops its now-dead re-export. No alias shims
are left behind — callers update imports directly.
---
python/tvm/contrib/hexagon/meta_schedule.py | 3 +-
python/tvm/{runtime/support.py => ir/utils.py} | 3 +-
python/tvm/relax/expr_functor.py | 2 +-
python/tvm/s_tir/meta_schedule/__init__.py | 1 -
.../s_tir/meta_schedule/builder/local_builder.py | 3 +-
.../s_tir/meta_schedule/cost_model/mlp_model.py | 3 +-
.../s_tir/meta_schedule/cost_model/random_model.py | 3 +-
.../s_tir/meta_schedule/cost_model/xgb_model.py | 3 +-
.../feature_extractor/random_feature_extractor.py | 2 +-
.../tvm/s_tir/meta_schedule/runner/local_runner.py | 3 +-
.../tvm/s_tir/meta_schedule/runner/rpc_runner.py | 2 +-
.../s_tir/meta_schedule/testing/dummy_object.py | 2 +-
python/tvm/s_tir/meta_schedule/utils.py | 140 ---------------------
python/tvm/tirx/functor.py | 35 +-----
.../meta_schedule/test_meta_schedule_cost_model.py | 2 +-
.../meta_schedule/test_meta_schedule_database.py | 5 +-
.../test_meta_schedule_feature_extractor.py | 2 +-
.../test_meta_schedule_measure_callback.py | 15 +--
.../test_meta_schedule_post_order_apply.py | 2 +-
.../meta_schedule/test_meta_schedule_runner.py | 2 +-
.../test_meta_schedule_search_strategy.py | 2 +-
.../test_meta_schedule_space_generator.py | 2 +-
.../test_meta_schedule_task_scheduler.py | 7 +-
.../meta_schedule/test_meta_schedule_tune_tir.py | 3 +-
24 files changed, 43 insertions(+), 204 deletions(-)
diff --git a/python/tvm/contrib/hexagon/meta_schedule.py
b/python/tvm/contrib/hexagon/meta_schedule.py
index 5582f69746..0084d1da7f 100644
--- a/python/tvm/contrib/hexagon/meta_schedule.py
+++ b/python/tvm/contrib/hexagon/meta_schedule.py
@@ -23,6 +23,7 @@ from collections.abc import Callable
import tvm
from tvm.driver import build as tvm_build
from tvm.ir.module import IRModule
+from tvm.ir.utils import derived_object
from tvm.runtime import Module, Tensor
from tvm.s_tir.meta_schedule.builder import LocalBuilder
from tvm.s_tir.meta_schedule.runner import (
@@ -36,7 +37,7 @@ from tvm.s_tir.meta_schedule.runner.rpc_runner import (
default_alloc_argument,
default_run_evaluator,
)
-from tvm.s_tir.meta_schedule.utils import cpu_count, derived_object
+from tvm.s_tir.meta_schedule.utils import cpu_count
from tvm.s_tir.transform import RemoveWeightLayoutRewriteBlock
from tvm.support.popen_pool import PopenPoolExecutor
from tvm.target import Target
diff --git a/python/tvm/runtime/support.py b/python/tvm/ir/utils.py
similarity index 99%
rename from python/tvm/runtime/support.py
rename to python/tvm/ir/utils.py
index d9762ef571..a2068d4759 100644
--- a/python/tvm/runtime/support.py
+++ b/python/tvm/ir/utils.py
@@ -14,8 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-
-"""Runtime support infra of TVM."""
+"""Utilities shared across TVM IR packages."""
from typing import TypeVar
diff --git a/python/tvm/relax/expr_functor.py b/python/tvm/relax/expr_functor.py
index 5ac77da3c0..c9ea88d111 100644
--- a/python/tvm/relax/expr_functor.py
+++ b/python/tvm/relax/expr_functor.py
@@ -23,8 +23,8 @@ from collections.abc import Callable
import tvm_ffi
from tvm.ir import Op
+from tvm.ir.utils import derived_object
from tvm.runtime import Object
-from tvm.runtime.support import derived_object
from ..ir.module import IRModule
from . import _ffi_api
diff --git a/python/tvm/s_tir/meta_schedule/__init__.py
b/python/tvm/s_tir/meta_schedule/__init__.py
index f3601f6e6d..3fbdd37859 100644
--- a/python/tvm/s_tir/meta_schedule/__init__.py
+++ b/python/tvm/s_tir/meta_schedule/__init__.py
@@ -53,5 +53,4 @@ from .task_scheduler import TaskScheduler
from .tir_integration import tune_tir
from .tune import tune_tasks
from .tune_context import TuneContext
-from .utils import derived_object
from .post_optimization import post_opt
diff --git a/python/tvm/s_tir/meta_schedule/builder/local_builder.py
b/python/tvm/s_tir/meta_schedule/builder/local_builder.py
index 2a88c0167b..aa563294e2 100644
--- a/python/tvm/s_tir/meta_schedule/builder/local_builder.py
+++ b/python/tvm/s_tir/meta_schedule/builder/local_builder.py
@@ -25,12 +25,13 @@ from typing import Optional, Union
from tvm_ffi import register_global_func
from tvm.ir import IRModule
+from tvm.ir.utils import derived_object
from tvm.runtime import Module, Tensor, load_param_dict, save_param_dict
from tvm.support.popen_pool import MapResult, PopenPoolExecutor, StatusKind
from tvm.target import Target
from ..logging import get_logger
-from ..utils import cpu_count, derived_object,
get_global_func_with_default_on_worker
+from ..utils import cpu_count, get_global_func_with_default_on_worker
from .builder import BuilderInput, BuilderResult, PyBuilder
logger = get_logger(__name__) # pylint: disable=invalid-name
diff --git a/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py
b/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py
index 162110371f..a9bb7c784d 100644
--- a/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py
+++ b/python/tvm/s_tir/meta_schedule/cost_model/mlp_model.py
@@ -32,6 +32,7 @@ import numpy as np # type: ignore
import torch # type: ignore
import tvm
+from tvm.ir.utils import derived_object
from tvm.support.tar import tar, untar
from ....runtime import Tensor
@@ -43,7 +44,7 @@ from ..logging import get_logger
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
-from ..utils import derived_object, shash2hex
+from ..utils import shash2hex
logger = get_logger("mlp_model") # pylint: disable=invalid-name
diff --git a/python/tvm/s_tir/meta_schedule/cost_model/random_model.py
b/python/tvm/s_tir/meta_schedule/cost_model/random_model.py
index 292fd4a964..86df91d58d 100644
--- a/python/tvm/s_tir/meta_schedule/cost_model/random_model.py
+++ b/python/tvm/s_tir/meta_schedule/cost_model/random_model.py
@@ -18,11 +18,12 @@
Random cost model
"""
+from tvm.ir.utils import derived_object
+
from ..cost_model import PyCostModel
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
-from ..utils import derived_object # type: ignore
@derived_object
diff --git a/python/tvm/s_tir/meta_schedule/cost_model/xgb_model.py
b/python/tvm/s_tir/meta_schedule/cost_model/xgb_model.py
index 3bc0b4d769..8d6aa49b10 100644
--- a/python/tvm/s_tir/meta_schedule/cost_model/xgb_model.py
+++ b/python/tvm/s_tir/meta_schedule/cost_model/xgb_model.py
@@ -25,6 +25,7 @@ from typing import TYPE_CHECKING, Any, Literal, NamedTuple,
Optional
import numpy as np # type: ignore
+from tvm.ir.utils import derived_object
from tvm.support.tar import tar, untar
from ....runtime import Tensor
@@ -33,7 +34,7 @@ from ..feature_extractor import FeatureExtractor
from ..logging import get_logger
from ..runner import RunnerResult
from ..search_strategy import MeasureCandidate
-from ..utils import cpu_count, derived_object, shash2hex
+from ..utils import cpu_count, shash2hex
from .metric import max_curve
if TYPE_CHECKING:
diff --git
a/python/tvm/s_tir/meta_schedule/feature_extractor/random_feature_extractor.py
b/python/tvm/s_tir/meta_schedule/feature_extractor/random_feature_extractor.py
index 8cf7e2f2bf..fc42b36604 100644
---
a/python/tvm/s_tir/meta_schedule/feature_extractor/random_feature_extractor.py
+++
b/python/tvm/s_tir/meta_schedule/feature_extractor/random_feature_extractor.py
@@ -19,11 +19,11 @@
import numpy as np # type: ignore
import tvm.runtime
+from tvm.ir.utils import derived_object
from ..feature_extractor import PyFeatureExtractor
from ..search_strategy import MeasureCandidate
from ..tune_context import TuneContext
-from ..utils import derived_object
@derived_object
diff --git a/python/tvm/s_tir/meta_schedule/runner/local_runner.py
b/python/tvm/s_tir/meta_schedule/runner/local_runner.py
index c55925fd0b..b56cd6613c 100644
--- a/python/tvm/s_tir/meta_schedule/runner/local_runner.py
+++ b/python/tvm/s_tir/meta_schedule/runner/local_runner.py
@@ -22,12 +22,13 @@ from collections.abc import Callable
from contextlib import contextmanager
import tvm
+from tvm.ir.utils import derived_object
from tvm.support.popen_pool import PopenPoolExecutor
from ....runtime import Device, Module
from ..logging import get_logger
from ..profiler import Profiler
-from ..utils import derived_object, get_global_func_with_default_on_worker
+from ..utils import get_global_func_with_default_on_worker
from .config import EvaluatorConfig
from .runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput,
RunnerResult
from .utils import (
diff --git a/python/tvm/s_tir/meta_schedule/runner/rpc_runner.py
b/python/tvm/s_tir/meta_schedule/runner/rpc_runner.py
index 435cfd8b4d..27ab71e669 100644
--- a/python/tvm/s_tir/meta_schedule/runner/rpc_runner.py
+++ b/python/tvm/s_tir/meta_schedule/runner/rpc_runner.py
@@ -21,6 +21,7 @@ import os.path as osp
from collections.abc import Callable
from contextlib import contextmanager
+from tvm.ir.utils import derived_object
from tvm.rpc import RPCSession
from tvm.runtime import Device, Module
from tvm.support.popen_pool import PopenPoolExecutor
@@ -28,7 +29,6 @@ from tvm.support.popen_pool import PopenPoolExecutor
from ..logging import get_logger
from ..profiler import Profiler
from ..utils import (
- derived_object,
get_global_func_on_rpc_session,
get_global_func_with_default_on_worker,
)
diff --git a/python/tvm/s_tir/meta_schedule/testing/dummy_object.py
b/python/tvm/s_tir/meta_schedule/testing/dummy_object.py
index 007de8a9de..d3e0d55a93 100644
--- a/python/tvm/s_tir/meta_schedule/testing/dummy_object.py
+++ b/python/tvm/s_tir/meta_schedule/testing/dummy_object.py
@@ -18,13 +18,13 @@
import random
+from tvm.ir.utils import derived_object
from tvm.s_tir.schedule import Trace
from ..builder import BuilderInput, BuilderResult, PyBuilder
from ..mutator import PyMutator
from ..runner import PyRunner, PyRunnerFuture, RunnerFuture, RunnerInput,
RunnerResult
from ..tune_context import TuneContext # pylint: disable=unused-import
-from ..utils import derived_object
@derived_object
diff --git a/python/tvm/s_tir/meta_schedule/utils.py
b/python/tvm/s_tir/meta_schedule/utils.py
index 42f52a6c1e..775054c4ce 100644
--- a/python/tvm/s_tir/meta_schedule/utils.py
+++ b/python/tvm/s_tir/meta_schedule/utils.py
@@ -32,146 +32,6 @@ from tvm.rpc import RPCSession
from tvm.tirx import FloatImm, IntImm
-def derived_object(cls: type) -> type:
- """A decorator to register derived subclasses for TVM objects.
-
- Parameters
- ----------
- cls : type
- The derived class to be registered.
-
- Returns
- -------
- cls : type
- The decorated TVM object.
-
- Example
- -------
- .. code-block:: python
-
- @register_object("s_tir.meta_schedule.PyRunner")
- class _PyRunner(meta_schedule.Runner):
- def __init__(self, f_run: Callable = None):
- self.__init_handle_by_constructor__(_ffi_api.RunnerPyRunner,
f_run)
-
- class PyRunner:
- _tvm_metadata = {
- "cls": _PyRunner,
- "methods": ["run"]
- }
- def run(self, runner_inputs):
- raise NotImplementedError
-
- @derived_object
- class LocalRunner(PyRunner):
- def run(self, runner_inputs):
- ...
- """
-
- import functools # pylint: disable=import-outside-toplevel
- import weakref # pylint: disable=import-outside-toplevel
-
- def _extract(inst: type, name: str):
- """Extract function from intrinsic class."""
-
- def method(*args, **kwargs):
- return getattr(inst, name)(*args, **kwargs)
-
- for inherit_cls, base_cls in zip(cls.__mro__, cls.__mro__[1:]):
- # extract functions that differ from the base class
- if not hasattr(base_cls, name):
- continue
- if getattr(base_cls, name) is getattr(inherit_cls, name) and name
!= "__str__":
- continue
- return method
-
- # for task scheduler return None means calling default function
- # otherwise it will trigger a TVMError of method not implemented
- # on the c++ side when you call the method, __str__ not required
- return None
-
- assert isinstance(cls.__base__, type)
- if hasattr(cls, "_type") and cls._type == "TVMDerivedObject": # type:
ignore
- raise TypeError(
- f"Inheritance from a decorated object `{cls.__name__}` is not
allowed. "
- f"Please inherit from `{cls.__name__}._cls`."
- )
- assert hasattr(cls, "_tvm_metadata"), (
- "Please use the user-facing method overriding class, i.e., PyRunner."
- )
-
- base = cls.__base__
- metadata = getattr(base, "_tvm_metadata")
- fields = metadata.get("fields", [])
- methods = metadata.get("methods", [])
-
- base_cls = metadata["cls"]
- slots = []
- if getattr(base_cls, "__dictoffset__", 0) == 0:
- slots.append("__dict__")
- if getattr(base_cls, "__weakrefoffset__", 0) == 0:
- slots.append("__weakref__")
-
- class TVMDerivedObject(base_cls): # type: ignore
- """The derived object to avoid cyclic dependency."""
-
- __slots__ = tuple(slots)
-
- _cls = cls
- _type = "TVMDerivedObject"
-
- def __init__(self, *args, **kwargs):
- """Constructor."""
- self._inst = cls(*args, **kwargs)
-
- super().__init__(
- # the constructor's parameters, builder, runner, etc.
- *[getattr(self._inst, name) for name in fields],
- # the function methods, init_with_tune_context, build, run,
etc.
- *[_extract(self._inst, name) for name in methods],
- )
-
- # for task scheduler hybrid funcs in c++ & python side
- # using weakref to avoid cyclic dependency
- self._inst._outer = weakref.ref(self)
-
- def __getattr__(self, name):
- import inspect # pylint: disable=import-outside-toplevel
-
- try:
- # fall back to instance attribute if there is not any
- # return self._inst.__getattribute__(name)
- result = self._inst.__getattribute__(name)
- except AttributeError:
- result = super().__getattr__(name)
-
- if inspect.ismethod(result):
-
- def method(*args, **kwargs):
- return result(*args, **kwargs)
-
- # set __own__ to aviod implicit deconstruction
- setattr(method, "__own__", self)
- return method
-
- return result
-
- def __setattr__(self, name, value):
- if name not in ["_inst", "key", "handle"]:
- self._inst.__setattr__(name, value)
- else:
- super().__setattr__(name, value)
-
- functools.update_wrapper(TVMDerivedObject.__init__, cls.__init__) # type:
ignore
- TVMDerivedObject.__name__ = cls.__name__
- TVMDerivedObject.__doc__ = cls.__doc__
- TVMDerivedObject.__module__ = cls.__module__
- for key, value in cls.__dict__.items():
- if isinstance(value, classmethod | staticmethod):
- setattr(TVMDerivedObject, key, value)
- return TVMDerivedObject
-
-
@register_global_func("s_tir.meta_schedule.cpu_count")
def _cpu_count_impl(logical: bool = True) -> int:
"""Return the number of logical or physical CPUs in the system
diff --git a/python/tvm/tirx/functor.py b/python/tvm/tirx/functor.py
index 4619c0b51f..ab2af06a79 100644
--- a/python/tvm/tirx/functor.py
+++ b/python/tvm/tirx/functor.py
@@ -23,7 +23,7 @@ from collections.abc import Callable
import tvm_ffi
from tvm.ir import PrimExpr
-from tvm.runtime.support import derived_object
+from tvm.ir.utils import derived_object
from . import _ffi_api
from .expr import (
@@ -78,39 +78,10 @@ from .stmt import (
While,
)
+# visitor and mutator are aliases for derived_object
visitor = derived_object
-"""
-A decorator to wrap user-customized PyStmtExprVisitor as TVM object
_PyStmtExprVisitor.
-
-Parameters
-----------
-visitor_cls : PyStmtExprVisitor
- The user-customized PyStmtExprVisitor.
-
-Returns
--------
-cls : _PyStmtExprVisitor
- The decorated TVM object _PyStmtExprVisitor(StmtExprVisitor on the C++
side).
-
-Example
--------
-.. code-block:: python
-
- @tirx.functor.stmt_expr_visitor
- class MyStmtExprVisitor(PyStmtExprVisitor):
- # customize visit function
- def visit_call_(self, op: Call) -> None:
- # just for demo purposes
- ...
- # myvisitor is now a special visitor that visit every Call with
- # user-customized visit_call_
- myvisitor = MyStmtExprVisitor()
- # apply myvisitor to PrimExpr and Stmt
- myvisitor.visit_expr(expr)
- myvisitor.visit_stmt(stmt)
-"""
-
mutator = derived_object
+
"""
A decorator to wrap user-customized PyStmtExprMutator as TVM object
_PyStmtExprMutator.
diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py
index 4ba49ebe24..b2385597ab 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_cost_model.py
@@ -27,13 +27,13 @@ import numpy as np
import tvm
import tvm.testing
+from tvm.ir.utils import derived_object
from tvm.s_tir.meta_schedule.cost_model import PyCostModel, RandomModel,
XGBModel
from tvm.s_tir.meta_schedule.cost_model.xgb_model import PackSum,
_get_custom_call_back
from tvm.s_tir.meta_schedule.feature_extractor import RandomFeatureExtractor
from tvm.s_tir.meta_schedule.runner import RunnerResult
from tvm.s_tir.meta_schedule.search_strategy import MeasureCandidate
from tvm.s_tir.meta_schedule.tune_context import TuneContext
-from tvm.s_tir.meta_schedule.utils import derived_object
from tvm.s_tir.schedule.schedule import Schedule
from tvm.script import tirx as T
diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py
index 9314dedf57..ffe4945f68 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_database.py
@@ -29,6 +29,7 @@ import tvm
import tvm.testing
from tvm import tirx
from tvm.ir.module import IRModule
+from tvm.ir.utils import derived_object
from tvm.s_tir import Schedule
from tvm.s_tir import meta_schedule as ms
from tvm.s_tir.meta_schedule.database import TuningRecord, Workload
@@ -113,7 +114,7 @@ def _equal_record(a: ms.database.TuningRecord, b:
ms.database.TuningRecord):
assert str(arg0.as_json()) == str(arg1.as_json())
[email protected]_object
+@derived_object
class PyMemoryDatabaseDefault(ms.database.PyDatabase):
def __init__(self):
super().__init__()
@@ -156,7 +157,7 @@ class PyMemoryDatabaseDefault(ms.database.PyDatabase):
return len(self.tuning_records_)
[email protected]_object
+@derived_object
class PyMemoryDatabaseOverride(ms.database.PyDatabase):
def __init__(self):
super().__init__()
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor.py
index 1d336d9b5a..91723c539c 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_feature_extractor.py
@@ -20,10 +20,10 @@ import re
import numpy as np
import tvm.runtime
+from tvm.ir.utils import derived_object
from tvm.s_tir.meta_schedule import TuneContext
from tvm.s_tir.meta_schedule.feature_extractor import PyFeatureExtractor
from tvm.s_tir.meta_schedule.search_strategy import MeasureCandidate
-from tvm.s_tir.meta_schedule.utils import derived_object
def test_meta_schedule_feature_extractor():
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py
index 2d61829203..b9f2bcab7a 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_measure_callback.py
@@ -21,6 +21,7 @@ import tempfile
import pytest
import tvm
+from tvm.ir.utils import derived_object
from tvm.s_tir import meta_schedule as ms
from tvm.s_tir.schedule import Schedule
from tvm.script import tirx as T
@@ -48,7 +49,7 @@ class Matmul:
def test_meta_schedule_measure_callback():
- @ms.derived_object
+ @derived_object
class FancyMeasureCallback(ms.measure_callback.PyMeasureCallback):
def apply(
self,
@@ -82,7 +83,7 @@ def test_meta_schedule_measure_callback():
def test_meta_schedule_measure_callback_fail():
- @ms.derived_object
+ @derived_object
class FailingMeasureCallback(ms.measure_callback.PyMeasureCallback):
def apply(
self,
@@ -106,7 +107,7 @@ def test_meta_schedule_measure_callback_fail():
def test_meta_schedule_measure_callback_as_string():
- @ms.derived_object
+ @derived_object
class NotSoFancyMeasureCallback(ms.measure_callback.PyMeasureCallback):
def apply(
self,
@@ -125,7 +126,7 @@ def test_meta_schedule_measure_callback_as_string():
@pytest.mark.skip("Tuning test - launches runner")
def test_meta_schedule_measure_callback_update_cost_model_with_zero():
- @ms.derived_object
+ @derived_object
class AllZeroRunnerFuture(ms.runner.PyRunnerFuture):
def done(self) -> bool:
return True
@@ -133,7 +134,7 @@ def
test_meta_schedule_measure_callback_update_cost_model_with_zero():
def result(self) -> ms.runner.RunnerResult:
return ms.runner.RunnerResult([0.0, 0.0], None)
- @ms.derived_object
+ @derived_object
class AllZeroRunner(ms.runner.PyRunner):
def run(self, runner_inputs: list[ms.runner.RunnerInput]) ->
list[ms.runner.RunnerResult]:
return [AllZeroRunnerFuture() for _ in runner_inputs]
@@ -151,7 +152,7 @@ def
test_meta_schedule_measure_callback_update_cost_model_with_zero():
@pytest.mark.skip("Tuning test - launches runner")
def test_meta_schedule_measure_callback_update_cost_model_with_runtime_error():
- @ms.derived_object
+ @derived_object
class EmptyRunnerFuture(ms.runner.PyRunnerFuture):
def done(self) -> bool:
return True
@@ -159,7 +160,7 @@ def
test_meta_schedule_measure_callback_update_cost_model_with_runtime_error():
def result(self) -> ms.runner.RunnerResult:
return ms.runner.RunnerResult(None, "error")
- @ms.derived_object
+ @derived_object
class EmptyRunner(ms.runner.PyRunner):
def run(self, runner_inputs: list[ms.runner.RunnerInput]) ->
list[ms.runner.RunnerResult]:
return [EmptyRunnerFuture() for _ in runner_inputs]
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
index ee9b74d92d..46d71ca6e7 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_post_order_apply.py
@@ -28,10 +28,10 @@ import tvm.testing
from tvm import te
from tvm.error import TVMError
from tvm.ir.module import IRModule
+from tvm.ir.utils import derived_object
from tvm.s_tir.meta_schedule import TuneContext
from tvm.s_tir.meta_schedule.schedule_rule import PyScheduleRule
from tvm.s_tir.meta_schedule.space_generator import PostOrderApply
-from tvm.s_tir.meta_schedule.utils import derived_object
from tvm.s_tir.schedule import SBlockRV, Schedule
from tvm.script import tirx as T
from tvm.target import Target
diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py
index 9c267a69c6..b23c603a4b 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_runner.py
@@ -28,6 +28,7 @@ from tvm_ffi import register_global_func
import tvm
import tvm.testing
+from tvm.ir.utils import derived_object
from tvm.rpc import RPCSession
from tvm.runtime import Device, Module
from tvm.s_tir.meta_schedule.arg_info import TensorInfo
@@ -53,7 +54,6 @@ from tvm.s_tir.meta_schedule.runner.rpc_runner import (
)
from tvm.s_tir.meta_schedule.testing.local_rpc import LocalRPC
from tvm.s_tir.meta_schedule.utils import (
- derived_object,
get_global_func_with_default_on_worker,
)
from tvm.script import tirx as T
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
index 0f393e23ab..002741c6bf 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_search_strategy.py
@@ -23,9 +23,9 @@ import pytest
import tvm
import tvm.testing
+from tvm.ir.utils import derived_object
from tvm.s_tir import meta_schedule as ms
from tvm.s_tir.meta_schedule.testing.dummy_object import DummyMutator
-from tvm.s_tir.meta_schedule.utils import derived_object
from tvm.s_tir.schedule import Schedule, Trace
from tvm.script import tirx as T
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py
index a783cf5872..5515d66f9a 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_generator.py
@@ -24,13 +24,13 @@ import pytest
import tvm
import tvm.testing
from tvm.base import TVMError
+from tvm.ir.utils import derived_object
from tvm.s_tir.meta_schedule.space_generator import (
PySpaceGenerator,
ScheduleFn,
SpaceGeneratorUnion,
)
from tvm.s_tir.meta_schedule.tune_context import TuneContext
-from tvm.s_tir.meta_schedule.utils import derived_object
from tvm.s_tir.schedule import Schedule
from tvm.script import tirx as T
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py
index 1ffedc30ca..61f5583c2a 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_task_scheduler.py
@@ -24,6 +24,7 @@ import pytest
import tvm
import tvm.testing
+from tvm.ir.utils import derived_object
from tvm.s_tir import Schedule
from tvm.s_tir import meta_schedule as ms
from tvm.s_tir.meta_schedule.testing.dummy_object import DummyBuilder,
DummyRunner
@@ -119,7 +120,7 @@ def _schedule_batch_matmul(sch: Schedule):
sch.reorder(i_0, j_0, i_1, j_1, k_0, i_2, j_2, k_1, i_3, j_3, t_0, t_1)
[email protected]_object
+@derived_object
class MyTaskScheduler(ms.task_scheduler.PyTaskScheduler):
done: set = set()
@@ -233,7 +234,7 @@ def test_meta_schedule_task_scheduler_multiple():
def test_meta_schedule_task_scheduler_NIE(): # pylint: disable=invalid-name
- @ms.derived_object
+ @derived_object
class NIETaskScheduler(ms.task_scheduler.PyTaskScheduler):
pass
@@ -360,7 +361,7 @@ def
test_meta_schedule_task_scheduler_gradient_based_with_null_search_strategy()
the scheduler should continue working as normal for other tasks
"""
- @ms.derived_object
+ @derived_object
class NullSearchStrategy(ms.search_strategy.PySearchStrategy):
def __init__(self, rounds_with_empty_candidates):
self.rounds_with_empty_candidates = rounds_with_empty_candidates
diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py
index 97f803fc48..8430072223 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_tune_tir.py
@@ -23,6 +23,7 @@ import pytest
import tvm
import tvm.testing
+from tvm.ir.utils import derived_object
from tvm.s_tir import meta_schedule as ms
from tvm.s_tir.meta_schedule.testing.custom_builder_runner import
run_module_via_rpc
from tvm.s_tir.meta_schedule.testing.local_rpc import LocalRPC
@@ -147,7 +148,7 @@ def test_tune_run_module_via_rpc():
@pytest.mark.skip("Integration test")
def test_tune_block_cpu():
- @ms.derived_object
+ @derived_object
class RemoveBlock(ms.schedule_rule.PyScheduleRule):
def _initialize_with_tune_context(self, context: ms.TuneContext) ->
None:
pass