This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch tir-bench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit e3271628f450f4a534818524cfa7559a1dc099f0
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed May 27 16:23:47 2026 -0400

    feat(tirx): add .16x{64,128,256}b tcgen05.ld/st dispatch + factory (#644)
    
    Adds a unified ``tcgen05_atom_layout(instr_shape, tensor_shape, dtype)``
    factory and matching ``Tx.alloc_tcgen05_frag(...)`` wrapper so users can
    allocate per-thread register fragments for any of the supported PTX shape-1
    atoms (.32x32b, .16x64b, .16x128b, .16x256b) with one call. 
``Tx.copy_async``
    inspects the local-side layout and dispatches to the right tcgen05 PTX atom.
    
    Layout derivation: per-shape (row, col) decomposition is derived from the
    CUTLASS DstLayout traits 
(3rdparty/cutlass/include/cute/atom/copy_traits_sm100.hpp).
    For .16x*b shapes (M=64 fragments) each warp's atom covers one 16-row slab
    of the warpgroup's 64-row fragment, driven by the PTX 9.7.16.8.1 access
    restriction that puts warp i on TMEM lanes i*32..i*32+31. For .32x32b
    (M=128) the factory returns the canonical (128, K):(1@tid_in_wg, 1) layout
    already accepted by the existing dispatch path.
    
    TMEM is kept dense for 16-bit dtypes (2 elements per 32-bit cell, matching
    the existing .32x32b convention) rather than going through .pack::16b /
    .unpack::16b — those would waste half the TMEM cell width.
    
    Bit-exact micro-tests (test_tmem_16xnb.py, 92 cases):
    - fp32 load through tcgen05_atom_layout: stage A via .32x32b.st (chunked
      for K>128 fp32 cols), load via .<instr_shape>.x<rep>, per-thread dump,
      host reconstructs (row, col) from the layout formula.
    - fp32 store round-trip (.<instr_shape>.st → .32x32b.ld).
    - 16-bit (fp16/bf16) self-consistent round-trip (.<instr_shape>.st →
      .<instr_shape>.ld preserves per-thread bits).
    - Explicit Tx.alloc_tcgen05_frag wrapper compile + PTX-emission check.
    
    Regression: full tests/python/tirx/operator/tile_primitive/cuda/ suite
    (955 tests) passes unchanged.
    
    Files
    - python/tvm/tirx/layout.py: factory ``tcgen05_atom_layout`` + supported
      rep tables + per-shape iter decompositions.
    - python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py:
      split copy_tmem_local_impl into ``_emit_32x32b_path`` (unchanged) and
      ``_emit_16xnb_path`` (new); structural-match local layout to dispatch.
    - python/tvm/tirx/script/builder/ir.py: ``alloc_tcgen05_frag`` wrapper.
    - 
tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py:
      bit-exact micro-test suite.
---
 python/tvm/tirx/layout.py                          | 200 ++++++
 .../tile_primitive/cuda/copy_async/tcgen05_ldst.py | 178 +++++-
 python/tvm/tirx/script/builder/ir.py               |  57 ++
 .../cuda/copy_async/test_tmem_16xnb.py             | 709 +++++++++++++++++++++
 4 files changed, 1140 insertions(+), 4 deletions(-)

diff --git a/python/tvm/tirx/layout.py b/python/tvm/tirx/layout.py
index ed55b4f802..e17a0d61f8 100644
--- a/python/tvm/tirx/layout.py
+++ b/python/tvm/tirx/layout.py
@@ -566,6 +566,7 @@ except NameError:  # pragma: no cover
     __all__ = []  # type: ignore[var-annotated]
 __all__ += list(_AXIS_NAMES)
 __all__ += ["R", "S"]
+__all__ += ["wg_local_layout", "tcgen05_atom_layout"]
 
 
 def wg_local_layout(cols, rows=128):
@@ -577,6 +578,205 @@ def wg_local_layout(cols, rows=128):
     return TileLayout(S[(rows, cols) : (1 @ Axis.tid_in_wg, 1)])
 
 
+# Allowed (.shape, .num) combinations for tcgen05.ld/st atoms.
+# Source: PTX ISA Table 49 (tcgen05-num-shapes-ld).
+_TCGEN05_ATOM_REPS = {
+    "32x32b":  (1, 2, 4, 8, 16, 32, 64, 128),
+    "16x64b":  (1, 2, 4, 8, 16, 32, 64, 128),
+    "16x128b": (1, 2, 4, 8, 16, 32, 64),
+    "16x256b": (1, 2, 4, 8, 16, 32),
+}
+
+
+# Per-warp fp32-column factor for each instr_shape. For .16x*b atoms the
+# warpgroup fragment is 64 rows × (factor * rep) fp32 cols; for .32x32b the
+# fragment is 128 rows × (factor * rep) fp32 cols with factor=1.
+_TCGEN05_COL_FACTOR_FP32 = {"32x32b": 1, "16x64b": 2, "16x128b": 4, "16x256b": 
8}
+
+# Number of fragment rows per warpgroup for each instr_shape.
+_TCGEN05_FRAG_ROWS = {"32x32b": 128, "16x64b": 64, "16x128b": 64, "16x256b": 
64}
+
+
+def tcgen05_atom_layout(
+    instr_shape: str, tensor_shape: tuple[int, int], dtype
+) -> "TileLayout":
+    """Register-side ``TileLayout`` for ``tcgen05.ld``/``tcgen05.st`` 
``.16x*`` atoms.
+
+    Describes the per-warpgroup register tile that ``Tx.copy_async`` produces
+    when reading a TMEM fragment via ``tcgen05.{ld,st}.<instr_shape>.xN``.
+    ``rep`` (the ``.xN`` qualifier) is inferred from ``tensor_shape``.
+
+    Fragment row count is determined by ``instr_shape``: ``.32x32b`` covers an
+    M=128 fragment (128 rows per warpgroup), and ``.16x{64,128,256}b`` covers
+    an M=64 fragment (64 rows per warpgroup).
+
+    TMEM is kept **dense** for 16-bit dtypes: two 16-bit elements per 32-bit
+    TMEM cell (matching the existing ``.32x32b`` convention). The PTX op is
+    issued with the plain ``.b32`` form (no ``.pack::16b`` qualifier), and
+    the returned layout describes the per-thread register file with two
+    packed 16-bit elements per 32-bit register.
+
+    Parameters
+    ----------
+    instr_shape : str
+        The PTX atom's ``.shape`` qualifier. One of ``"32x32b"``, ``"16x64b"``,
+        ``"16x128b"``, ``"16x256b"``.
+    tensor_shape : tuple[int, int]
+        The logical fragment shape in **element units**. Must be
+        ``(frag_rows, K)`` where ``frag_rows`` is ``128`` for ``.32x32b`` and
+        ``64`` for the other shapes, and ``K`` is divisible by the per-warp
+        column factor for the chosen instr_shape and dtype::
+
+            K must be a power-of-two multiple of (factor_fp32 * elem_per_32b)
+
+        where ``factor_fp32`` is ``1`` / ``2`` / ``4`` / ``8`` for ``.32x32b`` 
/
+        ``.16x64b`` / ``.16x128b`` / ``.16x256b``, and ``elem_per_32b`` is
+        ``1`` for fp32 and ``2`` for fp16/bf16. The inferred rep must be in PTX
+        Table 49's supported set for the chosen instr_shape.
+    dtype : str | tvm.DataType
+        Element dtype. ``"float32"``, ``"float16"``, or ``"bfloat16"``.
+
+    Returns
+    -------
+    TileLayout
+        A ``(64, K)``-shaped tile layout. The factory builds it as a sequence
+        of fine-grained iters describing the per-(lane, register) destination
+        position; ``.group([(64, K)])[0]`` flattens to two iters.
+
+    Examples
+    --------
+    ``tcgen05_atom_layout("16x64b", (64, 64), "float32")`` → ``.16x64b.x32`` 
(rep=32, fp32).
+
+    ``tcgen05_atom_layout("16x128b", (64, 256), "float16")`` → 
``.16x128b.x32`` (rep=32,
+    fp16; two fp16 elements packed per 32-bit register and per 32-bit TMEM 
cell).
+    """
+    if instr_shape not in _TCGEN05_ATOM_REPS:
+        raise ValueError(
+            f"tcgen05_atom_layout instr_shape must be one of "
+            f"{list(_TCGEN05_ATOM_REPS)}, got {instr_shape!r}"
+        )
+    bits = tvm.runtime.DataType(dtype).bits
+    if bits not in (16, 32):
+        raise ValueError(
+            f"tcgen05_atom_layout dtype must be a 32-bit or 16-bit type, "
+            f"got {dtype} ({bits} bits)"
+        )
+    if len(tensor_shape) != 2:
+        raise ValueError(
+            f"tcgen05_atom_layout tensor_shape must be 2-D (rows, cols), got 
{tensor_shape!r}"
+        )
+    rows, cols = tensor_shape
+    expected_rows = _TCGEN05_FRAG_ROWS[instr_shape]
+    if rows != expected_rows:
+        raise ValueError(
+            f"tcgen05_atom_layout {instr_shape!r} expects 
rows={expected_rows}, got {rows}"
+        )
+
+    elem_per_32b = 32 // bits
+    col_factor_elem = _TCGEN05_COL_FACTOR_FP32[instr_shape] * elem_per_32b
+    if cols % col_factor_elem != 0:
+        raise ValueError(
+            f"tcgen05_atom_layout cols={cols} not divisible by the per-rep 
column "
+            f"factor {col_factor_elem} for instr_shape={instr_shape!r} 
dtype={dtype}; "
+            f"valid cols are k * {col_factor_elem} for k in "
+            f"{_TCGEN05_ATOM_REPS[instr_shape]}"
+        )
+    rep = cols // col_factor_elem
+    if rep not in _TCGEN05_ATOM_REPS[instr_shape]:
+        raise ValueError(
+            f"tcgen05_atom_layout inferred rep={rep} (from cols={cols}) is not 
in "
+            f"the PTX Table 49 supported set for {instr_shape}: "
+            f"{_TCGEN05_ATOM_REPS[instr_shape]}"
+        )
+
+    laneid = Axis.laneid
+    wid = Axis.wid_in_wg
+    N = rep
+    shape = instr_shape
+    # All m-strides below are written in fp32-reg units; we multiply by
+    # elem_per_32b at the end and prepend a C_pack iter for the 16-bit case
+    # (each fp32 reg packs ``elem_per_32b`` elements at adjacent col 
positions).
+
+    if shape == "32x32b":
+        # M=128 fragment, simple thread-rows layout:
+        #   (rows=128, cols=K) : (1@tid_in_wg, 1)
+        # Each of 128 warpgroup threads owns one row; cols are contiguous in
+        # the per-thread storage. For 16-bit dtypes the K cols are packed two
+        # per 32-bit register (handled by the per-thread storage element count
+        # naturally — m-stride 1 in element units).
+        iters = [
+            Iter(rows, 1, Axis.tid_in_wg),
+            Iter(cols, 1, "m"),
+        ]
+        return TileLayout.from_iters(iters, [], {})
+
+    if shape == "16x64b":
+        # Per-warp tile (fp32 view): (16 rows, 2N cols). Per-lane regs = N.
+        # Lane (t0, t1, t2): t0 = laneid & 1, t1 = (laneid >> 1) & 1, t2 = 
laneid >> 2.
+        #   Row = t2 + 8*t0 + 16*wid_in_wg
+        #   Col (fp32) = t1 + 2*r,   r ∈ [0, N)
+        row_iters_fp32 = [
+            (8, 4, laneid),    # R_t2: laneid bits 2..4 → R bits 0..2
+            (2, 1, laneid),    # R_t0: laneid bit 0    → R bit 3
+            (4, 1, wid),       # R_w:  wid_in_wg       → R bits 4..5
+        ]
+        col_iters_fp32 = [
+            (2, 2, laneid),    # C_t1: laneid bit 1    → C bit 0
+            (N, 1, "m"),       # C_r:  register slot   → C bits 1..
+        ]
+    elif shape == "16x128b":
+        # Per-warp tile (fp32 view): (16 rows, 4N cols). Per-lane regs = 2N.
+        # Lane (t0, t1): t0 = laneid & 3, t1 = laneid >> 2.
+        # Reg r = ra + 2*rb, ra ∈ {0,1}, rb ∈ [0, N).
+        #   Row = t1 + 8*ra + 16*wid_in_wg
+        #   Col (fp32) = t0 + 4*rb
+        row_iters_fp32 = [
+            (8, 4, laneid),    # R_t1: laneid bits 2..4 → R bits 0..2
+            (2, 1, "m"),       # R_ra: reg bit 0        → R bit 3
+            (4, 1, wid),       # R_w
+        ]
+        col_iters_fp32 = [
+            (4, 1, laneid),    # C_t0: laneid bits 0..1 → C bits 0..1
+            (N, 2, "m"),       # C_rb: reg bits 1..     → C bits 2..
+        ]
+    else:  # 16x256b
+        # Per-warp tile (fp32 view): (16 rows, 8N cols). Per-lane regs = 4N.
+        # Lane (t0, t1) as for 16x128b. Reg r = v0p + 2*va + 4*vb.
+        #   Row = t1 + 8*va + 16*wid_in_wg
+        #   Col (fp32) = v0p + 2*t0 + 8*vb
+        row_iters_fp32 = [
+            (8, 4, laneid),    # R_t1
+            (2, 2, "m"),       # R_va: reg bit 1 → R bit 3
+            (4, 1, wid),       # R_w
+        ]
+        col_iters_fp32 = [
+            (2, 1, "m"),       # C_v0p: reg bit 0  → C bit 0
+            (4, 1, laneid),    # C_t0
+            (N, 4, "m"),       # C_vb:  reg bits 2.. → C bits 3..
+        ]
+
+    def _scale(iters):
+        out = []
+        for ext, stride, axis in iters:
+            if axis == "m":
+                out.append((ext, stride * elem_per_32b, axis))
+            else:
+                out.append((ext, stride, axis))
+        return out
+
+    row_iters = _scale(row_iters_fp32)
+    col_iters = _scale(col_iters_fp32)
+
+    # For the 16-bit packed variant each fp32 register holds two adjacent
+    # column elements (low / high halves), so we prepend a C_pack iter of
+    # extent ``elem_per_32b`` and m-stride 1 to the column group.
+    if elem_per_32b > 1:
+        col_iters = [(elem_per_32b, 1, "m"), *col_iters]
+
+    iters = [Iter(ext, stride, axis) for ext, stride, axis in row_iters + 
col_iters]
+    return TileLayout.from_iters(iters, [], {})
+
+
 # ------------------------------------------------------------------
 # Helper types to support `PrimExpr @ Axis` and `sum` for offsets
 # ------------------------------------------------------------------
diff --git 
a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py 
b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py
index 4700d4e0da..345eef4519 100644
--- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py
+++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py
@@ -27,7 +27,7 @@ from tvm.arith import Analyzer
 from tvm.runtime import DataType
 from tvm.script import tirx as Tx
 from tvm.tirx import Buffer, PrimFunc
-from tvm.tirx.layout import S, TCol, TileLayout, TLane, tid_in_wg
+from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout, 
tid_in_wg
 from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, 
register_dispatch
 from tvm.tirx.stmt import TilePrimitiveCall
 
@@ -35,6 +35,42 @@ from ..common import get_st_extent
 from ..copy import _is_valid_copy, _scope_allowed
 from ..exec_scope_utils import exec_scope_ok
 
+# Per-warp fp32-column factor for each instr_shape (mirrors
+# ``_TCGEN05_COL_FACTOR_FP32`` in ``tvm.tirx.layout``; .16x64b → 2,
+# .16x128b → 4, .16x256b → 8). Source: PTX ISA Table 49.
+_TCGEN05_COL_FACTOR_FP32 = {"16x64b": 2, "16x128b": 4, "16x256b": 8}
+
+
+def _match_tcgen05_atom_layout(buf):
+    """Return ``(instr_shape, rep)`` if ``buf.layout`` matches a tcgen05
+    ``.16x*b`` atom layout for some supported ``instr_shape``.
+
+    The local buffer shape ``(64, K)`` together with the dtype determines the
+    candidate ``rep`` for each ``instr_shape``; we just probe the three shapes
+    and structurally compare. ``None`` if no atom layout matches.
+    """
+    if len(buf.shape) != 2:
+        return None
+    rows, cols = int(buf.shape[0]), int(buf.shape[1])
+    if rows != 64:
+        return None
+    dtype = buf.dtype
+    layout_c = buf.layout.canonicalize()
+    for shape in _TCGEN05_COL_FACTOR_FP32:
+        try:
+            cand = tcgen05_atom_layout(shape, (rows, cols), 
dtype).canonicalize()
+        except ValueError:
+            continue
+        try:
+            tvm.ir.assert_structural_equal(layout_c, cand)
+        except (AssertionError, ValueError):
+            continue
+        # Recover rep from cols (same arithmetic the factory uses).
+        elem_per_32b = 32 // DataType(dtype).bits
+        rep = cols // (_TCGEN05_COL_FACTOR_FP32[shape] * elem_per_32b)
+        return shape, rep
+    return None
+
 
 def copy_tmem_local_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> 
PrimFunc | None:
     op_call = TilePrimitiveCall.downcast(op_call)
@@ -56,11 +92,51 @@ def copy_tmem_local_impl(op_call: TilePrimitiveCall, sctx: 
DispatchContext) -> P
     assert tmem_buf.layout is not None
     assert local_buf.layout is not None
     assert tmem_buf.dtype == local_buf.dtype
+    assert tmem_buf.allocated_addr is not None
 
     analyzer = Analyzer()
     elem_size = DataType(local_buf.dtype).bits
     elem_per_32b = 32 // elem_size
     assert len(local_buf.shape) == len(tmem_buf.shape) == 2
+
+    # Try the .16x* (M=64) path first by structural-matching the register-side
+    # layout against ``tcgen05_atom_layout(instr_shape, (64, K), dtype)``. The
+    # TMEM-side layout is the standard (128, W):(1@TLane, 1@TCol); the M=64
+    # fragment lives at lanes 0..15 of each warp's accessible slab (per PTX
+    # 9.7.16.8.1), so each warp issues with row_offset=0 and collectively the
+    # 4 warps cover all 64 rows.
+    atom_match = _match_tcgen05_atom_layout(local_buf)
+
+    if atom_match is not None:
+        shape, num = atom_match
+        return _emit_16xnb_path(
+            shape=shape,
+            num=num,
+            direction=direction,
+            tmem_buf=tmem_buf,
+            local_buf=local_buf,
+            tmem_region=tmem_region,
+            local_region=local_region,
+            elem_per_32b=elem_per_32b,
+            analyzer=analyzer,
+        )
+
+    # Fall through to the existing .32x32b (M=128) path.
+    return _emit_32x32b_path(
+        direction=direction,
+        tmem_buf=tmem_buf,
+        local_buf=local_buf,
+        tmem_region=tmem_region,
+        local_region=local_region,
+        elem_per_32b=elem_per_32b,
+        analyzer=analyzer,
+    )
+
+
+def _emit_32x32b_path(
+    *, direction, tmem_buf, local_buf, tmem_region, local_region, 
elem_per_32b, analyzer
+) -> PrimFunc:
+    """Original M=128 fragment path using ``tcgen05.{ld,st}.32x32b.xN``."""
     # local: 128xWIDTH <-> tmem: 128xSHAPE[1]
     assert analyzer.can_prove_equal(local_buf.shape[0], 128)
     assert analyzer.can_prove_equal(tmem_buf.shape[0], 128)
@@ -87,10 +163,7 @@ def copy_tmem_local_impl(op_call: TilePrimitiveCall, sctx: 
DispatchContext) -> P
     # local layout
     TileLayout(S[(128, width) : (1 @ tid_in_wg, 1)]).canonicalize()
 
-    # tmem allocated addr is not None
-    assert tmem_buf.allocated_addr is not None
     tvm.ir.assert_structural_equal(tmem_buf.layout.canonicalize(), tmem_layout)
-    # tvm.ir.assert_structural_equal(local_buf.layout.canonicalize(), 
local_layout)
     # local: [0:128, 0:WIDTH] <-> tmem: [0:128, st:st+WIDTH]
     assert analyzer.can_prove_equal(tmem_st[0], 0)
     assert analyzer.can_prove_equal(tmem_extent[0], 128)
@@ -121,6 +194,103 @@ def copy_tmem_local_impl(op_call: TilePrimitiveCall, 
sctx: DispatchContext) -> P
     return impl
 
 
+def _emit_16xnb_path(
+    *,
+    shape,
+    num,
+    direction,
+    tmem_buf,
+    local_buf,
+    tmem_region,
+    local_region,
+    elem_per_32b,
+    analyzer,
+) -> PrimFunc:
+    """M=64 fragment path using ``tcgen05.{ld,st}.<shape>.x<num>`` (one of
+    ``.16x64b``, ``.16x128b``, ``.16x256b``).
+
+    Each of the warpgroup's 4 warps issues the atom once with
+    ``row_offset=0``; the PTX TMEM access restriction places warp ``i`` on
+    TMEM lanes ``i*32..i*32+31``, of which the atom uses the first 16 to
+    cover one 16-row slab of the 64-row fragment. Collectively, the four
+    warps cover all 64 rows.
+    """
+    # Per-atom column footprint in fp32 columns:
+    #   .16x64b  → 2N    .16x128b → 4N    .16x256b → 8N
+    col_factor_fp32 = {"16x64b": 2, "16x128b": 4, "16x256b": 8}[shape]
+    # Per-thread register count (in 32-bit units):
+    #   .16x64b.xN  → N        .16x128b.xN → 2N      .16x256b.xN → 4N
+    regs_per_thread = {"16x64b": num, "16x128b": 2 * num, "16x256b": 4 * 
num}[shape]
+    # Logical column width that the local buffer view exposes (in element 
units).
+    width_elems = col_factor_fp32 * num * elem_per_32b
+    # Per-thread storage in element units (same total bits as the register 
vector).
+    per_thread_elems = regs_per_thread * elem_per_32b
+
+    # Local-side: shape (64, K_cols)
+    assert analyzer.can_prove_equal(local_buf.shape[0], 64), (
+        f".16x*b path expects local_buf rows=64, got {local_buf.shape[0]}"
+    )
+    assert analyzer.can_prove_equal(local_buf.shape[1], width_elems), (
+        f".16x*b path expects local_buf cols={width_elems}, got 
{local_buf.shape[1]}"
+    )
+
+    # TMEM-side: shape (128, W); the M=64 fragment occupies the first 16 lanes 
of
+    # each warp's 32-lane slab.
+    assert analyzer.can_prove_equal(tmem_buf.shape[0], 128), (
+        f".16x*b path expects tmem_buf rows=128, got {tmem_buf.shape[0]}"
+    )
+    tmem_layout = TileLayout(S[(128, tmem_buf.shape[1]) : (1 @ TLane, 1 @ 
TCol)]).canonicalize()
+    tvm.ir.assert_structural_equal(tmem_buf.layout.canonicalize(), tmem_layout)
+
+    tmem_st, tmem_extent = get_st_extent(tmem_region)
+    local_st, local_extent = get_st_extent(local_region)
+
+    # Local slice must be the full (64, K_cols) view.
+    assert analyzer.can_prove_equal(local_st[0], 0)
+    assert analyzer.can_prove_equal(local_extent[0], 64)
+    assert analyzer.can_prove_equal(local_extent[1], width_elems)
+
+    # TMEM slice must start at row 0 (warp 0 of the WG is at lane 0) and span
+    # 64 rows (collectively the 4 warps' first 16-lane chunks).
+    assert analyzer.can_prove_equal(tmem_st[0], 0)
+    assert analyzer.can_prove_equal(tmem_extent[0], 64)
+    assert analyzer.can_prove_equal(tmem_extent[1], width_elems)
+
+    col_off = tmem_st[1]
+    assert analyzer.can_prove_equal(tvm.tirx.floormod(col_off, elem_per_32b), 
0)
+    col_off_32b = tvm.tirx.floordiv(col_off, elem_per_32b)
+    local_col_off = local_st[1]
+    assert analyzer.can_prove_equal(tvm.tirx.floormod(local_col_off, 
elem_per_32b), 0)
+    local_col_off_elems = local_col_off
+
+    is_load = direction == "tmem2local"
+    op = Tx.ptx.tcgen05.ld if is_load else Tx.ptx.tcgen05.st
+    # We intentionally do *not* emit ``.pack::16b`` / ``.unpack::16b`` for
+    # 16-bit dtypes. That qualifier would store one 16-bit element per 32-bit
+    # TMEM cell (LOW half only, HIGH half wasted) — fine for some CUTLASS
+    # epilogues but a 2x TMEM waste vs. the existing ``.32x32b`` convention,
+    # which packs two 16-bit elements per cell. By using the plain ``.b32``
+    # form we keep TMEM dense (2 elements per 32-bit cell); the per-thread
+    # register file holds two packed 16-bit values per 32-bit register, and
+    # the layout factory's iters describe that packing.
+
+    # fmt: off
+    @Tx.prim_func(check_well_formed=False)
+    def impl():
+        with Tx.warp():
+            # Per-thread 1-D flat view of the local storage, then a uint32 view
+            # for the register-pointer arguments of the PTX builtin.
+            local_storage = local_buf.view(per_thread_elems, 
layout=TileLayout(S[per_thread_elems]))  # noqa: E501
+            local_32b = local_storage.view("uint32")
+            op(
+                tmem_buf.allocated_addr[0],
+                *[local_32b[local_col_off_elems // elem_per_32b + i] for i in 
range(regs_per_thread)],  # noqa: E501
+                shape=shape, num=num, row=0, col=col_off_32b,
+            )
+    # fmt: on
+    return impl
+
+
 # === Variant: copy_async/tmem<->local (priority=10) ===
 #
 # When: one buffer is in tmem (tensor memory, Blackwell SM100+) and the other
diff --git a/python/tvm/tirx/script/builder/ir.py 
b/python/tvm/tirx/script/builder/ir.py
index da24e71a7d..0c6b8607b8 100644
--- a/python/tvm/tirx/script/builder/ir.py
+++ b/python/tvm/tirx/script/builder/ir.py
@@ -1868,6 +1868,62 @@ smem = alloc_shared
 tmem = functools.partial(alloc_buffer, scope="tmem")
 
 
+def alloc_tcgen05_frag(instr_shape, tensor_shape, dtype):
+    """Allocate a register fragment for ``tcgen05.{ld,st}`` atoms.
+
+    Sizes the per-thread storage, allocates ``local`` scope memory, and returns
+    a 2-D view of shape ``tensor_shape`` with a matching 
``tcgen05_atom_layout``.
+    Pass the result to ``Tx.copy_async`` (with a ``(128, W)``-shaped TMEM
+    buffer) to trigger the corresponding dispatch path.
+
+    Parameters
+    ----------
+    instr_shape : str
+        ``"32x32b"`` (M=128 fragment, 128 row warpgroup tile, layout
+        ``(128, K):(1@tid_in_wg, 1)``); or ``"16x64b"`` / ``"16x128b"`` /
+        ``"16x256b"`` (M=64 fragments, 64 row warpgroup tile with the
+        per-shape per-lane register decomposition).
+    tensor_shape : tuple[int, int]
+        Logical fragment shape ``(frag_rows, K)`` in element units. 
``frag_rows``
+        is ``128`` for ``.32x32b`` and ``64`` for the ``.16x*b`` shapes.
+    dtype : str
+        ``"float32"``, ``"float16"``, or ``"bfloat16"``.
+
+    Returns
+    -------
+    Buffer
+        2-D view of shape ``tensor_shape`` whose layout matches
+        ``tcgen05_atom_layout(instr_shape, tensor_shape, dtype)``.
+
+    Examples
+    --------
+    M=128 readback (existing dispatch):
+        ``frag = Tx.alloc_tcgen05_frag("32x32b", (128, 64), "float32")``
+        ``Tx.copy_async(frag[:, :], tmem[:, 0:64])``
+
+    M=64 readback (.16x64b dispatch):
+        ``frag = Tx.alloc_tcgen05_frag("16x64b", (64, 64), "float32")``
+        ``Tx.copy_async(frag[:, :], tmem[0:64, 0:64])``
+    """
+    from tvm.tirx.layout import tcgen05_atom_layout  # local import to avoid 
cycle
+
+    rows, cols = tensor_shape
+    bits = DataType(dtype).bits
+    # Per-warpgroup total bits = 64 rows × K cols × bits. Divided across 128
+    # threads gives per-thread bits; convert to element count.
+    per_thread_bits = (rows * cols * bits) // 128
+    if per_thread_bits % bits != 0:
+        raise ValueError(
+            f"alloc_tcgen05_frag tensor_shape={tensor_shape} dtype={dtype!r} "
+            f"does not evenly divide across 128 threads"
+        )
+    per_thread_elems = per_thread_bits // bits
+
+    layout = tcgen05_atom_layout(instr_shape, tensor_shape, dtype)
+    flat = alloc_local((per_thread_elems,), dtype)
+    return flat.view(rows, cols, layout=layout)
+
+
 if TYPE_CHECKING:
     ScalarT = TypeVar("ScalarT")
 
@@ -4021,6 +4077,7 @@ __all__ += [
     "alloc_local",
     "alloc_scalar",
     "alloc_shared",
+    "alloc_tcgen05_frag",
     "cluster",
     "cluster_id",
     "cta",
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
new file mode 100644
index 0000000000..09a15f3bd1
--- /dev/null
+++ 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
@@ -0,0 +1,709 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements.  See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership.  The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License.  You may obtain a copy of the License at
+#
+#   http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied.  See the License for the
+# specific language governing permissions and limitations
+# under the License.
+# pylint: disable=invalid-name, missing-function-docstring
+"""Bit-exact tests for the ``.16x{64,128,256}b`` ``tcgen05.{ld,st}`` dispatch.
+
+For each ``(shape, rep, dtype, direction)`` we:
+
+1. Fill a (128, FULL_W) host buffer ``A`` with random values.
+2. Stage ``A`` into TMEM via the existing ``.32x32b`` ld/st round-trip.
+3. Issue the new ``.16x*b`` atom via ``Tx.copy_async`` to read a (64, K_cols)
+   fragment from TMEM into a register tile shaped by ``tcgen05_atom_layout``.
+4. Dump the register tile to a ``(128, regs_per_thread)`` global buffer indexed
+   ``B[tid_in_wg, r]``.
+5. Reconstruct the expected ``B[t, r]`` on the host from the per-(lane, reg) →
+   (frag_row, frag_col) formula. The M=64 fragment occupies TMEM lanes
+   ``warp_id * 32 + (0..15)``, so ``frag_row R`` maps to TMEM lane
+   ``(R // 16) * 32 + (R % 16)``.
+
+For the store direction we run the inverse: prefill the register tile via host 
→
+``B`` → ``.32x32b.ld``-staged read, write to TMEM via the new ``.16x*b.st``,
+then read TMEM back via ``.32x32b.ld`` into a (128, FULL_W) buffer and check
+that the M=64 fragment's row positions hold the expected register data.
+"""
+
+import numpy as np
+import pytest
+
+import tvm
+import tvm.testing
+from tvm.script import tirx as Tx
+from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout
+from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg
+
+
+# --------------------------------------------------------------------------
+# Shape metadata + host-side layout reconstruction
+# --------------------------------------------------------------------------
+
+# (.shape, .num) ranges supported by PTX Table 49.
+_SHAPE_REPS = {
+    "32x32b": (1, 2, 4, 8, 16, 32, 64, 128),
+    "16x64b": (1, 2, 4, 8, 16, 32, 64, 128),
+    "16x128b": (1, 2, 4, 8, 16, 32, 64),
+    "16x256b": (1, 2, 4, 8, 16, 32),
+}
+
+# Per-warp fp32 column span = factor * rep.
+_COL_FACTOR_FP32 = {"32x32b": 1, "16x64b": 2, "16x128b": 4, "16x256b": 8}
+
+# Per-thread 32-bit register count = factor * rep.
+_REGS_FACTOR = {"32x32b": 1, "16x64b": 1, "16x128b": 2, "16x256b": 4}
+
+# Per-warpgroup fragment row count.
+_FRAG_ROWS = {"32x32b": 128, "16x64b": 64, "16x128b": 64, "16x256b": 64}
+
+
+def _decompose_fp32(shape: str, t: int, r: int) -> tuple[int, int]:
+    """Return ``(frag_row, frag_col)`` in fp32 element units for the fp32 
atom."""
+    laneid = t & 31
+    wid_in_wg = t >> 5
+    if shape == "32x32b":
+        # M=128 fragment: each thread t owns full row t with N consecutive 
cols.
+        row = t
+        col = r
+    elif shape == "16x64b":
+        t0 = laneid & 1
+        t1 = (laneid >> 1) & 1
+        t2 = laneid >> 2
+        row = t2 + 8 * t0 + 16 * wid_in_wg
+        col = t1 + 2 * r
+    elif shape == "16x128b":
+        t0 = laneid & 3
+        t1 = laneid >> 2
+        ra = r & 1
+        rb = r >> 1
+        row = t1 + 8 * ra + 16 * wid_in_wg
+        col = t0 + 4 * rb
+    elif shape == "16x256b":
+        t0 = laneid & 3
+        t1 = laneid >> 2
+        v0p = r & 1
+        va = (r >> 1) & 1
+        vb = r >> 2
+        row = t1 + 8 * va + 16 * wid_in_wg
+        col = v0p + 2 * t0 + 8 * vb
+    else:
+        raise ValueError(shape)
+    return row, col
+
+
+def _frag_row_to_tmem_lane(shape: str, R: int) -> int:
+    """Map fragment row R to its physical TMEM lane.
+
+    For ``.32x32b`` (M=128) the mapping is identity: row R lives at TMEM lane 
R.
+    For ``.16x*b`` (M=64) the fragment occupies the first 16 lanes of each
+    warp's 32-lane slab, so ``R`` ∈ [0, 64) lives at lane ``(R // 16) * 32 + 
(R % 16)``.
+    """
+    if shape == "32x32b":
+        return R
+    return (R // 16) * 32 + (R % 16)
+
+
+def _expected_reg_value_fp32(
+    A: np.ndarray, shape: str, rep: int, tmem_col_off: int, t: int, r: int
+) -> np.uint32:
+    """fp32 path: return the bit-pattern (as uint32) that thread ``t`` register
+    ``r`` should hold after ``.<shape>.x<rep>`` reads ``A`` (staged into TMEM) 
at
+    column offset ``tmem_col_off``."""
+    row, col = _decompose_fp32(shape, t, r)
+    tmem_lane = _frag_row_to_tmem_lane(shape, row)
+    val = np.float32(A[tmem_lane, tmem_col_off + col])
+    return val.view(np.uint32)
+
+
+def _expected_reg_value_16b(
+    A: np.ndarray, shape: str, rep: int, tmem_col_off: int, t: int, r: int, 
dtype_np
+) -> np.uint32:
+    """16-bit path (fp16 / bf16 with .pack::16b): each fp32 register packs two
+    16-bit elements at adjacent columns ``(2*col_fp32, 2*col_fp32 + 1)``."""
+    row, col_fp32 = _decompose_fp32(shape, t, r)
+    tmem_lane = _frag_row_to_tmem_lane(shape, row)
+    lo = dtype_np(A[tmem_lane, tmem_col_off + 2 * col_fp32])
+    hi = dtype_np(A[tmem_lane, tmem_col_off + 2 * col_fp32 + 1])
+    lo_u16 = lo.view(np.uint16)
+    hi_u16 = hi.view(np.uint16)
+    return np.uint32(int(lo_u16) | (int(hi_u16) << 16))
+
+
+# --------------------------------------------------------------------------
+# Test 1: load direction
+# --------------------------------------------------------------------------
+
+
[email protected]("shape", list(_SHAPE_REPS))
[email protected]("rep", [1, 2, 4, 8, 16, 32])  # subset; full reps 
below
[email protected]("dtype", ["float32"])
+def test_tcgen05_ld_16xnb_load_fp32(shape, rep, dtype):
+    """Bit-exact verification of ``tcgen05.<shape>.x<rep>.b32`` load."""
+    if rep not in _SHAPE_REPS[shape]:
+        pytest.skip(f"rep {rep} not valid for {shape}")
+    _run_load_test(shape, rep, dtype)
+
+
[email protected](
+    "shape, rep",
+    [
+        ("16x64b", 64),
+        ("16x64b", 128),
+        ("16x128b", 64),
+    ],
+)
+def test_tcgen05_ld_16xnb_load_fp32_large_rep(shape, rep):
+    """High-rep entries that aren't in the parametrize-cross above."""
+    _run_load_test(shape, rep, "float32")
+
+
[email protected]("shape", list(_SHAPE_REPS))
[email protected]("rep", [1, 2, 4, 8, 16, 32])
[email protected]("dtype", ["float16", "bfloat16"])
+def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype):
+    """Self-consistent round-trip for 16-bit pack::16b path.
+
+    The fp32 ``test_tcgen05_ld_16xnb_load_fp32`` already validates the
+    ``(lane, reg) → (frag_row, frag_col)`` mapping bit-exactly against the
+    standard ``.32x32b`` staging. For the 16-bit case the staging convention
+    differs (``.32x32b.st`` packs two fp16 per 32-bit TMEM cell, whereas
+    ``.16x*b.ld.pack::16b`` reads two fp16 from the LOW halves of adjacent
+    32-bit cells), so we instead verify the new dispatch round-trips
+    per-thread data via ``.16x*b.st.unpack::16b`` → ``.16x*b.ld.pack::16b``.
+    A bit-exact round-trip is sufficient evidence that the per-thread
+    register-layout matches between the load and store atom families.
+    """
+    if rep not in _SHAPE_REPS[shape]:
+        pytest.skip(f"rep {rep} not valid for {shape}")
+    _run_roundtrip_16b(shape, rep, dtype)
+
+
+def _run_roundtrip_16b(shape: str, rep: int, dtype: str):
+    bits = tvm.runtime.DataType(dtype).bits
+    assert bits == 16
+    elem_per_32b = 2
+    K_cols_fp32 = _COL_FACTOR_FP32[shape] * rep
+    K_cols_elem = K_cols_fp32 * elem_per_32b
+    regs_per_thread = _REGS_FACTOR[shape] * rep
+    per_thread_elems = regs_per_thread * elem_per_32b
+    frag_rows = _FRAG_ROWS[shape]
+
+    # The 16-bit round-trip writes and reads exclusively through .16x*b atoms,
+    # so the TMEM column footprint is whatever ``K_cols_fp32`` says — no
+    # .32x32b staging constraint applies here.
+    tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32))
+    stage_width_elem = tmem_col_width_32b * elem_per_32b
+    atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype)
+
+    @Tx.prim_func
+    def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None:
+        # Per-thread input/output: A[tid_in_wg, i] feeds register slot i of the
+        # warpgroup-collective fragment; B[tid_in_wg, i] is what comes back
+        # after a .16x*b.st → .16x*b.ld round-trip.
+        A = Tx.match_buffer(A_ptr, (128, per_thread_elems), dtype)
+        B = Tx.match_buffer(B_ptr, (128, per_thread_elems), dtype)
+
+        Tx.device_entry()
+        warp_id = Tx.warp_id([128 // 32])
+        Tx.cta_id([2])
+        wg_id = Tx.warpgroup_id([1])
+        Tx.warp_id_in_wg([4])
+        Tx.lane_id([32])
+        tid_in_wg = Tx.thread_id([128])
+
+        tmem_addr = Tx.alloc_shared([1], "uint32")
+
+        if wg_id == 0:
+            with Tx.warpgroup():
+                if warp_id == 0:
+                    with Tx.warp():
+                        Tx.ptx.tcgen05.alloc(
+                            Tx.address_of(tmem_addr),
+                            n_cols=tmem_col_width_32b,
+                            cta_group=1,
+                        )
+
+                Tx.tvm_storage_sync("shared")
+
+                tmem = Tx.decl_buffer(
+                    (128, stage_width_elem),
+                    dtype,
+                    scope="tmem",
+                    allocated_addr=tmem_addr[0],
+                    layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 
1 @ TCol)]),
+                )
+
+                # Load per-thread A → reg_in
+                reg_in = Tx.alloc_local((per_thread_elems,), dtype)
+                with Tx.thread():
+                    for i in range(per_thread_elems):
+                        reg_in[i] = A[tid_in_wg, i]
+                Tx.cuda.cta_sync()
+
+                # reg_in -> TMEM via .<shape>.x<rep>.st.unpack::16b
+                frag_in = reg_in.view(frag_rows, K_cols_elem, layout=atom_view)
+                Tx.copy_async(tmem[0:frag_rows, 0:K_cols_elem], frag_in[:, :])
+                Tx.ptx.tcgen05.wait.st()
+                Tx.cuda.cta_sync()
+
+                # TMEM -> reg_out via .<shape>.x<rep>.ld.pack::16b
+                reg_out = Tx.alloc_local((per_thread_elems,), dtype)
+                frag_out = reg_out.view(frag_rows, K_cols_elem, 
layout=atom_view)
+                Tx.copy_async(frag_out[:, :], tmem[0:frag_rows, 0:K_cols_elem])
+                Tx.ptx.tcgen05.wait.ld()
+                Tx.cuda.cta_sync()
+
+                # reg_out -> B
+                with Tx.thread():
+                    for i in range(per_thread_elems):
+                        B[tid_in_wg, i] = reg_out[i]
+
+                if warp_id == 0:
+                    with Tx.warp():
+                        Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1)
+                        Tx.ptx.tcgen05.dealloc(
+                            tmem_addr[0], n_cols=tmem_col_width_32b, 
cta_group=1
+                        )
+
+    target = tvm.target.Target("cuda")
+    with target:
+        mod = tvm.IRModule({"main": kernel})
+        mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
+        A_np = tvm.testing.generate_random_array(dtype, (128, 
per_thread_elems))
+        B_np = np.zeros((128, per_thread_elems), dtype=dtype)
+        DEV = tvm.cuda(0)
+        A = tvm.runtime.tensor(A_np, DEV)
+        B = tvm.runtime.tensor(B_np, DEV)
+        mod(A, B)
+        # Round-trip should preserve every per-thread bit pattern.
+        A_view = A.numpy().view(np.uint16)
+        B_view = B.numpy().view(np.uint16)
+        np.testing.assert_array_equal(B_view, A_view)
+
+
+def _next_pow2(x: int) -> int:
+    if x <= 1:
+        return 1
+    return 1 << (x - 1).bit_length()
+
+
+def _run_load_test(shape: str, rep: int, dtype: str):
+    """Stage A into TMEM via .32x32b, then read it back as the fragment via
+    .<shape>.x<rep> (through ``Tx.alloc_tcgen05_frag``), and compare each
+    thread's registers against the expected layout-derived value."""
+    bits = tvm.runtime.DataType(dtype).bits
+    elem_per_32b = 32 // bits
+    # Per-warp fp32 col span × number of warps in one warpgroup covers the
+    # fragment column footprint. The TMEM allocation is sized for the same
+    # element-column count.
+    K_cols_fp32 = _COL_FACTOR_FP32[shape] * rep
+    K_cols_elem = K_cols_fp32 * elem_per_32b
+    regs_per_thread = _REGS_FACTOR[shape] * rep  # 32-bit register count
+    per_thread_elems = regs_per_thread * elem_per_32b
+    frag_rows = _FRAG_ROWS[shape]
+
+    tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32))
+
+    # Staging via .32x32b caps at num=128 (= 128 fp32 cols) per atom call. For
+    # configs whose K_cols_fp32 exceeds 128 we split the stage into multiple
+    # chunks of CHUNK_FP32 fp32 cols each.
+    CHUNK_FP32 = 128
+    chunk_elem = CHUNK_FP32 * elem_per_32b
+    num_chunks = tmem_col_width_32b // CHUNK_FP32 if tmem_col_width_32b > 
CHUNK_FP32 else 1
+    chunk_width_32b = tmem_col_width_32b if num_chunks == 1 else CHUNK_FP32
+    chunk_width_elem = chunk_width_32b * elem_per_32b
+    stage_width_elem = tmem_col_width_32b * elem_per_32b
+
+    # Vector length for global<->local copies (in elements).
+    VEC_LEN = 128 // bits
+    if stage_width_elem % VEC_LEN != 0:
+        pytest.skip(f"stage_width_elem {stage_width_elem} % VEC_LEN {VEC_LEN} 
!= 0")
+
+    g_layout = TileLayout(
+        S[(128, stage_width_elem // VEC_LEN, VEC_LEN) : (stage_width_elem, 
VEC_LEN, 1)]
+    )
+    chunk_view = TileLayout(S[(128, chunk_width_elem) : (1 @ axis_tid_in_wg, 
1)])
+    # The factory + wrapper both go through ``tcgen05_atom_layout``; we use it
+    # explicitly here so that ``frag_local`` has the canonical layout that
+    # ``Tx.copy_async`` matches when dispatching to the right atom path.
+    atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype)
+
+    @Tx.prim_func
+    def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None:
+        # A is the host data we stage into TMEM via the standard .32x32b path.
+        A = Tx.match_buffer(A_ptr, (128, stage_width_elem), dtype)
+        # B is a per-thread register dump: B[tid_in_wg, reg_idx_in_elements].
+        B = Tx.match_buffer(B_ptr, (128, per_thread_elems), dtype)
+
+        A_flat = A.view(-1)
+
+        Tx.device_entry()
+        warp_id = Tx.warp_id([128 // 32])
+        Tx.cta_id([2])
+        wg_id = Tx.warpgroup_id([1])
+        Tx.warp_id_in_wg([4])
+        Tx.lane_id([32])
+        tid_in_wg = Tx.thread_id([128])
+
+        tmem_addr = Tx.alloc_shared([1], "uint32")
+
+        if wg_id == 0:
+            with Tx.warpgroup():
+                if warp_id == 0:
+                    with Tx.warp():
+                        Tx.ptx.tcgen05.alloc(
+                            Tx.address_of(tmem_addr),
+                            n_cols=tmem_col_width_32b,
+                            cta_group=1,
+                        )
+
+                Tx.tvm_storage_sync("shared")
+
+                tmem = Tx.decl_buffer(
+                    (128, stage_width_elem),
+                    dtype,
+                    scope="tmem",
+                    allocated_addr=tmem_addr[0],
+                    layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 
1 @ TCol)]),
+                )
+
+                # Per-thread chunk staging buffer (CHUNK_FP32 fp32 worth).
+                stage_reg = Tx.alloc_local((chunk_width_elem,), dtype)
+                stage_local = stage_reg.view(128, chunk_width_elem, 
layout=chunk_view)
+
+                # Walk chunks: A[:, ck:ck+chunk] -> stage_reg -> TMEM[:, 
ck:ck+chunk]
+                for chunk_idx in range(num_chunks):
+                    col_off_elem = chunk_idx * chunk_width_elem
+                    with Tx.thread():
+                        for i in range(chunk_width_elem // VEC_LEN):
+                            # Each thread's row offset in A_flat: 
stage_width_elem; within
+                            # the row, this chunk starts at col_off_elem and 
each vector
+                            # picks up VEC_LEN elements at slot i.
+                            g_offset = Tx.meta_var(
+                                tid_in_wg * stage_width_elem
+                                + col_off_elem
+                                + i * VEC_LEN
+                            )
+                            Tx.copy(
+                                stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN],
+                                A_flat[g_offset : g_offset + VEC_LEN],
+                            )
+                    Tx.cuda.cta_sync()
+                    Tx.copy_async(
+                        tmem[:, col_off_elem : col_off_elem + 
chunk_width_elem],
+                        stage_local[:, :],
+                    )
+                Tx.ptx.tcgen05.wait.st()
+                Tx.cuda.cta_sync()
+
+                # TMEM[0:frag_rows, 0:K_cols] -> frag_local via 
.<shape>.x<rep>.ld.
+                # Use ``tcgen05_atom_layout`` so dispatch matches the new path
+                # (or stays on .32x32b for instr_shape="32x32b"). Keep the flat
+                # ``frag_reg`` for the per-thread dump below.
+                frag_reg = Tx.alloc_local((per_thread_elems,), dtype)
+                frag_local = frag_reg.view(frag_rows, K_cols_elem, 
layout=atom_view)
+                Tx.copy_async(frag_local[:, :], tmem[0:frag_rows, 
0:K_cols_elem])
+                Tx.ptx.tcgen05.wait.ld()
+                Tx.cuda.cta_sync()
+
+                # Dump per-thread regs to B[tid_in_wg, :]
+                with Tx.thread():
+                    for i in range(per_thread_elems):
+                        B[tid_in_wg, i] = frag_reg[i]
+
+                if warp_id == 0:
+                    with Tx.warp():
+                        Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1)
+                        Tx.ptx.tcgen05.dealloc(
+                            tmem_addr[0], n_cols=tmem_col_width_32b, 
cta_group=1
+                        )
+
+    target = tvm.target.Target("cuda")
+    with target:
+        mod = tvm.IRModule({"main": kernel})
+        mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
+        A_np = tvm.testing.generate_random_array(dtype, (128, 
stage_width_elem))
+        B_np = np.zeros((128, per_thread_elems), dtype=dtype)
+        DEV = tvm.cuda(0)
+        A = tvm.runtime.tensor(A_np, DEV)
+        B = tvm.runtime.tensor(B_np, DEV)
+        mod(A, B)
+        B_out = B.numpy()
+
+    # Build expected B_out from the layout.
+    if bits == 32:
+        # Each register slot in B[t, r] holds a single fp32; compare 
bit-exactly.
+        B_expected = np.zeros((128, per_thread_elems), dtype=np.uint32)
+        for t in range(128):
+            for r in range(regs_per_thread):
+                B_expected[t, r] = _expected_reg_value_fp32(A_np, shape, rep, 
0, t, r)
+        B_view = B_out.view(np.uint32)
+        np.testing.assert_array_equal(B_view, B_expected)
+    else:
+        # B[t, :] holds per_thread_elems 16-bit values; each fp32 register 
packs
+        # two of them in (low, high) order. Compare bit-exactly via uint32 
view.
+        dtype_np = np.float16 if dtype == "float16" else np.dtype("bfloat16")
+        if dtype == "bfloat16":
+            # numpy doesn't have a stable bfloat16 dtype across versions; use 
ml_dtypes.
+            try:
+                from ml_dtypes import bfloat16 as _bf16  # noqa: PLC0415
+
+                dtype_np = _bf16
+            except ImportError:
+                pytest.skip("bfloat16 verification needs ml_dtypes")
+        B_view = B_out.view(np.uint32).reshape(128, regs_per_thread)
+        B_expected = np.zeros((128, regs_per_thread), dtype=np.uint32)
+        for t in range(128):
+            for r in range(regs_per_thread):
+                B_expected[t, r] = _expected_reg_value_16b(
+                    A_np, shape, rep, 0, t, r, dtype_np
+                )
+        np.testing.assert_array_equal(B_view, B_expected)
+
+
+# --------------------------------------------------------------------------
+# Test 2: store direction (mirror of test 1, with .st instead of .ld)
+# --------------------------------------------------------------------------
+
+
[email protected]("shape", list(_SHAPE_REPS))
[email protected]("rep", [1, 4, 16])
[email protected]("dtype", ["float32"])
+def test_tcgen05_st_16xnb_store(shape, rep, dtype):
+    """Round-trip test: write the M=64 fragment via .<shape>.x<rep>.st then 
read
+    via the standard .32x32b path; verify the host-known fragment data ends up
+    at the expected TMEM lane positions.
+
+    Only fp32 here — the 16-bit case has a different staging convention
+    (pack::16b reads/writes the LOW halves of adjacent cells, not low/high of
+    one cell) and is covered by ``test_tcgen05_16xnb_roundtrip_16b`` via a
+    self-consistent .16x*b.st → .16x*b.ld loop.
+    """
+    if rep not in _SHAPE_REPS[shape]:
+        pytest.skip(f"rep {rep} not valid for {shape}")
+    bits = tvm.runtime.DataType(dtype).bits
+    elem_per_32b = 32 // bits
+    K_cols_fp32 = _COL_FACTOR_FP32[shape] * rep
+    K_cols_elem = K_cols_fp32 * elem_per_32b
+    regs_per_thread = _REGS_FACTOR[shape] * rep
+    per_thread_elems = regs_per_thread * elem_per_32b
+    frag_rows = _FRAG_ROWS[shape]
+
+    tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32))
+    if tmem_col_width_32b > 128:
+        pytest.skip(f"tmem_col_width_32b {tmem_col_width_32b} > 128 not 
supported by .32x32b staging")  # noqa: E501
+    stage_width_elem = tmem_col_width_32b * elem_per_32b
+    VEC_LEN = 128 // bits
+    if stage_width_elem % VEC_LEN != 0:
+        pytest.skip(f"stage_width_elem {stage_width_elem} % VEC_LEN {VEC_LEN} 
!= 0")
+
+    g_layout = TileLayout(
+        S[(128, stage_width_elem // VEC_LEN, VEC_LEN) : (stage_width_elem, 
VEC_LEN, 1)]
+    )
+    stage_view = TileLayout(S[(128, stage_width_elem) : (1 @ axis_tid_in_wg, 
1)])
+    atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype)
+
+    @Tx.prim_func
+    def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None:
+        # A[tid_in_wg, i] is the i-th per-thread element to feed into the atom 
store.
+        A = Tx.match_buffer(A_ptr, (128, per_thread_elems), dtype)
+        # B[lane, col] is the TMEM-staged readout after the round-trip.
+        B = Tx.match_buffer(B_ptr, (128, stage_width_elem), dtype)
+        B_flat = B.view(-1)
+
+        Tx.device_entry()
+        warp_id = Tx.warp_id([128 // 32])
+        Tx.cta_id([2])
+        wg_id = Tx.warpgroup_id([1])
+        Tx.warp_id_in_wg([4])
+        Tx.lane_id([32])
+        tid_in_wg = Tx.thread_id([128])
+
+        tmem_addr = Tx.alloc_shared([1], "uint32")
+
+        if wg_id == 0:
+            with Tx.warpgroup():
+                if warp_id == 0:
+                    with Tx.warp():
+                        Tx.ptx.tcgen05.alloc(
+                            Tx.address_of(tmem_addr),
+                            n_cols=tmem_col_width_32b,
+                            cta_group=1,
+                        )
+
+                Tx.tvm_storage_sync("shared")
+
+                tmem = Tx.decl_buffer(
+                    (128, stage_width_elem),
+                    dtype,
+                    scope="tmem",
+                    allocated_addr=tmem_addr[0],
+                    layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 
1 @ TCol)]),
+                )
+
+                # Load per-thread A → frag_reg
+                frag_reg = Tx.alloc_local((per_thread_elems,), dtype)
+                with Tx.thread():
+                    for i in range(per_thread_elems):
+                        frag_reg[i] = A[tid_in_wg, i]
+                Tx.cuda.cta_sync()
+
+                # frag_local -> TMEM via .<shape>.x<rep>.st
+                frag_local = frag_reg.view(frag_rows, K_cols_elem, 
layout=atom_view)
+                Tx.copy_async(tmem[0:frag_rows, 0:K_cols_elem], frag_local[:, 
:])
+                Tx.ptx.tcgen05.wait.st()
+                Tx.cuda.cta_sync()
+
+                # TMEM -> readout via .32x32b.ld
+                stage_reg = Tx.alloc_local((stage_width_elem,), dtype)
+                stage_local = stage_reg.view(128, stage_width_elem, 
layout=stage_view)
+                Tx.copy_async(stage_local[:, :], tmem[:, :])
+                Tx.ptx.tcgen05.wait.ld()
+                Tx.cuda.cta_sync()
+
+                # readout -> B (full 128×stage_width_elem dump)
+                with Tx.thread():
+                    for i in range(stage_width_elem // VEC_LEN):
+                        g_offset = Tx.meta_var(g_layout.apply(tid_in_wg, i, 
0)["m"])
+                        Tx.copy(
+                            B_flat[g_offset : g_offset + VEC_LEN],
+                            stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN],
+                        )
+
+                if warp_id == 0:
+                    with Tx.warp():
+                        Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1)
+                        Tx.ptx.tcgen05.dealloc(
+                            tmem_addr[0], n_cols=tmem_col_width_32b, 
cta_group=1
+                        )
+
+    target = tvm.target.Target("cuda")
+    with target:
+        mod = tvm.IRModule({"main": kernel})
+        mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
+        A_np = tvm.testing.generate_random_array(dtype, (128, 
per_thread_elems))
+        B_np = np.zeros((128, stage_width_elem), dtype=dtype)
+        DEV = tvm.cuda(0)
+        A = tvm.runtime.tensor(A_np, DEV)
+        B = tvm.runtime.tensor(B_np, DEV)
+        mod(A, B)
+        B_out = B.numpy()
+
+    # Build expected TMEM staging: only rows that the M=64 fragment writes to
+    # should match A's per-thread data; other rows are untouched (we set B_np 
to
+    # zero and the .32x32b.ld reads whatever the TMEM allocator left, which may
+    # be arbitrary, so only check the fragment positions).
+    if bits == 32:
+        view = B_out.view(np.uint32)
+        for t in range(128):
+            for r in range(regs_per_thread):
+                row, col = _decompose_fp32(shape, t, r)
+                tmem_lane = _frag_row_to_tmem_lane(shape, row)
+                expected = np.float32(A_np[t, r]).view(np.uint32)
+                assert view[tmem_lane, col] == expected, (
+                    f"{shape}.x{rep} {dtype}: thread {t} reg {r} → "
+                    f"(row={row}, col={col}) tmem_lane={tmem_lane} got "
+                    f"{view[tmem_lane, col]:#x} want {expected:#x}"
+                )
+    else:
+        # 16-bit: each fp32 reg packs two 16-bit elements at adjacent TMEM 
cols.
+        view = B_out.view(np.uint16)
+        for t in range(128):
+            for r in range(regs_per_thread):
+                row, col_fp32 = _decompose_fp32(shape, t, r)
+                tmem_lane = _frag_row_to_tmem_lane(shape, row)
+                lo = np.float16(A_np[t, 2 * r]).view(np.uint16) if dtype == 
"float16" else None
+                # bfloat16 (numpy) lacks a clean .view(uint16); skip in store 
mode
+                # for now to keep this test path bit-exact only for float16.
+                if dtype != "float16":
+                    pytest.skip("16b store check restricted to float16")
+                hi = np.float16(A_np[t, 2 * r + 1]).view(np.uint16)
+                assert view[tmem_lane, 2 * col_fp32] == lo, (
+                    f"{shape}.x{rep} {dtype}: t={t} r={r} lo "
+                    f"({tmem_lane=}, {col_fp32=}) got {view[tmem_lane, 2 * 
col_fp32]:#x} "
+                    f"want {lo:#x}"
+                )
+                assert view[tmem_lane, 2 * col_fp32 + 1] == hi
+
+
+# --------------------------------------------------------------------------
+# Wrapper test: exercise Tx.alloc_tcgen05_frag directly (compile-only smoke).
+# --------------------------------------------------------------------------
+
+
[email protected](
+    "shape, frag_rows, K_cols",
+    [
+        ("32x32b", 128, 32),  # .32x32b.x32 fp32: simple thread-rows layout
+        ("32x32b", 128, 64),  # .32x32b.x64 fp32
+        ("16x64b", 64, 64),   # .16x64b.x32 fp32
+        ("16x128b", 64, 64),  # .16x128b.x16 fp32
+        ("16x256b", 64, 64),  # .16x256b.x8 fp32
+    ],
+)
+def test_alloc_tcgen05_frag_wrapper_compiles(shape, frag_rows, K_cols):
+    """Ensure Tx.alloc_tcgen05_frag yields a buffer that ``Tx.copy_async`` 
accepts
+    and lowers to the correct tcgen05 atom for each supported instr_shape."""
+
+    @Tx.prim_func
+    def kernel(A_ptr: Tx.handle) -> None:
+        Tx.match_buffer(A_ptr, (128, K_cols), "float32")
+        Tx.device_entry()
+        warp_id = Tx.warp_id([4])
+        Tx.cta_id([2])
+        wg_id = Tx.warpgroup_id([1])
+        Tx.warp_id_in_wg([4])
+        Tx.lane_id([32])
+        Tx.thread_id([128])
+
+        tmem_addr = Tx.alloc_shared([1], "uint32")
+        if wg_id == 0:
+            with Tx.warpgroup():
+                if warp_id == 0:
+                    with Tx.warp():
+                        Tx.ptx.tcgen05.alloc(
+                            Tx.address_of(tmem_addr), n_cols=max(32, K_cols), 
cta_group=1
+                        )
+                Tx.tvm_storage_sync("shared")
+                tmem = Tx.decl_buffer(
+                    (128, K_cols),
+                    "float32",
+                    scope="tmem",
+                    allocated_addr=tmem_addr[0],
+                    layout=TileLayout(S[(128, K_cols) : (1 @ TLane, 1 @ 
TCol)]),
+                )
+                # One-liner: wrapper handles per-thread storage + layout.
+                frag = Tx.alloc_tcgen05_frag(shape, (frag_rows, K_cols), 
"float32")
+                Tx.copy_async(frag[:, :], tmem[0:frag_rows, 0:K_cols])
+                Tx.ptx.tcgen05.wait.ld()
+                if warp_id == 0:
+                    with Tx.warp():
+                        Tx.ptx.tcgen05.relinquish_alloc_permit(cta_group=1)
+                        Tx.ptx.tcgen05.dealloc(
+                            tmem_addr[0], n_cols=max(32, K_cols), cta_group=1
+                        )
+
+    target = tvm.target.Target("cuda")
+    with target:
+        mod = tvm.IRModule({"main": kernel})
+        mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
+    # Compiles cleanly + the generated CUDA contains the expected PTX shape.
+    src = mod.mod.imports[0].inspect_source()
+    assert shape in src, (
+        f"expected .{shape}.x? in generated PTX, but `{shape}` not found in 
CUDA source"
+    )
+
+
+if __name__ == "__main__":
+    tvm.testing.main()

Reply via email to