This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit bb57b65fd8723cb82f7834a0b9974adaad093dca Author: Hongyi Jin <[email protected]> AuthorDate: Wed May 20 22:08:14 2026 -0700 fix(tirx/stmt_functor): add ScopeIdDefStmt to Python StmtFunctor dispatch (#637) The PR-#636 / kernels-#306 ``Tx.device_entry()`` flat-form migration introduced free-standing ``ScopeIdDefStmt`` nodes (``wg_id = Tx.warpgroup_id([N])``, ``warp_id = Tx.warp_id_in_wg([4])``, …) at the kernel-body top level. The C++ ``StmtVisitor`` / ``StmtMutator`` already handle them (visit extents via ``VisitExpr``; mutate via rebuild), but the Python ``StmtFunctor.__init__`` did not register ``"tirx.ScopeIdDefStmt"`` in its ``_dispatch_map``, so every Python visitor / mutator that walks a post-migration kernel falls through to ``visit_stmt_default_`` and blows up with ``NotImplementedError: Do not have a default for ScopeIdDefStmt``. Surfaced concretely from cpusim — ~47 visitor / mutator subclasses were affected, and the tools/cpusim-side workaround (monkey-patching ``StmtFunctor.__init__``) belongs upstream. Changes: * ``stmt.py`` — register a Python ``ScopeIdDefStmt`` class so the FFI returns instances of it (instead of an auto-generated fallback). Re-export from ``tvm.tirx``. The C++ field is named ``def`` (a Python keyword), so access is ``getattr(stmt, "def")``. * ``stmt_functor.py`` — wire ``"tirx.ScopeIdDefStmt"`` into the ``StmtFunctor._dispatch_map`` and add ``visit_scope_id_def_stmt_`` with the same shape as the existing ``visit_*_`` methods: - Abstract on ``StmtFunctor`` (raises via ``visit_stmt_default_``). - Concrete on ``StmtVisitor`` — walk extents and preferred_extents via ``visit_expr`` (mirrors the C++ visitor). - Concrete on ``StmtMutator`` — walk extents and preferred_extents via ``visit_expr``; if any changed, rebuild ``ScopeIdDef`` (using the new ``_SCOPE_BINDING_TO_PARENT_CUR`` reverse map) and wrap in a fresh ``ScopeIdDefStmt`` (mirrors the C++ mutator). * ``exec_scope.py`` — add ``_SCOPE_BINDING_TO_PARENT_CUR`` mirroring the C++ ``ScopeBinding`` enum, so the Python mutator can rebuild a ``ScopeIdDef`` from an existing one (whose ``scope`` field is the int form). --- python/tvm/tirx/__init__.py | 2 +- python/tvm/tirx/exec_scope.py | 19 ++++++++++++++ python/tvm/tirx/stmt.py | 35 +++++++++++++++++++++++++- python/tvm/tirx/stmt_functor.py | 55 +++++++++++++++++++++++++++++++++++++++++ 4 files changed, 109 insertions(+), 2 deletions(-) diff --git a/python/tvm/tirx/__init__.py b/python/tvm/tirx/__init__.py index 00a3522238..10de65a564 100644 --- a/python/tvm/tirx/__init__.py +++ b/python/tvm/tirx/__init__.py @@ -44,7 +44,7 @@ from .stmt import BufferStore, AllocBuffer, AttrStmt, DeclBuffer from .stmt import SeqStmt from .stmt import IfThenElse, Evaluate, stmt_seq, stmt_list from .stmt import BufferRegion, MatchBufferRegion, SBlock, SBlockRealize -from .stmt import TilePrimitiveCall, ExecScopeStmt +from .stmt import TilePrimitiveCall, ExecScopeStmt, ScopeIdDefStmt from .function import PrimFunc, TensorIntrin, IndexMap diff --git a/python/tvm/tirx/exec_scope.py b/python/tvm/tirx/exec_scope.py index 9a6e00eb6c..e63d6830df 100644 --- a/python/tvm/tirx/exec_scope.py +++ b/python/tvm/tirx/exec_scope.py @@ -65,6 +65,25 @@ _SCOPE_KIND_TO_NAME = { } +# Mirror of ``enum class ScopeBinding`` in tvm/tirx/exec_scope.h. Maps the +# ``int`` value of ``ScopeIdDef.scope`` back to the ``(parent, cur)`` pair +# that ``ScopeIdDef.__init__`` accepts — needed when Python code wants to +# rebuild a ``ScopeIdDef`` from an existing one (e.g. a StmtMutator +# walking and rewriting extents). +_SCOPE_BINDING_TO_PARENT_CUR = { + 0: ("kernel", "cluster"), + 1: ("kernel", "cta"), + 2: ("cluster", "cta"), + 3: ("cta", "warpgroup"), + 4: ("cta", "warp"), + 5: ("warpgroup", "warp"), + 6: ("warp", "thread"), + 7: ("cta", "thread"), + 8: ("warpgroup", "thread"), + 9: ("cluster", "cta_pair"), +} + + @register_object("tirx.ExecScope") class ExecScope(Object): """An execution scope, identified by one of {cluster, cta, warpgroup, warp, diff --git a/python/tvm/tirx/stmt.py b/python/tvm/tirx/stmt.py index f1072bf25a..4972c71518 100644 --- a/python/tvm/tirx/stmt.py +++ b/python/tvm/tirx/stmt.py @@ -39,7 +39,7 @@ from tvm.tirx import FloatImm from . import _ffi_api from .buffer import Buffer -from .exec_scope import ExecScope +from .exec_scope import ExecScope, ScopeIdDef from .expr import IterVar, StringImm, Var if TYPE_CHECKING: @@ -848,6 +848,39 @@ class ExecScopeStmt(Stmt): ) # type: ignore +@tvm_ffi.register_object("tirx.ScopeIdDefStmt") +class ScopeIdDefStmt(Stmt): + """ScopeIdDefStmt node. + + Leaf statement that introduces scope-identifier vars + (``wg_id = Tx.warpgroup_id([N])``, ``warp_id = Tx.warp_id_in_wg([4])``, + ``lane_id = Tx.lane_id([32])``, …) at the kernel-body top level. The + underlying ``ScopeIdDef`` carries the def vars, their extents, and + the parent/child scope binding. + + Note: the C++ field is named ``def`` (a Python keyword). Access it + via ``getattr(stmt, "def")`` or ``stmt.__getattribute__("def")`` — + the type-annotation alias here is purely for documentation. + + Parameters + ---------- + def_ : ScopeIdDef + The scope-id definition (def vars, extents, scope binding). + + span : Optional[Span] + The location of this statement in the source code. + """ + + span: Span | None + + def __init__(self, def_: ScopeIdDef, span: Span | None = None) -> None: + self.__init_handle_by_constructor__( + _ffi_api.ScopeIdDefStmt, # type: ignore + def_, + span, + ) # type: ignore + + @tvm_ffi.register_object("tirx.Break") class Break(Stmt): """Break node. diff --git a/python/tvm/tirx/stmt_functor.py b/python/tvm/tirx/stmt_functor.py index 65c08921b9..c67032d4b0 100644 --- a/python/tvm/tirx/stmt_functor.py +++ b/python/tvm/tirx/stmt_functor.py @@ -54,6 +54,7 @@ class StmtFunctor: "tirx.SBlock": self.visit_block_, "tirx.SBlockRealize": self.visit_block_realize_, "tirx.ExecScopeStmt": self.visit_exec_scope_stmt_, + "tirx.ScopeIdDefStmt": self.visit_scope_id_def_stmt_, "tirx.TilePrimitiveCall": self.visit_op_call_, "tirx.AllocBuffer": self.visit_alloc_buffer_, } @@ -176,6 +177,10 @@ class StmtFunctor: """Visitor for ExecScopeStmt nodes.""" return self.visit_stmt_default_(op) + def visit_scope_id_def_stmt_(self, op): + """Visitor for ScopeIdDefStmt nodes.""" + return self.visit_stmt_default_(op) + def visit_op_call_(self, op): """Visitor for TilePrimitiveCall nodes.""" return self.visit_stmt_default_(op) @@ -338,6 +343,23 @@ class StmtVisitor(StmtFunctor): """Visitor implementation for ExecScopeStmt.""" self.visit_stmt(op.body) + def visit_scope_id_def_stmt_(self, op): + """Visitor implementation for ScopeIdDefStmt. + + Mirrors the C++ visitor: walk extents and preferred_extents via + ``visit_expr``; there is no body to recurse into (the def vars + themselves are leaves the visitor doesn't otherwise inspect). + """ + # The C++ field is named ``def``, which is a Python keyword, + # so it's accessed via ``getattr``. + sid = getattr(op, "def") + if sid.extents is not None: + for e in sid.extents: + self.visit_expr(e) + if sid.preferred_extents is not None: + for e in sid.preferred_extents: + self.visit_expr(e) + def visit_op_call_(self, op): """Visitor implementation for TilePrimitiveCall.""" for arg in op.args: @@ -781,6 +803,39 @@ class StmtMutator(StmtFunctor): return tvm.tirx.ExecScopeStmt(op.exec_scope, body, op.span) + def visit_scope_id_def_stmt_(self, op): + """Mutator implementation for ScopeIdDefStmt. + + Mirrors the C++ mutator: rewrite ``extents`` and + ``preferred_extents`` via ``visit_expr``. Deferred-extent defs + (extents is None) and unchanged extents pass through. + """ + from .exec_scope import _SCOPE_BINDING_TO_PARENT_CUR, ScopeIdDef + + # ``def`` is a Python keyword; access the C++ field via ``getattr``. + sid = getattr(op, "def") + changed = False + + def _walk(arr): + nonlocal changed + if arr is None: + return None + out = [] + for e in arr: + ne = self.visit_expr(e) + if ne is not e: + changed = True + out.append(ne) + return out + + new_extents = _walk(sid.extents) + new_pref = _walk(sid.preferred_extents) + if not changed: + return op + parent, cur = _SCOPE_BINDING_TO_PARENT_CUR[sid.scope] + new_def = ScopeIdDef(sid.def_ids, new_extents, parent, cur, new_pref) + return tvm.tirx.ScopeIdDefStmt(new_def, op.span) + def visit_op_call_(self, op): """Mutator implementation for TilePrimitiveCall.""" new_args = []
