This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 35a35b8434 [Tests][Refactor] Remove unused testing helpers (#19800)
35a35b8434 is described below

commit 35a35b8434c36cbaa8ff2b3854a4528ed83d2856
Author: Shushi Hong <[email protected]>
AuthorDate: Tue Jun 16 17:51:23 2026 -0400

    [Tests][Refactor] Remove unused testing helpers (#19800)
    
    CompareBeforeAfter, skip_parameterizations, and xfail_parameterizations
    have no remaining users anywhere in the repo. CompareBeforeAfter (a base
    class for TIR before/after transform tests) has been superseded by the
    inline assert_structural_equal(transform(Before), Expected) pattern, and
    the {skip,xfail}_parameterizations helpers (which marked specific
    parametrizations at runtime) are unused -- native pytest.param(...,
    marks=...) covers that need.
    
    Also drop the private _mark_parameterizations helper they relied on and
    the now-unused 'import textwrap'.
---
 python/tvm/testing/utils.py | 280 --------------------------------------------
 1 file changed, 280 deletions(-)

diff --git a/python/tvm/testing/utils.py b/python/tvm/testing/utils.py
index 9adeba689b..c686bc1184 100644
--- a/python/tvm/testing/utils.py
+++ b/python/tvm/testing/utils.py
@@ -74,7 +74,6 @@ import os
 import pickle
 import platform
 import sys
-import textwrap
 import time
 from pathlib import Path
 
@@ -1170,40 +1169,6 @@ def install_request_hook(depth: int) -> None:
     request_hook.init()
 
 
-def _mark_parameterizations(*params, marker_fn, reason):
-    """
-    Mark tests with a nodeid parameters that exactly matches one in params.
-    Useful for quickly marking tests as xfail when they have a large
-    combination of parameters.
-    """
-    params = set(params)
-
-    def decorator(func):
-        @functools.wraps(func)
-        def wrapper(request, *args, **kwargs):
-            if "[" in request.node.name and "]" in request.node.name:
-                # Strip out the test name and the [ and ] brackets
-                params_from_name = 
request.node.name[len(request.node.originalname) + 1 : -1]
-                if params_from_name in params:
-                    marker_fn(
-                        reason=f"{marker_fn.__name__} on nodeid 
{request.node.nodeid}: " + reason
-                    )
-
-            return func(request, *args, **kwargs)
-
-        return wrapper
-
-    return decorator
-
-
-def xfail_parameterizations(*xfail_params, reason):
-    return _mark_parameterizations(*xfail_params, marker_fn=pytest.xfail, 
reason=reason)
-
-
-def skip_parameterizations(*skip_params, reason):
-    return _mark_parameterizations(*skip_params, marker_fn=pytest.skip, 
reason=reason)
-
-
 def strtobool(val):
     """Convert a string representation of truth to true (1) or false (0).
     True values are 'y', 'yes', 't', 'true', 'on', and '1'; false values
@@ -1224,251 +1189,6 @@ def main():
     sys.exit(pytest.main([test_file, *sys.argv[1:]]))
 
 
-class CompareBeforeAfter:
-    """Utility for comparing before/after of TIR transforms
-
-    A standard framework for writing tests that take a TIR PrimFunc as
-    input, apply a transformation, then either compare against an
-    expected output or assert that the transformation raised an error.
-    A test should subclass CompareBeforeAfter, defining class members
-    `before` / `Before`, `transform`, and `expected` / `Expected`.  
CompareBeforeAfter will
-    then use these members to define a test method and test fixture.
-
-    `transform` may be one of the following.
-
-    - An instance of `tvm.ir.transform.Pass`
-
-    - A method that takes no arguments and returns a `tvm.ir.transform.Pass`
-
-    - A pytest fixture that returns a `tvm.ir.transform.Pass`
-
-    `before` / `Before` may be any one of the following.
-
-    - An instance of `tvm.tirx.PrimFunc`.  This is allowed, but is not
-      the preferred method, as any errors in constructing the
-      `PrimFunc` occur while collecting the test, preventing any other
-      tests in the same file from being run.
-
-    - An TVMScript function, without the ``@T.prim_func`` decoration.
-      The ``@T.prim_func`` decoration will be applied when running the
-      test, rather than at module import.
-
-    - A method that takes no arguments and returns a `tvm.tirx.PrimFunc`
-
-    - A pytest fixture that returns a `tvm.tirx.PrimFunc`
-
-    `expected` / `Expected` may be any one of the following.  The type of
-    `expected` / `Expected` defines the test being performed.  If `expected`
-    provides a `tvm.tirx.PrimFunc`, the result of the transformation
-    must match `expected`.  If `expected` is an exception, then the
-    transformation must raise that exception type.
-
-    - Any option supported for `before` / `Before`.
-
-    - The `Exception` class object, or a class object that inherits
-      from `Exception`.
-
-    - A method that takes no arguments and returns `Exception` or a
-      class object that inherits from `Exception`.
-
-    - A pytest fixture that returns `Exception` or an class object
-      that inherits from `Exception`.
-
-    Examples
-    --------
-
-    .. code-block:: python
-
-        class TestRemoveIf(tvm.testing.CompareBeforeAfter):
-            transform = tvm.tirx.transform.StmtSimplify()
-
-            def before(A: T.Buffer(1, "int32")):
-                if True:
-                    A[0] = 42
-                else:
-                    A[0] = 5
-
-            def expected(A: T.Buffer(1, "int32")):
-                A[0] = 42
-
-    """
-
-    check_well_formed: bool = True
-
-    def __init_subclass__(cls):
-        assert len([getattr(cls, name) for name in ["before", "Before"] if 
hasattr(cls, name)]) <= 1
-        assert (
-            len([getattr(cls, name) for name in ["expected", "Expected"] if 
hasattr(cls, name)])
-            <= 1
-        )
-        for name in ["before", "Before"]:
-            if hasattr(cls, name):
-                cls.before = cls._normalize_before(getattr(cls, name))
-                break
-        for name in ["expected", "Expected"]:
-            if hasattr(cls, name):
-                cls.expected = cls._normalize_expected(getattr(cls, name))
-                break
-        if hasattr(cls, "transform"):
-            cls.transform = cls._normalize_transform(cls.transform)
-
-    @classmethod
-    def _normalize_ir_module(cls, func):
-        if isinstance(func, tvm.tirx.PrimFunc | tvm.IRModule):
-
-            def inner(self):
-                # pylint: disable=unused-argument
-                return func
-
-        elif cls._is_method(func):
-
-            def inner(self):
-                # pylint: disable=unused-argument
-                return func(self)
-
-        elif inspect.isclass(func):
-
-            def inner(self):
-                # pylint: disable=unused-argument
-                func_dict = {}
-                for name, method in func.__dict__.items():
-                    if name.startswith("_"):
-                        pass
-                    elif isinstance(method, tvm.ir.function.BaseFunc):
-                        func_dict[name] = method.with_attr("global_symbol", 
name)
-                    else:
-                        source_code = "@T.prim_func\n" + 
textwrap.dedent(inspect.getsource(method))
-                        prim_func = tvm.script.from_source(
-                            source_code, 
check_well_formed=self.check_well_formed
-                        )
-                        func_dict[name] = prim_func.with_attr("global_symbol", 
name)
-                return tvm.IRModule(func_dict)
-
-        else:
-
-            def inner(self):
-                # pylint: disable=unused-argument
-                source_code = "@T.prim_func\n" + 
textwrap.dedent(inspect.getsource(func))
-                return tvm.script.from_source(source_code, 
check_well_formed=self.check_well_formed)
-
-        return pytest.fixture(inner)
-
-    @classmethod
-    def _normalize_before(cls, func):
-        if hasattr(func, "_pytestfixturefunction"):
-            return func
-        else:
-            return cls._normalize_ir_module(func)
-
-    @classmethod
-    def _normalize_expected(cls, func):
-        if hasattr(func, "_pytestfixturefunction"):
-            return func
-
-        elif inspect.isclass(func) and issubclass(func, Exception):
-
-            def inner(self):
-                # pylint: disable=unused-argument
-                return func
-
-            return pytest.fixture(inner)
-
-        else:
-            return cls._normalize_ir_module(func)
-
-    @classmethod
-    def _normalize_transform(cls, transform):
-        def apply(module_transform):
-            def inner(obj):
-                if isinstance(obj, tvm.IRModule):
-                    return module_transform(obj)
-                elif isinstance(obj, tvm.tirx.PrimFunc):
-                    mod = tvm.IRModule({"main": obj})
-                    mod = module_transform(mod)
-                    return mod["main"]
-                else:
-                    raise TypeError(f"Expected IRModule or PrimFunc, but 
received {type(obj)}")
-
-            return inner
-
-        if hasattr(transform, "_pytestfixturefunction"):
-            if not hasattr(cls, "_transform_orig"):
-                cls._transform_orig = transform
-
-            def inner(self, _transform_orig):
-                # pylint: disable=unused-argument
-                return apply(_transform_orig)
-
-        elif isinstance(transform, tvm.ir.transform.Pass):
-
-            def inner(self):
-                # pylint: disable=unused-argument
-                return apply(transform)
-
-        elif cls._is_method(transform):
-
-            def inner(self):
-                # pylint: disable=unused-argument
-                return apply(transform(self))
-
-        else:
-            raise TypeError(
-                "Expected transform to be a tvm.ir.transform.Pass, or a method 
returning a Pass"
-            )
-
-        return pytest.fixture(inner)
-
-    @staticmethod
-    def _is_method(func):
-        return callable(func) and "self" in inspect.signature(func).parameters
-
-    def test_compare(self, before, expected, transform):
-        """Unit test to compare the expected TIR PrimFunc to actual"""
-
-        if inspect.isclass(expected) and issubclass(expected, Exception):
-            with pytest.raises(expected):
-                after = transform(before)
-
-                # This portion through pytest.fail isn't strictly
-                # necessary, but gives a better error message that
-                # includes the before/after.
-                before_str = before.script(name="before")
-                after_str = after.script(name="after")
-
-                pytest.fail(
-                    msg=(
-                        f"Expected {expected.__name__} to be raised from 
transformation, "
-                        f"instead received TIR\n:{before_str}\n{after_str}"
-                    )
-                )
-
-        elif isinstance(expected, tvm.tirx.PrimFunc | tvm.ir.IRModule):
-            after = transform(before)
-
-            try:
-                # overwrite global symbol so it doesn't come up in the 
comparison
-                if isinstance(after, tvm.tirx.PrimFunc):
-                    after = after.with_attr("global_symbol", "main")
-                    expected = expected.with_attr("global_symbol", "main")
-                tvm.ir.assert_structural_equal(after, expected)
-            except ValueError as err:
-                before_str = before.script(name="before")
-                after_str = after.script(name="after")
-                expected_str = expected.script(name="expected")
-                raise ValueError(
-                    f"TIR after transformation did not match expected:\n"
-                    f"{before_str}\n{after_str}\n{expected_str}"
-                ) from err
-
-        else:
-            raise TypeError(
-                f"tvm.testing.CompareBeforeAfter requires the `expected` 
fixture "
-                f"to return either `Exception`, an `Exception` subclass, "
-                f"or an instance of `tvm.tirx.PrimFunc`.  "
-                f"Instead, received {type(expected)}."
-            )
-
-
 ml_dtypes_dict = {
     "float8_e4m3fn": ml_dtypes.float8_e4m3fn,
     "float8_e5m2": ml_dtypes.float8_e5m2,

Reply via email to