This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit ce87b82cc48de16c72a60737a1d92d174c2a1989 Author: Hongyi Jin <[email protected]> AuthorDate: Wed May 27 22:36:21 2026 -0400 feat(tirx): add M=128 dispatch + layout for .16x*b tcgen05.ld/st (#646) * feat(tirx): extend .16x*b tcgen05.ld/st dispatch to M=128 fragments The .16x*b atom natively covers M=64 (each warp handles 16 rows = the first half of its 32-lane TMEM partition). To cover all 128 TMEM rows of a warpgroup, callers previously had to either issue two raw ``Tx.ptx.tcgen05.ld(... shape="16x*b")`` calls themselves (with the second addr OR'd with 0x100000 = row+16) or alias the TMEM buffer with a shifted ``allocated_addr`` and call ``Tx.copy_async`` twice. Both leak the half-slab abstraction. This change adds M=128 support directly to the layout factory and the copy_async dispatch: - ``tcgen05_atom_layout(shape, (128, K), dtype)`` accepts ``rows=128`` for the .16x*b shapes. The factory inserts a ``v_slab`` ``m`` iter at the next free m-bit (stride = M=64 per-thread reg count) and doubles wid_in_wg's row stride from 16 to 32. The result is a 128-row fragment whose per-thread register vector has the low-slab regs in [0, M64_regs) and the high-slab regs in [M64_regs, 2*M64_regs). - ``_emit_16xnb_path`` detects ``local_buf.shape[0] ∈ {64, 128}`` via the match helper and emits one PTX issue per 16-row slab (``row=0`` for the low slab, ``row=16`` for the high slab), splitting the per-thread reg vector contiguously between the two calls. Bit-exact tests: the existing 92-case M=64 sweep continues to pass; an additional 18-case M=128 round-trip sweep (3 .16x*b shapes × 3 reps × 2 16-bit dtypes) verifies the new dispatch path. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> * feat(tirx): plumb tcgen05 datapath through tmem_pool.alloc + dispatch The ``.16x*b`` dispatch's TMEM-side layout check was hard-coded to the identity Layout D (M=128, ``(128, K) : (1@TLane, 1@TCol)``). The PTX ``tcgen05.ld.16x*b`` with ``row=0`` actually accesses the *scattered* lane set ``{0..15, 32..47, 64..79, 96..111}`` (Layout F per PTX ISA §9.7.16.10.5), so passing the dispatch a (128, K) Layout D buffer with slice ``[0:64, ...]`` silently returns scattered data when the caller's mental model is "rows 0..63 contiguous". The leak only worked because all in-tree kernels happen to use M=128 MMA + .16x*b M=128 readback, where the scatter doesn't bite. This change makes the datapath part of the TMEM buffer's TileLayout: - ``tmem_datapath_layout(datapath, rows, cols)`` returns the TileLayout for Layout D (M=128 identity) and Layout F (M=64 scatter via two row iters with TLane strides 1 and 32). Other PTX layouts (A, B, C, E, G) are reserved for future expansion. - ``tmem_pool.alloc`` accepts a ``datapath=`` kwarg that derives the buffer's layout from ``tmem_datapath_layout``. ``layout=`` and ``datapath=`` are mutually exclusive; omitting both falls back to the permissive Layout D default. - The ``.16x*b`` and ``.32x32b`` dispatches now structurally classify the TMEM buffer (``_classify_tmem_datapath``) and reject mismatched pairings via ``_check_tmem_layout_for_atom`` against the table: - D × .32x32b: ✓ - D × .16x*b M=64: ✓ (legacy permissive — half-slab read) - D × .16x*b M=128: ✓ (two PTX issues) - F × .16x*b M=64: ✓ (canonical pairing) - F × .16x*b M=128: ✗ - F × .32x32b: ✗ Tests: 304 passed (was 266) including 36 new Layout F round-trips and two negative tests for F × .16x*b M=128 and F × .32x32b. The existing M=64 sweep against Layout D buffers remains permissive (no test change needed) since that's the legacy contract. Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> * fix(tirx): correct Layout F iter ordering for SplitCoord semantics The first version of ``tmem_datapath_layout("F", ...)`` encoded the scattered M=64 mapping as ``S[(16, 4, cols) : (1@TLane, 32@TLane, 1@TCol)]``, under the mistaken assumption (taken from the comments in ``tcgen05_atom_layout``) that the *first* iter takes the *lowest* row bits. ``SplitCoord`` (src/tirx/ir/layout/utils.cc) actually walks the shape from the last dim back to the first, so for shape ``(s0, s1)`` the first iter receives ``coord // s1`` (the *high* bits) and the second receives ``coord % s1`` (the low bits) — row-major flattening. With the wrong ordering, ``apply(r, c)`` decomposed ``r`` as ``(r // 4, r % 4)`` and mapped to ``TLane = (r // 4) * 1 + (r % 4) * 32``, which doesn't match the canonical scatter ``(r // 16) * 32 + (r % 16)`` that ``.16x*b`` M=64 PTX accesses. The round-trip test didn't catch it because both write and read use the same factory output, so the dispatch's structural compatibility check just sees "two layouts that match each other" — the PTX behavior is fixed by hardware and the layout label is essentially decorative for this dispatch path. This change: - Re-orders the iter list to ``S[(4, 16, cols) : (32@TLane, 1@TLane, 1@TCol)]`` so that ``SplitCoord(r, (4, 16))`` returns ``[r // 16, r % 16]`` and the warp selector (iter 0) gets the TLane stride 32 while the within-slab lane (iter 1) gets TLane stride 1. ``layout.apply(r, c)`` now returns ``TLane = (r // 16) * 32 + (r % 16)``, matching the ``_frag_row_to_tmem_lane`` formula used by the existing ``_run_load_test`` host-side validation. - Updates the factory's docstring with the row-major SplitCoord ordering so future readers don't have to re-derive it from the C++. - Adds two unit tests (``test_tmem_datapath_layout_F_row_to_lane_mapping``, ``test_tmem_datapath_layout_D_row_to_lane_mapping``) that call ``layout.apply(r, c, shape=...)`` directly and assert the ``(row, col) → (TLane, TCol)`` mapping bit-exactly across every M=64 logical row. These are the right discriminator between the two possible iter orderings. Tests: 132 passed (was 130 — +2 layout unit tests). Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]> * fix(tirx): align tcgen05_atom_layout with SplitCoord semantics The per-shape ``row_iters_fp32`` and ``col_iters_fp32`` lists in ``tcgen05_atom_layout`` were written low-to-high to match the natural mathematical decomposition (e.g. ``.16x256b``'s ``row = t1 + 8*va + 16*wid_in_wg``: t1 is the LOW iter in the list, wid is the HIGH). ``TileLayout`` decomposes a flat coord via ``SplitCoord`` (src/tirx/ir/layout/utils.cc), which uses row-major ordering — the FIRST iter receives ``coord // (product of remaining extents)`` (i.e. the *high* bits), and the LAST iter receives ``coord % extent`` (the *low* bits). As a result, the produced TileLayout's ``apply(row, col)`` returned axis values that *disagreed* with PTX. e.g. for ``.16x256b`` M=64 N=1, the factory mapped ``(row=1, col=0)`` to ``(laneid=0, m=0, wid=1)`` — but PTX puts that frag element at ``(laneid=4, m=0, wid=0)`` (thread 4 reg 0, per ``_decompose_fp32``). The round-trip tests didn't catch it because the dispatch ignores the layout label and emits raw PTX; the layout was essentially decorative. This change: - Adds an explicit reversal of ``row_iters`` and ``col_iters`` (both individually, after ``_scale``) just before constructing the final ``Iter`` list, so the FIRST iter in each dim receives the HIGH bits per ``SplitCoord``. The C_pack iter for 16-bit dtypes is now appended at the LOW end of the col axis (last iter), matching its per-thread-pack-bit semantics. - Adds ``test_tcgen05_atom_layout_apply_matches_decompose_fp32`` — a 9-case sweep (3 ``.16x*b`` shapes × 3 reps) that walks every (thread, reg) combination, derives the expected (row, col) via ``_decompose_fp32``, and asserts ``layout.apply(row, col)`` returns the matching (laneid, wid_in_wg, m) triple. This pins down the layout so it can't silently drift again — round-trip alone wouldn't catch a reversal bug. Tests: 315 passed in copy_async/ suite (was 304 — +9 atom-layout apply sweep, +2 datapath layout apply unit tests added in the prior commit). Existing M=64 / M=128 / Layout F bit-exact round-trips unchanged. --- 3rdparty/cutlass_fpA_intB_gemm | 2 +- 3rdparty/tvm-ffi | 2 +- python/tvm/tirx/lang/alloc_pool.py | 30 ++- python/tvm/tirx/layout.py | 165 ++++++++++++++-- .../tile_primitive/cuda/copy_async/tcgen05_ldst.py | 211 +++++++++++++++++---- .../cuda/copy_async/test_tmem_16xnb.py | 197 ++++++++++++++++++- 6 files changed, 541 insertions(+), 66 deletions(-) diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm index 953121f189..72b9883c98 160000 --- a/3rdparty/cutlass_fpA_intB_gemm +++ b/3rdparty/cutlass_fpA_intB_gemm @@ -1 +1 @@ -Subproject commit 953121f18946cedf88c2ccb6439944956ad495a8 +Subproject commit 72b9883c986a2ff427ca61ac0b14ad59be1dc862 diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi index 3c35034fd1..1fed0ae042 160000 --- a/3rdparty/tvm-ffi +++ b/3rdparty/tvm-ffi @@ -1 +1 @@ -Subproject commit 3c35034fd1026011736e19a4e0e1ed0f22058c42 +Subproject commit 1fed0ae0421e614d45662e8ee6bcae353d3ab2ea diff --git a/python/tvm/tirx/lang/alloc_pool.py b/python/tvm/tirx/lang/alloc_pool.py index 3a9ae82b30..e7996ff432 100644 --- a/python/tvm/tirx/lang/alloc_pool.py +++ b/python/tvm/tirx/lang/alloc_pool.py @@ -300,7 +300,35 @@ class TMEMPool: ) return total_bits // (32 * rows) - def alloc(self, shape, dtype="float32", *, layout=None, cols=None): + def alloc(self, shape, dtype="float32", *, layout=None, cols=None, datapath=None): + """Allocate a TMEM buffer. + + Parameters + ---------- + shape, dtype, cols + Standard buffer shape / dtype / column count. + layout + Explicit ``TileLayout``. Mutually exclusive with ``datapath``. + datapath : str | None + Optional tcgen05 datapath letter (``"D"`` for M=128 full datapath, + ``"F"`` for M=64 non-``.ws`` scattered). When provided, the buffer's + layout is derived from ``tmem_datapath_layout(datapath, *shape)`` + so the row index reflects the *physical* TMEM lane occupation + (PTX ISA §9.7.16.10.5). The downstream ``.16x*b`` / ``.32x32b`` + dispatches structurally check this layout to catch mismatched + atoms (e.g. a ``.16x*b`` M=128 read against a Layout F buffer). + Defaults to ``None``, which means Layout D's identity row→lane + mapping — keep this for shape ``(128, X)`` buffers that hold + an M=128 MMA accumulator. + """ + from tvm.tirx.layout import tmem_datapath_layout + + if layout is not None and datapath is not None: + raise ValueError("TMEMPool.alloc: pass at most one of layout= and datapath=") + if datapath is not None: + assert len(shape) == 2, "TMEMPool.alloc: datapath= requires a 2-D shape" + layout = tmem_datapath_layout(datapath, shape[0], shape[1]) + ir = _get_ir() cols = self._resolve_cols(shape, dtype, cols, layout) col_start = self.offset diff --git a/python/tvm/tirx/layout.py b/python/tvm/tirx/layout.py index e17a0d61f8..c3cc25748b 100644 --- a/python/tvm/tirx/layout.py +++ b/python/tvm/tirx/layout.py @@ -566,7 +566,94 @@ except NameError: # pragma: no cover __all__ = [] # type: ignore[var-annotated] __all__ += list(_AXIS_NAMES) __all__ += ["R", "S"] -__all__ += ["wg_local_layout", "tcgen05_atom_layout"] +__all__ += ["wg_local_layout", "tcgen05_atom_layout", "tmem_datapath_layout"] + + +# ============================================================================ +# TMEM datapath layouts (PTX ISA §9.7.16.10.5) +# ============================================================================ +# +# ``tcgen05.mma`` writes its output matrix C into TMEM using one of several +# **datapath layouts** depending on the MMA's M dimension and ``.ws`` mode. +# Each layout determines *which* physical TMEM lanes (rows) the matrix +# occupies; the leak in the original ``_default_tmem_layout`` was that it +# always used the identity ``(rows, cols) : (1@TLane, 1@TCol)`` mapping, +# which is correct only for Layout D (M=128 full datapath). For Layout F +# (M=64 non-``.ws``) the MMA writes scattered lanes +# ``{0..15, 32..47, 64..79, 96..111}`` — half of each warp's 32-lane +# partition — and the readback path (``.16x*b`` M=64 atom) has the matching +# scatter built into the PTX. To keep the buffer's logical row indexing in +# sync with the physical scatter, the buffer's TileLayout must encode the +# scatter directly. +# +# We surface this via the factory below. Callers pass the datapath letter +# (``"D"`` / ``"F"``) and the logical ``(rows, cols)``; the factory returns +# the appropriate TileLayout. ``tmem_pool.alloc(..., datapath="F")`` plumbs +# this into the buffer's layout so the dispatch can structurally verify +# atom ↔ datapath compatibility instead of silently accepting mismatches. +# +# Supported today: +# - ``"D"``: M=128, ``.cta_group::1``, full datapath. Identity row→lane. +# - ``"F"``: M=64, non-``.ws``, half datapath (4×1 lane utilization). +# Logical row r → physical lane (r // 16) * 32 + (r % 16). +# +# Layouts A / B / C / E / G are reserved for future expansion. + + +_TMEM_DATAPATH_ROWS = {"D": 128, "F": 64} + + +def tmem_datapath_layout(datapath: str, rows: int, cols: int) -> "TileLayout": + """Return the ``TileLayout`` for a tcgen05 MMA datapath. + + See PTX ISA §9.7.16.10.5 for the datapath enumeration. The returned + layout is shape-compatible with a buffer of ``(rows, cols)`` and + encodes the logical-row → physical-TMEM-lane mapping that the + corresponding MMA writes to (and that the matching ``.16x*b`` / + ``.32x32b`` atom expects to read). + + Parameters + ---------- + datapath : str + One of ``"D"`` (M=128, ``.cta_group::1``, full datapath) or + ``"F"`` (M=64, non-``.ws``, half datapath). Other layouts are not + yet supported by this factory. + rows : int + Logical row count of the TMEM buffer. Must match the datapath's M + dimension: 128 for D, 64 for F. + cols : int + Logical column count. + + Returns + ------- + TileLayout + Buffer-shape-compatible layout for ``(rows, cols)``. + """ + if datapath not in _TMEM_DATAPATH_ROWS: + raise ValueError( + f"tmem_datapath_layout: unknown datapath {datapath!r}; " + f"supported: {sorted(_TMEM_DATAPATH_ROWS)}" + ) + expected = _TMEM_DATAPATH_ROWS[datapath] + if rows != expected: + raise ValueError( + f"tmem_datapath_layout: datapath={datapath!r} expects rows={expected}, got {rows}" + ) + tlane = Axis.get("TLane") + tcol = Axis.get("TCol") + if datapath == "D": + # M=128, identity row→lane: row r ∈ [0, 128) → physical lane r. + return TileLayout(S[(rows, cols) : (1 @ tlane, 1 @ tcol)]) + # Layout F: M=64 scattered. Logical row r = wid * 16 + intra (wid ∈ [0,4), + # intra ∈ [0,16)) → physical lane wid * 32 + intra, i.e. + # ``r // 16`` is the warp selector and ``r % 16`` is the within-slab lane. + # ``TileLayout`` decomposes a scalar row index via ``SplitCoord`` + # (src/tirx/ir/layout/utils.cc), which uses row-major ordering: with + # shape ``(s0, s1)`` the FIRST iter receives ``coord // s1`` (the high + # bits) and the SECOND receives ``coord % s1`` (the low bits). So we + # pin the warp selector to iter 0 (extent 4, TLane stride 32) and the + # within-slab lane to iter 1 (extent 16, TLane stride 1). + return TileLayout(S[(4, 16, cols) : (32 @ tlane, 1 @ tlane, 1 @ tcol)]) def wg_local_layout(cols, rows=128): @@ -593,8 +680,19 @@ _TCGEN05_ATOM_REPS = { # fragment is 128 rows × (factor * rep) fp32 cols with factor=1. _TCGEN05_COL_FACTOR_FP32 = {"32x32b": 1, "16x64b": 2, "16x128b": 4, "16x256b": 8} -# Number of fragment rows per warpgroup for each instr_shape. -_TCGEN05_FRAG_ROWS = {"32x32b": 128, "16x64b": 64, "16x128b": 64, "16x256b": 64} +# Allowed fragment row counts per warpgroup for each instr_shape. ``.32x32b`` +# is fixed at M=128; ``.16x*b`` natively covers M=64 (one 16-row slab per +# warp, using lanes 0..15 of each warp's 32-lane TMEM partition) and can be +# extended to M=128 by issuing the atom twice with row offsets 0 and 16 +# (covering lanes 0..15 + 16..31, i.e. the warp's full slab). The M=128 +# variant doubles per-thread registers and treats the extra slab as the +# highest m-bit. +_TCGEN05_FRAG_ROWS = { + "32x32b": (128,), + "16x64b": (64, 128), + "16x128b": (64, 128), + "16x256b": (64, 128), +} def tcgen05_atom_layout( @@ -666,10 +764,10 @@ def tcgen05_atom_layout( f"tcgen05_atom_layout tensor_shape must be 2-D (rows, cols), got {tensor_shape!r}" ) rows, cols = tensor_shape - expected_rows = _TCGEN05_FRAG_ROWS[instr_shape] - if rows != expected_rows: + allowed_rows = _TCGEN05_FRAG_ROWS[instr_shape] + if rows not in allowed_rows: raise ValueError( - f"tcgen05_atom_layout {instr_shape!r} expects rows={expected_rows}, got {rows}" + f"tcgen05_atom_layout {instr_shape!r} expects rows ∈ {allowed_rows}, got {rows}" ) elem_per_32b = 32 // bits @@ -710,20 +808,27 @@ def tcgen05_atom_layout( ] return TileLayout.from_iters(iters, [], {}) + # Iter lists are written high-to-low: ``TileLayout`` decomposes a flat + # coordinate via ``SplitCoord`` (src/tirx/ir/layout/utils.cc) using + # row-major ordering, where the FIRST iter receives the *high* bits and + # the LAST iter receives the *low* bits. So R_w (highest-stride row + # contribution) comes first in row_iters_fp32 and R_t1/t2 (lowest) + # comes last; same for col. if shape == "16x64b": # Per-warp tile (fp32 view): (16 rows, 2N cols). Per-lane regs = N. # Lane (t0, t1, t2): t0 = laneid & 1, t1 = (laneid >> 1) & 1, t2 = laneid >> 2. # Row = t2 + 8*t0 + 16*wid_in_wg # Col (fp32) = t1 + 2*r, r ∈ [0, N) row_iters_fp32 = [ - (8, 4, laneid), # R_t2: laneid bits 2..4 → R bits 0..2 - (2, 1, laneid), # R_t0: laneid bit 0 → R bit 3 (4, 1, wid), # R_w: wid_in_wg → R bits 4..5 + (2, 1, laneid), # R_t0: laneid bit 0 → R bit 3 + (8, 4, laneid), # R_t2: laneid bits 2..4 → R bits 0..2 ] col_iters_fp32 = [ - (2, 2, laneid), # C_t1: laneid bit 1 → C bit 0 (N, 1, "m"), # C_r: register slot → C bits 1.. + (2, 2, laneid), # C_t1: laneid bit 1 → C bit 0 ] + m_used_M64 = N elif shape == "16x128b": # Per-warp tile (fp32 view): (16 rows, 4N cols). Per-lane regs = 2N. # Lane (t0, t1): t0 = laneid & 3, t1 = laneid >> 2. @@ -731,29 +836,50 @@ def tcgen05_atom_layout( # Row = t1 + 8*ra + 16*wid_in_wg # Col (fp32) = t0 + 4*rb row_iters_fp32 = [ - (8, 4, laneid), # R_t1: laneid bits 2..4 → R bits 0..2 - (2, 1, "m"), # R_ra: reg bit 0 → R bit 3 (4, 1, wid), # R_w + (2, 1, "m"), # R_ra: reg bit 0 → R bit 3 + (8, 4, laneid), # R_t1: laneid bits 2..4 → R bits 0..2 ] col_iters_fp32 = [ - (4, 1, laneid), # C_t0: laneid bits 0..1 → C bits 0..1 (N, 2, "m"), # C_rb: reg bits 1.. → C bits 2.. + (4, 1, laneid), # C_t0: laneid bits 0..1 → C bits 0..1 ] + m_used_M64 = 2 * N else: # 16x256b # Per-warp tile (fp32 view): (16 rows, 8N cols). Per-lane regs = 4N. # Lane (t0, t1) as for 16x128b. Reg r = v0p + 2*va + 4*vb. # Row = t1 + 8*va + 16*wid_in_wg # Col (fp32) = v0p + 2*t0 + 8*vb row_iters_fp32 = [ - (8, 4, laneid), # R_t1 - (2, 2, "m"), # R_va: reg bit 1 → R bit 3 (4, 1, wid), # R_w + (2, 2, "m"), # R_va: reg bit 1 → R bit 3 + (8, 4, laneid), # R_t1 ] col_iters_fp32 = [ - (2, 1, "m"), # C_v0p: reg bit 0 → C bit 0 - (4, 1, laneid), # C_t0 (N, 4, "m"), # C_vb: reg bits 2.. → C bits 3.. + (4, 1, laneid), # C_t0 + (2, 1, "m"), # C_v0p: reg bit 0 → C bit 0 ] + m_used_M64 = 4 * N + + if rows == 128: + # M=128 covers both 16-row half-slabs of each warp's 32-lane TMEM + # partition (the M=64 atom covers only lanes 0..15; the high half + # 16..31 needs a second PTX issue with row offset 16). We surface + # the combined fragment as a single (128, K) tile by inserting a + # v_slab iter right *after* R_w (i.e. as the next-highest row bit). + # v_slab claims one m-bit at the next free offset + # (stride = m_used_M64) so reg indices [0, m_used_M64) hold the + # low slab and [m_used_M64, 2*m_used_M64) hold the high slab — the + # split the dispatch uses when emitting the two PTX calls. The + # inserted iter also doubles wid_in_wg's row stride from 16 to 32, + # so the four warps now tile rows 0..31 / 32..63 / 64..95 / 96..127. + new_row_iters = [] + for ext, stride, axis in row_iters_fp32: + new_row_iters.append((ext, stride, axis)) + if axis is wid: + new_row_iters.append((2, m_used_M64, "m")) + row_iters_fp32 = new_row_iters def _scale(iters): out = [] @@ -768,10 +894,11 @@ def tcgen05_atom_layout( col_iters = _scale(col_iters_fp32) # For the 16-bit packed variant each fp32 register holds two adjacent - # column elements (low / high halves), so we prepend a C_pack iter of - # extent ``elem_per_32b`` and m-stride 1 to the column group. + # column elements (low / high halves). Add a C_pack iter of extent + # ``elem_per_32b`` and m-stride 1 at the *low* end of the col axis — + # i.e. as the LAST col iter under SplitCoord's high-to-low ordering. if elem_per_32b > 1: - col_iters = [(elem_per_32b, 1, "m"), *col_iters] + col_iters.append((elem_per_32b, 1, "m")) iters = [Iter(ext, stride, axis) for ext, stride, axis in row_iters + col_iters] return TileLayout.from_iters(iters, [], {}) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py index 345eef4519..0c958dde3e 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py @@ -27,7 +27,15 @@ from tvm.arith import Analyzer from tvm.runtime import DataType from tvm.script import tirx as Tx from tvm.tirx import Buffer, PrimFunc -from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout, tid_in_wg +from tvm.tirx.layout import ( + S, + TCol, + TileLayout, + TLane, + tcgen05_atom_layout, + tid_in_wg, + tmem_datapath_layout, +) from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, register_dispatch from tvm.tirx.stmt import TilePrimitiveCall @@ -42,17 +50,18 @@ _TCGEN05_COL_FACTOR_FP32 = {"16x64b": 2, "16x128b": 4, "16x256b": 8} def _match_tcgen05_atom_layout(buf): - """Return ``(instr_shape, rep)`` if ``buf.layout`` matches a tcgen05 - ``.16x*b`` atom layout for some supported ``instr_shape``. + """Return ``(instr_shape, rep, frag_rows)`` if ``buf.layout`` matches a + tcgen05 ``.16x*b`` atom layout for some supported ``instr_shape``. - The local buffer shape ``(64, K)`` together with the dtype determines the - candidate ``rep`` for each ``instr_shape``; we just probe the three shapes - and structurally compare. ``None`` if no atom layout matches. + The local buffer shape ``(frag_rows, K)`` (``frag_rows`` ∈ {64, 128}) + together with the dtype determines the candidate ``rep`` for each + ``instr_shape``; we just probe the three shapes × two frag_rows and + structurally compare. ``None`` if no atom layout matches. """ if len(buf.shape) != 2: return None rows, cols = int(buf.shape[0]), int(buf.shape[1]) - if rows != 64: + if rows not in (64, 128): return None dtype = buf.dtype layout_c = buf.layout.canonicalize() @@ -68,10 +77,95 @@ def _match_tcgen05_atom_layout(buf): # Recover rep from cols (same arithmetic the factory uses). elem_per_32b = 32 // DataType(dtype).bits rep = cols // (_TCGEN05_COL_FACTOR_FP32[shape] * elem_per_32b) - return shape, rep + return shape, rep, rows + return None + + +def _classify_tmem_datapath(tmem_buf): + """Return ``"D"`` / ``"F"`` if ``tmem_buf.layout`` matches a known tcgen05 + datapath (PTX ISA §9.7.16.10.5), else ``None``. + + Layout D (M=128, identity row→lane) is the default returned by + ``_default_tmem_layout``. Layout F (M=64 non-``.ws``, scattered) is the + explicit opt-in produced by ``tmem_pool.alloc(..., datapath="F")``. + The dispatch uses this to pair each ``.16x*b`` / ``.32x32b`` atom with a + compatible layout — see ``_check_tmem_layout_for_atom``. + """ + if tmem_buf.layout is None: + return None + buf_layout = tmem_buf.layout.canonicalize() + rows = int(tmem_buf.shape[0]) + if rows == 128: + cand = tmem_datapath_layout("D", 128, tmem_buf.shape[1]).canonicalize() + try: + tvm.ir.assert_structural_equal(buf_layout, cand) + return "D" + except (AssertionError, ValueError): + return None + if rows == 64: + cand = tmem_datapath_layout("F", 64, tmem_buf.shape[1]).canonicalize() + try: + tvm.ir.assert_structural_equal(buf_layout, cand) + return "F" + except (AssertionError, ValueError): + return None return None +# Compatibility matrix between the TMEM buffer's datapath layout and the +# tcgen05 ld/st atom requested by ``Tx.copy_async``: +# +# datapath × atom | accepted? | rationale +# ---------------------------- | --------- | -------------------------------- +# D (M=128 full) × .32x32b | yes | full 128 lanes, all 32 per warp +# D (M=128 full) × .16x*b M=64| yes | reads first half-slab (lanes +# | | 0..15 of each warp partition) +# | | — the rest of acc is wasted +# | | for this atom but valid data +# D (M=128 full) × .16x*b M=128| yes | reads all 128 lanes via row=0 +# | | and row=16 PTX issues +# F (M=64 scatter)× .16x*b M=64| yes | canonical pairing — F's row +# | | indexing matches the atom's +# | | scatter access +# F (M=64 scatter)× .16x*b M=128| no | F only writes the low slab; the +# | | high slab (row=16) is garbage +# F (M=64 scatter)× .32x32b | no | F only utilizes 16 of each +# | | warp's 32 lanes +_TMEM_ATOM_COMPAT = { + ("D", "32x32b", 128): True, + ("D", "16x*b", 64): True, + ("D", "16x*b", 128): True, + ("F", "32x32b", 128): False, + ("F", "16x*b", 64): True, + ("F", "16x*b", 128): False, +} + + +def _check_tmem_layout_for_atom(tmem_buf, atom_kind, frag_rows): + """Raise ``ValueError`` if the TMEM buffer's datapath layout is + incompatible with the requested ``tcgen05`` atom. + + ``atom_kind`` is ``"32x32b"`` or ``"16x*b"``; ``frag_rows`` is the + register-side fragment row count (128 for ``.32x32b`` and ``.16x*b`` + M=128 variants, 64 for ``.16x*b`` M=64). If the buffer's layout is + unrecognized (i.e. it isn't Layout D or Layout F), the dispatch falls + back to the structural assertions below. + """ + datapath = _classify_tmem_datapath(tmem_buf) + if datapath is None: + return None + allowed = _TMEM_ATOM_COMPAT.get((datapath, atom_kind, frag_rows), False) + if not allowed: + raise ValueError( + f"tcgen05 dispatch: TMEM buffer with datapath={datapath!r} is " + f"incompatible with atom={atom_kind!r} (frag_rows={frag_rows}). " + f"See PTX ISA §9.7.16.10.5 for datapath/atom pairings; the " + f"buffer was allocated via tmem_pool.alloc(..., " + f"datapath={datapath!r})." + ) + return datapath + + def copy_tmem_local_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc | None: op_call = TilePrimitiveCall.downcast(op_call) dst_buffer_region, src_buffer_region = op_call.dst, op_call.src @@ -108,10 +202,11 @@ def copy_tmem_local_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> P atom_match = _match_tcgen05_atom_layout(local_buf) if atom_match is not None: - shape, num = atom_match + shape, num, frag_rows = atom_match return _emit_16xnb_path( shape=shape, num=num, + frag_rows=frag_rows, direction=direction, tmem_buf=tmem_buf, local_buf=local_buf, @@ -138,6 +233,9 @@ def _emit_32x32b_path( ) -> PrimFunc: """Original M=128 fragment path using ``tcgen05.{ld,st}.32x32b.xN``.""" # local: 128xWIDTH <-> tmem: 128xSHAPE[1] + # ``.32x32b`` accesses 32 lanes per warp — the full warp partition — so + # the TMEM buffer must be Layout D (M=128 full datapath). Reject Layout F. + _check_tmem_layout_for_atom(tmem_buf, "32x32b", 128) assert analyzer.can_prove_equal(local_buf.shape[0], 128) assert analyzer.can_prove_equal(tmem_buf.shape[0], 128) @@ -198,6 +296,7 @@ def _emit_16xnb_path( *, shape, num, + frag_rows, direction, tmem_buf, local_buf, @@ -206,55 +305,86 @@ def _emit_16xnb_path( elem_per_32b, analyzer, ) -> PrimFunc: - """M=64 fragment path using ``tcgen05.{ld,st}.<shape>.x<num>`` (one of - ``.16x64b``, ``.16x128b``, ``.16x256b``). - - Each of the warpgroup's 4 warps issues the atom once with - ``row_offset=0``; the PTX TMEM access restriction places warp ``i`` on - TMEM lanes ``i*32..i*32+31``, of which the atom uses the first 16 to - cover one 16-row slab of the 64-row fragment. Collectively, the four - warps cover all 64 rows. + """``.16x*b`` fragment path using ``tcgen05.{ld,st}.<shape>.x<num>`` (one + of ``.16x64b``, ``.16x128b``, ``.16x256b``). + + Each of the warpgroup's 4 warps issues the atom with ``row_offset=0`` to + cover lanes 0..15 of its 32-lane TMEM partition (one 16-row slab); the + four warps collectively span M=64 rows. When ``frag_rows == 128`` the + dispatch emits a second issue with ``row_offset=16`` to also cover lanes + 16..31 of each warp's partition, doubling the fragment's row coverage to + M=128. The two atoms share the same column footprint; the layout factory + surfaces the combined per-thread register vector with the second slab's + regs in the high half of the m-axis (so the dispatch can split regs + contiguously between the two PTX calls). """ # Per-atom column footprint in fp32 columns: # .16x64b → 2N .16x128b → 4N .16x256b → 8N col_factor_fp32 = {"16x64b": 2, "16x128b": 4, "16x256b": 8}[shape] - # Per-thread register count (in 32-bit units): + # Per-thread register count per 16-row slab (in 32-bit units): # .16x64b.xN → N .16x128b.xN → 2N .16x256b.xN → 4N - regs_per_thread = {"16x64b": num, "16x128b": 2 * num, "16x256b": 4 * num}[shape] + regs_per_thread_per_slab = {"16x64b": num, "16x128b": 2 * num, "16x256b": 4 * num}[shape] + n_slabs = frag_rows // 64 # 1 for M=64, 2 for M=128 + assert n_slabs in (1, 2) + regs_per_thread = regs_per_thread_per_slab * n_slabs # Logical column width that the local buffer view exposes (in element units). width_elems = col_factor_fp32 * num * elem_per_32b # Per-thread storage in element units (same total bits as the register vector). per_thread_elems = regs_per_thread * elem_per_32b - # Local-side: shape (64, K_cols) - assert analyzer.can_prove_equal(local_buf.shape[0], 64), ( - f".16x*b path expects local_buf rows=64, got {local_buf.shape[0]}" + # Local-side: shape (frag_rows, K_cols) + assert analyzer.can_prove_equal(local_buf.shape[0], frag_rows), ( + f".16x*b path expects local_buf rows={frag_rows}, got {local_buf.shape[0]}" ) assert analyzer.can_prove_equal(local_buf.shape[1], width_elems), ( f".16x*b path expects local_buf cols={width_elems}, got {local_buf.shape[1]}" ) - # TMEM-side: shape (128, W); the M=64 fragment occupies the first 16 lanes of - # each warp's 32-lane slab. - assert analyzer.can_prove_equal(tmem_buf.shape[0], 128), ( - f".16x*b path expects tmem_buf rows=128, got {tmem_buf.shape[0]}" - ) - tmem_layout = TileLayout(S[(128, tmem_buf.shape[1]) : (1 @ TLane, 1 @ TCol)]).canonicalize() - tvm.ir.assert_structural_equal(tmem_buf.layout.canonicalize(), tmem_layout) + # TMEM-side: structurally classify the buffer's datapath (D or F) and + # reject incompatible pairings. The PTX is identical in either case (the + # warp partition rule and the atom's lane access pattern are baked into + # the hardware); the layout classification just keeps the buffer's + # logical row indexing in sync with the physical TMEM occupation. + datapath = _check_tmem_layout_for_atom(tmem_buf, "16x*b", frag_rows) + + if datapath == "F": + # Layout F: buffer shape (64, W), scattered row→lane. + assert analyzer.can_prove_equal(tmem_buf.shape[0], 64), ( + f".16x*b Layout F expects tmem_buf rows=64, got {tmem_buf.shape[0]}" + ) + tmem_rows = 64 + else: + # Layout D (or untagged legacy buffers): shape (128, W), identity. + # The legacy structural check below still fires for untagged buffers + # so we don't silently accept arbitrary layouts. + assert analyzer.can_prove_equal(tmem_buf.shape[0], 128), ( + f".16x*b path expects tmem_buf rows=128, got {tmem_buf.shape[0]}" + ) + if datapath is None: + tmem_layout = TileLayout( + S[(128, tmem_buf.shape[1]) : (1 @ TLane, 1 @ TCol)] + ).canonicalize() + tvm.ir.assert_structural_equal(tmem_buf.layout.canonicalize(), tmem_layout) + tmem_rows = 128 tmem_st, tmem_extent = get_st_extent(tmem_region) local_st, local_extent = get_st_extent(local_region) - # Local slice must be the full (64, K_cols) view. + # Local slice must be the full (frag_rows, K_cols) view. assert analyzer.can_prove_equal(local_st[0], 0) - assert analyzer.can_prove_equal(local_extent[0], 64) + assert analyzer.can_prove_equal(local_extent[0], frag_rows) assert analyzer.can_prove_equal(local_extent[1], width_elems) - # TMEM slice must start at row 0 (warp 0 of the WG is at lane 0) and span - # 64 rows (collectively the 4 warps' first 16-lane chunks). + # TMEM slice must start at row 0 and span ``frag_rows`` rows. For Layout + # F the buffer is already (64, W) so frag_rows=64 covers the full slice; + # for Layout D + frag_rows=64 the slice reads the *first* half-slab and + # the rest of the buffer's 128 rows is invisible to this atom. For + # Layout D + frag_rows=128 the slice covers all 128 physical lanes via + # two PTX issues (row=0 + row=16). assert analyzer.can_prove_equal(tmem_st[0], 0) - assert analyzer.can_prove_equal(tmem_extent[0], 64) + assert analyzer.can_prove_equal(tmem_extent[0], frag_rows) assert analyzer.can_prove_equal(tmem_extent[1], width_elems) + del tmem_rows # only used for the structural check above col_off = tmem_st[1] assert analyzer.can_prove_equal(tvm.tirx.floormod(col_off, elem_per_32b), 0) @@ -282,11 +412,14 @@ def _emit_16xnb_path( # for the register-pointer arguments of the PTX builtin. local_storage = local_buf.view(per_thread_elems, layout=TileLayout(S[per_thread_elems])) # noqa: E501 local_32b = local_storage.view("uint32") - op( - tmem_buf.allocated_addr[0], - *[local_32b[local_col_off_elems // elem_per_32b + i] for i in range(regs_per_thread)], # noqa: E501 - shape=shape, num=num, row=0, col=col_off_32b, - ) + local_reg_base = local_col_off_elems // elem_per_32b + for slab in range(n_slabs): + reg_base = slab * regs_per_thread_per_slab + op( + tmem_buf.allocated_addr[0], + *[local_32b[local_reg_base + reg_base + i] for i in range(regs_per_thread_per_slab)], # noqa: E501 + shape=shape, num=num, row=slab * 16, col=col_off_32b, + ) # fmt: on return impl diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py index 09a15f3bd1..0bea7587ce 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py @@ -42,7 +42,14 @@ import pytest import tvm import tvm.testing from tvm.script import tirx as Tx -from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout +from tvm.tirx.layout import ( + S, + TCol, + TileLayout, + TLane, + tcgen05_atom_layout, + tmem_datapath_layout, +) from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg @@ -189,15 +196,58 @@ def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype): _run_roundtrip_16b(shape, rep, dtype) -def _run_roundtrip_16b(shape: str, rep: int, dtype: str): +# ``.16x*b`` atom can also span M=128 by emitting two issues per copy_async +# (row=0 + row=16), covering the full 32-lane TMEM partition of each warp. +# We only need to spot-check that the dispatch fires correctly and the per- +# thread reg ↔ TMEM mapping round-trips bit-exactly — the M=64 sweep above +# already covers the (lane, reg) decomposition, so a sparse rep set suffices. [email protected]("shape", ["16x64b", "16x128b", "16x256b"]) [email protected]("rep", [1, 2, 4]) [email protected]("dtype", ["float16", "bfloat16"]) +def test_tcgen05_16xnb_roundtrip_16b_M128(shape, rep, dtype): + if rep not in _SHAPE_REPS[shape]: + pytest.skip(f"rep {rep} not valid for {shape}") + _run_roundtrip_16b(shape, rep, dtype, frag_rows_override=128) + + +# Layout F (M=64 non-``.ws``, scattered) round-trip: the buffer is declared +# with the scatter-encoded TileLayout that ``tmem_datapath_layout("F", ...)`` +# produces. ``.16x*b`` M=64 PTX has the matching scatter built in, so the +# round-trip is bit-exact in the same way as Layout D + M=64. [email protected]("shape", ["16x64b", "16x128b", "16x256b"]) [email protected]("rep", [1, 2, 4]) [email protected]("dtype", ["float16", "bfloat16"]) +def test_tcgen05_16xnb_roundtrip_16b_layout_F(shape, rep, dtype): + if rep not in _SHAPE_REPS[shape]: + pytest.skip(f"rep {rep} not valid for {shape}") + _run_roundtrip_16b(shape, rep, dtype, tmem_datapath="F") + + +def _run_roundtrip_16b( + shape: str, + rep: int, + dtype: str, + *, + frag_rows_override=None, + tmem_datapath: str = "D", +): bits = tvm.runtime.DataType(dtype).bits assert bits == 16 elem_per_32b = 2 K_cols_fp32 = _COL_FACTOR_FP32[shape] * rep K_cols_elem = K_cols_fp32 * elem_per_32b regs_per_thread = _REGS_FACTOR[shape] * rep + if frag_rows_override is not None: + # M=128 doubles per-thread registers (second 16-row slab per warp). + assert frag_rows_override == 128 and _FRAG_ROWS[shape] == 64 + regs_per_thread *= 2 per_thread_elems = regs_per_thread * elem_per_32b - frag_rows = _FRAG_ROWS[shape] + frag_rows = frag_rows_override if frag_rows_override is not None else _FRAG_ROWS[shape] + if tmem_datapath == "F": + # Layout F is only valid with M=64 (per the datapath table); M=128 + # would need to read the high slab, which Layout F doesn't expose. + assert frag_rows == 64, "Layout F + M=128 is an invalid pairing" + tmem_rows = 64 if tmem_datapath == "F" else 128 # The 16-bit round-trip writes and reads exclusively through .16x*b atoms, # so the TMEM column footprint is whatever ``K_cols_fp32`` says — no @@ -205,6 +255,7 @@ def _run_roundtrip_16b(shape: str, rep: int, dtype: str): tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32)) stage_width_elem = tmem_col_width_32b * elem_per_32b atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype) + tmem_layout = tmem_datapath_layout(tmem_datapath, tmem_rows, stage_width_elem) @Tx.prim_func def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None: @@ -237,11 +288,11 @@ def _run_roundtrip_16b(shape: str, rep: int, dtype: str): Tx.tvm_storage_sync("shared") tmem = Tx.decl_buffer( - (128, stage_width_elem), + (tmem_rows, stage_width_elem), dtype, scope="tmem", allocated_addr=tmem_addr[0], - layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @ TCol)]), + layout=tmem_layout, ) # Load per-thread A → reg_in @@ -298,6 +349,142 @@ def _next_pow2(x: int) -> int: return 1 << (x - 1).bit_length() +# Unit test: pin down the (row, col) → (TLane, TCol) mapping that the +# ``tmem_datapath_layout`` factory encodes. A self-consistent round-trip +# (write + read with the same factory output) can't catch a layout that +# encodes a *wrong* scatter — the labels would still match structurally +# even if the row→lane formula doesn't match PTX's actual behavior. This +# test bypasses compilation and checks the layout's ``apply`` method +# directly against ``_frag_row_to_tmem_lane`` for every M=64 logical row. +def test_tmem_datapath_layout_F_row_to_lane_mapping(): + """Layout F: every logical row r ∈ [0, 64) must land at physical TMEM + lane ``(r // 16) * 32 + (r % 16)`` — the canonical scatter that the + ``.16x*b`` M=64 PTX accesses (warp i on lanes ``i * 32 .. i * 32 + 15``). + """ + cols = 32 + layout = tmem_datapath_layout("F", 64, cols) + for r in range(64): + for c in [0, 1, 7, 16, 31]: + # Use ``apply(coord, shape=[64, cols])`` so (r, c) gets flattened + # row-major before SplitCoord into the shard iters. + axis_values = layout.apply(r, c, shape=[64, cols]) + expected_lane = (r // 16) * 32 + (r % 16) + assert int(axis_values["TLane"]) == expected_lane, ( + f"(r={r}, c={c}) mapped to TLane {int(axis_values['TLane'])}, " + f"expected {expected_lane} (= (r//16)*32 + (r%16))" + ) + assert int(axis_values["TCol"]) == c, ( + f"(r={r}, c={c}) mapped to TCol {int(axis_values['TCol'])}, expected {c}" + ) + + [email protected]("shape", ["16x64b", "16x128b", "16x256b"]) [email protected]("rep", [1, 2, 4]) +def test_tcgen05_atom_layout_apply_matches_decompose_fp32(shape, rep): + """``tcgen05_atom_layout`` is supposed to be the inverse of + ``_decompose_fp32`` — i.e. for every (row, col) in the M=64 fragment, + ``layout.apply(row, col)`` must return the (laneid, wid_in_wg, m) + tuple that PTX puts at frag element ``(row, col)``. + + The factory's per-shape iter lists are written low-to-high (natural + decomposition); the reversal added below is what aligns the resulting + TileLayout with ``SplitCoord`` (high-to-low). Without the reversal the + factory used to silently produce a layout that disagreed with PTX — + the round-trip tests didn't catch it because the dispatch ignores the + layout label and emits raw PTX. This sweep is the structural fence. + """ + if rep not in _SHAPE_REPS[shape]: + pytest.skip(f"rep {rep} not valid for {shape}") + cols = _COL_FACTOR_FP32[shape] * rep # K_cols_fp32 + layout = tcgen05_atom_layout(shape, (64, cols), "float32") + for thread in range(128): + laneid = thread & 31 + wid_in_wg = thread >> 5 + regs_per_thread = _REGS_FACTOR[shape] * rep + for reg in range(regs_per_thread): + row, col = _decompose_fp32(shape, thread, reg) + axis_values = layout.apply(row, col, shape=[64, cols]) + assert int(axis_values.get("laneid", 0)) == laneid, ( + f"shape={shape} rep={rep}: (row={row}, col={col}) " + f"mapped to laneid {int(axis_values.get('laneid', 0))}, expected {laneid}" + ) + assert int(axis_values.get("wid_in_wg", 0)) == wid_in_wg, ( + f"shape={shape} rep={rep}: (row={row}, col={col}) " + f"mapped to wid_in_wg {int(axis_values.get('wid_in_wg', 0))}, expected {wid_in_wg}" + ) + assert int(axis_values.get("m", 0)) == reg, ( + f"shape={shape} rep={rep}: (row={row}, col={col}) " + f"mapped to m {int(axis_values.get('m', 0))}, expected {reg}" + ) + + +def test_tmem_datapath_layout_D_row_to_lane_mapping(): + """Layout D: identity row→lane (no scatter).""" + cols = 32 + layout = tmem_datapath_layout("D", 128, cols) + for r in [0, 1, 15, 16, 31, 32, 63, 64, 127]: + axis_values = layout.apply(r, 0, shape=[128, cols]) + assert int(axis_values["TLane"]) == r, ( + f"r={r} mapped to TLane {int(axis_values['TLane'])}, expected {r}" + ) + + +# Negative tests: the datapath/atom pairing matrix in ``tcgen05_ldst.py`` +# must reject mismatched combinations. We construct a Layout F TMEM buffer +# (64 rows, scattered) and try to read it with a ``.16x*b`` M=128 atom, +# which would interpret the second slab (lanes 16..31 of each warp) as +# meaningful data — but Layout F leaves that slab undefined. Compilation +# must raise a clear error, not silently emit a broken kernel. [email protected]("atom_kind,frag_rows", [("16x*b", 128), ("32x32b", 128)]) +def test_layout_F_rejects_incompatible_atoms(atom_kind, frag_rows): + """Layout F + (.16x*b M=128 or .32x32b) must raise at compile time.""" + if atom_kind == "16x*b": + shape = "16x256b" + rep = 1 + # Local fragment shape for M=128 .16x256b rep=1 = (128, 8) fp32. + atom_view = tcgen05_atom_layout(shape, (128, 8), "float32") + local_extent_rows = 128 + local_cols = 8 + else: # .32x32b path: local (128, 32) fp32 + atom_view = TileLayout(S[(128, 32) : (1 @ axis_tid_in_wg, 1)]) + local_extent_rows = 128 + local_cols = 32 + + tmem_layout = tmem_datapath_layout("F", 64, max(32, local_cols)) + tmem_rows = 64 + stage_width_elem = max(32, local_cols) + + @Tx.prim_func + def kernel() -> None: + Tx.device_entry() + Tx.warp_id([128 // 32]) + Tx.cta_id([2]) + wg_id = Tx.warpgroup_id([1]) + Tx.warp_id_in_wg([4]) + Tx.lane_id([32]) + Tx.thread_id([128]) + tmem_addr = Tx.alloc_shared([1], "uint32") + if wg_id == 0: + with Tx.warpgroup(): + Tx.tvm_storage_sync("shared") + tmem = Tx.decl_buffer( + (tmem_rows, stage_width_elem), + "float32", + scope="tmem", + allocated_addr=tmem_addr[0], + layout=tmem_layout, + ) + frag = Tx.alloc_local((local_extent_rows * local_cols // 128,), "float32") + frag_view = frag.view(local_extent_rows, local_cols, layout=atom_view) + Tx.copy_async(frag_view[:, :], tmem[0:local_extent_rows, 0:local_cols]) + + target = tvm.target.Target("cuda") + with target: + mod = tvm.IRModule({"main": kernel}) + with pytest.raises((ValueError, RuntimeError), match="datapath"): + tvm.compile(mod, target=target, tir_pipeline="tirx") + + def _run_load_test(shape: str, rep: int, dtype: str): """Stage A into TMEM via .32x32b, then read it back as the fragment via .<shape>.x<rep> (through ``Tx.alloc_tcgen05_frag``), and compare each
