This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch tir-bench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit ce87b82cc48de16c72a60737a1d92d174c2a1989
Author: Hongyi Jin <[email protected]>
AuthorDate: Wed May 27 22:36:21 2026 -0400

    feat(tirx): add M=128 dispatch + layout for .16x*b tcgen05.ld/st (#646)
    
    * feat(tirx): extend .16x*b tcgen05.ld/st dispatch to M=128 fragments
    
    The .16x*b atom natively covers M=64 (each warp handles 16 rows = the first
    half of its 32-lane TMEM partition). To cover all 128 TMEM rows of a
    warpgroup, callers previously had to either issue two raw
    ``Tx.ptx.tcgen05.ld(... shape="16x*b")`` calls themselves (with the second
    addr OR'd with 0x100000 = row+16) or alias the TMEM buffer with a shifted
    ``allocated_addr`` and call ``Tx.copy_async`` twice. Both leak the
    half-slab abstraction.
    
    This change adds M=128 support directly to the layout factory and the
    copy_async dispatch:
    
    - ``tcgen05_atom_layout(shape, (128, K), dtype)`` accepts ``rows=128`` for
      the .16x*b shapes. The factory inserts a ``v_slab`` ``m`` iter at the
      next free m-bit (stride = M=64 per-thread reg count) and doubles
      wid_in_wg's row stride from 16 to 32. The result is a 128-row fragment
      whose per-thread register vector has the low-slab regs in [0, M64_regs)
      and the high-slab regs in [M64_regs, 2*M64_regs).
    
    - ``_emit_16xnb_path`` detects ``local_buf.shape[0] ∈ {64, 128}`` via the
      match helper and emits one PTX issue per 16-row slab (``row=0`` for the
      low slab, ``row=16`` for the high slab), splitting the per-thread reg
      vector contiguously between the two calls.
    
    Bit-exact tests: the existing 92-case M=64 sweep continues to pass; an
    additional 18-case M=128 round-trip sweep (3 .16x*b shapes × 3 reps × 2
    16-bit dtypes) verifies the new dispatch path.
    
    Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
    
    * feat(tirx): plumb tcgen05 datapath through tmem_pool.alloc + dispatch
    
    The ``.16x*b`` dispatch's TMEM-side layout check was hard-coded to the
    identity Layout D (M=128, ``(128, K) : (1@TLane, 1@TCol)``). The PTX
    ``tcgen05.ld.16x*b`` with ``row=0`` actually accesses the *scattered*
    lane set ``{0..15, 32..47, 64..79, 96..111}`` (Layout F per PTX ISA
    §9.7.16.10.5), so passing the dispatch a (128, K) Layout D buffer with
    slice ``[0:64, ...]`` silently returns scattered data when the caller's
    mental model is "rows 0..63 contiguous". The leak only worked because
    all in-tree kernels happen to use M=128 MMA + .16x*b M=128 readback,
    where the scatter doesn't bite.
    
    This change makes the datapath part of the TMEM buffer's TileLayout:
    
    - ``tmem_datapath_layout(datapath, rows, cols)`` returns the TileLayout
      for Layout D (M=128 identity) and Layout F (M=64 scatter via two row
      iters with TLane strides 1 and 32). Other PTX layouts (A, B, C, E, G)
      are reserved for future expansion.
    
    - ``tmem_pool.alloc`` accepts a ``datapath=`` kwarg that derives the
      buffer's layout from ``tmem_datapath_layout``. ``layout=`` and
      ``datapath=`` are mutually exclusive; omitting both falls back to the
      permissive Layout D default.
    
    - The ``.16x*b`` and ``.32x32b`` dispatches now structurally classify
      the TMEM buffer (``_classify_tmem_datapath``) and reject mismatched
      pairings via ``_check_tmem_layout_for_atom`` against the table:
        - D × .32x32b: ✓
        - D × .16x*b M=64: ✓ (legacy permissive — half-slab read)
        - D × .16x*b M=128: ✓ (two PTX issues)
        - F × .16x*b M=64: ✓ (canonical pairing)
        - F × .16x*b M=128: ✗
        - F × .32x32b: ✗
    
    Tests: 304 passed (was 266) including 36 new Layout F round-trips and
    two negative tests for F × .16x*b M=128 and F × .32x32b. The existing
    M=64 sweep against Layout D buffers remains permissive (no test change
    needed) since that's the legacy contract.
    
    Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
    
    * fix(tirx): correct Layout F iter ordering for SplitCoord semantics
    
    The first version of ``tmem_datapath_layout("F", ...)`` encoded the
    scattered M=64 mapping as ``S[(16, 4, cols) : (1@TLane, 32@TLane, 
1@TCol)]``,
    under the mistaken assumption (taken from the comments in
    ``tcgen05_atom_layout``) that the *first* iter takes the *lowest* row
    bits. ``SplitCoord`` (src/tirx/ir/layout/utils.cc) actually walks the
    shape from the last dim back to the first, so for shape ``(s0, s1)`` the
    first iter receives ``coord // s1`` (the *high* bits) and the second
    receives ``coord % s1`` (the low bits) — row-major flattening.
    
    With the wrong ordering, ``apply(r, c)`` decomposed ``r`` as
    ``(r // 4, r % 4)`` and mapped to ``TLane = (r // 4) * 1 + (r % 4) * 32``,
    which doesn't match the canonical scatter
    ``(r // 16) * 32 + (r % 16)`` that ``.16x*b`` M=64 PTX accesses. The
    round-trip test didn't catch it because both write and read use the
    same factory output, so the dispatch's structural compatibility check
    just sees "two layouts that match each other" — the PTX behavior is
    fixed by hardware and the layout label is essentially decorative for
    this dispatch path.
    
    This change:
    - Re-orders the iter list to ``S[(4, 16, cols) : (32@TLane, 1@TLane, 
1@TCol)]``
      so that ``SplitCoord(r, (4, 16))`` returns ``[r // 16, r % 16]`` and
      the warp selector (iter 0) gets the TLane stride 32 while the
      within-slab lane (iter 1) gets TLane stride 1. ``layout.apply(r, c)``
      now returns ``TLane = (r // 16) * 32 + (r % 16)``, matching the
      ``_frag_row_to_tmem_lane`` formula used by the existing
      ``_run_load_test`` host-side validation.
    
    - Updates the factory's docstring with the row-major SplitCoord
      ordering so future readers don't have to re-derive it from the C++.
    
    - Adds two unit tests
      (``test_tmem_datapath_layout_F_row_to_lane_mapping``,
      ``test_tmem_datapath_layout_D_row_to_lane_mapping``) that call
      ``layout.apply(r, c, shape=...)`` directly and assert the
      ``(row, col) → (TLane, TCol)`` mapping bit-exactly across every M=64
      logical row. These are the right discriminator between the two
      possible iter orderings.
    
    Tests: 132 passed (was 130 — +2 layout unit tests).
    
    Co-Authored-By: Claude Opus 4.7 (1M context) <[email protected]>
    
    * fix(tirx): align tcgen05_atom_layout with SplitCoord semantics
    
    The per-shape ``row_iters_fp32`` and ``col_iters_fp32`` lists in
    ``tcgen05_atom_layout`` were written low-to-high to match the natural
    mathematical decomposition (e.g. ``.16x256b``'s
    ``row = t1 + 8*va + 16*wid_in_wg``: t1 is the LOW iter in the list, wid is
    the HIGH). ``TileLayout`` decomposes a flat coord via ``SplitCoord``
    (src/tirx/ir/layout/utils.cc), which uses row-major ordering — the FIRST
    iter receives ``coord // (product of remaining extents)`` (i.e. the *high*
    bits), and the LAST iter receives ``coord % extent`` (the *low* bits).
    
    As a result, the produced TileLayout's ``apply(row, col)`` returned axis
    values that *disagreed* with PTX. e.g. for ``.16x256b`` M=64 N=1, the
    factory mapped ``(row=1, col=0)`` to ``(laneid=0, m=0, wid=1)`` — but PTX
    puts that frag element at ``(laneid=4, m=0, wid=0)`` (thread 4 reg 0, per
    ``_decompose_fp32``). The round-trip tests didn't catch it because the
    dispatch ignores the layout label and emits raw PTX; the layout was
    essentially decorative.
    
    This change:
    - Adds an explicit reversal of ``row_iters`` and ``col_iters`` (both
      individually, after ``_scale``) just before constructing the final
      ``Iter`` list, so the FIRST iter in each dim receives the HIGH bits
      per ``SplitCoord``. The C_pack iter for 16-bit dtypes is now appended
      at the LOW end of the col axis (last iter), matching its
      per-thread-pack-bit semantics.
    
    - Adds ``test_tcgen05_atom_layout_apply_matches_decompose_fp32`` — a
      9-case sweep (3 ``.16x*b`` shapes × 3 reps) that walks every
      (thread, reg) combination, derives the expected (row, col) via
      ``_decompose_fp32``, and asserts ``layout.apply(row, col)`` returns
      the matching (laneid, wid_in_wg, m) triple. This pins down the layout
      so it can't silently drift again — round-trip alone wouldn't catch a
      reversal bug.
    
    Tests: 315 passed in copy_async/ suite (was 304 — +9 atom-layout apply
    sweep, +2 datapath layout apply unit tests added in the prior commit).
    Existing M=64 / M=128 / Layout F bit-exact round-trips unchanged.
---
 3rdparty/cutlass_fpA_intB_gemm                     |   2 +-
 3rdparty/tvm-ffi                                   |   2 +-
 python/tvm/tirx/lang/alloc_pool.py                 |  30 ++-
 python/tvm/tirx/layout.py                          | 165 ++++++++++++++--
 .../tile_primitive/cuda/copy_async/tcgen05_ldst.py | 211 +++++++++++++++++----
 .../cuda/copy_async/test_tmem_16xnb.py             | 197 ++++++++++++++++++-
 6 files changed, 541 insertions(+), 66 deletions(-)

diff --git a/3rdparty/cutlass_fpA_intB_gemm b/3rdparty/cutlass_fpA_intB_gemm
index 953121f189..72b9883c98 160000
--- a/3rdparty/cutlass_fpA_intB_gemm
+++ b/3rdparty/cutlass_fpA_intB_gemm
@@ -1 +1 @@
-Subproject commit 953121f18946cedf88c2ccb6439944956ad495a8
+Subproject commit 72b9883c986a2ff427ca61ac0b14ad59be1dc862
diff --git a/3rdparty/tvm-ffi b/3rdparty/tvm-ffi
index 3c35034fd1..1fed0ae042 160000
--- a/3rdparty/tvm-ffi
+++ b/3rdparty/tvm-ffi
@@ -1 +1 @@
-Subproject commit 3c35034fd1026011736e19a4e0e1ed0f22058c42
+Subproject commit 1fed0ae0421e614d45662e8ee6bcae353d3ab2ea
diff --git a/python/tvm/tirx/lang/alloc_pool.py 
b/python/tvm/tirx/lang/alloc_pool.py
index 3a9ae82b30..e7996ff432 100644
--- a/python/tvm/tirx/lang/alloc_pool.py
+++ b/python/tvm/tirx/lang/alloc_pool.py
@@ -300,7 +300,35 @@ class TMEMPool:
         )
         return total_bits // (32 * rows)
 
-    def alloc(self, shape, dtype="float32", *, layout=None, cols=None):
+    def alloc(self, shape, dtype="float32", *, layout=None, cols=None, 
datapath=None):
+        """Allocate a TMEM buffer.
+
+        Parameters
+        ----------
+        shape, dtype, cols
+            Standard buffer shape / dtype / column count.
+        layout
+            Explicit ``TileLayout``. Mutually exclusive with ``datapath``.
+        datapath : str | None
+            Optional tcgen05 datapath letter (``"D"`` for M=128 full datapath,
+            ``"F"`` for M=64 non-``.ws`` scattered). When provided, the 
buffer's
+            layout is derived from ``tmem_datapath_layout(datapath, *shape)``
+            so the row index reflects the *physical* TMEM lane occupation
+            (PTX ISA §9.7.16.10.5). The downstream ``.16x*b`` / ``.32x32b``
+            dispatches structurally check this layout to catch mismatched
+            atoms (e.g. a ``.16x*b`` M=128 read against a Layout F buffer).
+            Defaults to ``None``, which means Layout D's identity row→lane
+            mapping — keep this for shape ``(128, X)`` buffers that hold
+            an M=128 MMA accumulator.
+        """
+        from tvm.tirx.layout import tmem_datapath_layout
+
+        if layout is not None and datapath is not None:
+            raise ValueError("TMEMPool.alloc: pass at most one of layout= and 
datapath=")
+        if datapath is not None:
+            assert len(shape) == 2, "TMEMPool.alloc: datapath= requires a 2-D 
shape"
+            layout = tmem_datapath_layout(datapath, shape[0], shape[1])
+
         ir = _get_ir()
         cols = self._resolve_cols(shape, dtype, cols, layout)
         col_start = self.offset
diff --git a/python/tvm/tirx/layout.py b/python/tvm/tirx/layout.py
index e17a0d61f8..c3cc25748b 100644
--- a/python/tvm/tirx/layout.py
+++ b/python/tvm/tirx/layout.py
@@ -566,7 +566,94 @@ except NameError:  # pragma: no cover
     __all__ = []  # type: ignore[var-annotated]
 __all__ += list(_AXIS_NAMES)
 __all__ += ["R", "S"]
-__all__ += ["wg_local_layout", "tcgen05_atom_layout"]
+__all__ += ["wg_local_layout", "tcgen05_atom_layout", "tmem_datapath_layout"]
+
+
+# ============================================================================
+# TMEM datapath layouts (PTX ISA §9.7.16.10.5)
+# ============================================================================
+#
+# ``tcgen05.mma`` writes its output matrix C into TMEM using one of several
+# **datapath layouts** depending on the MMA's M dimension and ``.ws`` mode.
+# Each layout determines *which* physical TMEM lanes (rows) the matrix
+# occupies; the leak in the original ``_default_tmem_layout`` was that it
+# always used the identity ``(rows, cols) : (1@TLane, 1@TCol)`` mapping,
+# which is correct only for Layout D (M=128 full datapath). For Layout F
+# (M=64 non-``.ws``) the MMA writes scattered lanes
+# ``{0..15, 32..47, 64..79, 96..111}`` — half of each warp's 32-lane
+# partition — and the readback path (``.16x*b`` M=64 atom) has the matching
+# scatter built into the PTX. To keep the buffer's logical row indexing in
+# sync with the physical scatter, the buffer's TileLayout must encode the
+# scatter directly.
+#
+# We surface this via the factory below. Callers pass the datapath letter
+# (``"D"`` / ``"F"``) and the logical ``(rows, cols)``; the factory returns
+# the appropriate TileLayout. ``tmem_pool.alloc(..., datapath="F")`` plumbs
+# this into the buffer's layout so the dispatch can structurally verify
+# atom ↔ datapath compatibility instead of silently accepting mismatches.
+#
+# Supported today:
+#   - ``"D"``: M=128, ``.cta_group::1``, full datapath. Identity row→lane.
+#   - ``"F"``: M=64, non-``.ws``, half datapath (4×1 lane utilization).
+#     Logical row r → physical lane (r // 16) * 32 + (r % 16).
+#
+# Layouts A / B / C / E / G are reserved for future expansion.
+
+
+_TMEM_DATAPATH_ROWS = {"D": 128, "F": 64}
+
+
+def tmem_datapath_layout(datapath: str, rows: int, cols: int) -> "TileLayout":
+    """Return the ``TileLayout`` for a tcgen05 MMA datapath.
+
+    See PTX ISA §9.7.16.10.5 for the datapath enumeration. The returned
+    layout is shape-compatible with a buffer of ``(rows, cols)`` and
+    encodes the logical-row → physical-TMEM-lane mapping that the
+    corresponding MMA writes to (and that the matching ``.16x*b`` /
+    ``.32x32b`` atom expects to read).
+
+    Parameters
+    ----------
+    datapath : str
+        One of ``"D"`` (M=128, ``.cta_group::1``, full datapath) or
+        ``"F"`` (M=64, non-``.ws``, half datapath). Other layouts are not
+        yet supported by this factory.
+    rows : int
+        Logical row count of the TMEM buffer. Must match the datapath's M
+        dimension: 128 for D, 64 for F.
+    cols : int
+        Logical column count.
+
+    Returns
+    -------
+    TileLayout
+        Buffer-shape-compatible layout for ``(rows, cols)``.
+    """
+    if datapath not in _TMEM_DATAPATH_ROWS:
+        raise ValueError(
+            f"tmem_datapath_layout: unknown datapath {datapath!r}; "
+            f"supported: {sorted(_TMEM_DATAPATH_ROWS)}"
+        )
+    expected = _TMEM_DATAPATH_ROWS[datapath]
+    if rows != expected:
+        raise ValueError(
+            f"tmem_datapath_layout: datapath={datapath!r} expects 
rows={expected}, got {rows}"
+        )
+    tlane = Axis.get("TLane")
+    tcol = Axis.get("TCol")
+    if datapath == "D":
+        # M=128, identity row→lane: row r ∈ [0, 128) → physical lane r.
+        return TileLayout(S[(rows, cols) : (1 @ tlane, 1 @ tcol)])
+    # Layout F: M=64 scattered. Logical row r = wid * 16 + intra (wid ∈ [0,4),
+    # intra ∈ [0,16)) → physical lane wid * 32 + intra, i.e.
+    # ``r // 16`` is the warp selector and ``r % 16`` is the within-slab lane.
+    # ``TileLayout`` decomposes a scalar row index via ``SplitCoord``
+    # (src/tirx/ir/layout/utils.cc), which uses row-major ordering: with
+    # shape ``(s0, s1)`` the FIRST iter receives ``coord // s1`` (the high
+    # bits) and the SECOND receives ``coord % s1`` (the low bits). So we
+    # pin the warp selector to iter 0 (extent 4, TLane stride 32) and the
+    # within-slab lane to iter 1 (extent 16, TLane stride 1).
+    return TileLayout(S[(4, 16, cols) : (32 @ tlane, 1 @ tlane, 1 @ tcol)])
 
 
 def wg_local_layout(cols, rows=128):
@@ -593,8 +680,19 @@ _TCGEN05_ATOM_REPS = {
 # fragment is 128 rows × (factor * rep) fp32 cols with factor=1.
 _TCGEN05_COL_FACTOR_FP32 = {"32x32b": 1, "16x64b": 2, "16x128b": 4, "16x256b": 
8}
 
-# Number of fragment rows per warpgroup for each instr_shape.
-_TCGEN05_FRAG_ROWS = {"32x32b": 128, "16x64b": 64, "16x128b": 64, "16x256b": 
64}
+# Allowed fragment row counts per warpgroup for each instr_shape. ``.32x32b``
+# is fixed at M=128; ``.16x*b`` natively covers M=64 (one 16-row slab per
+# warp, using lanes 0..15 of each warp's 32-lane TMEM partition) and can be
+# extended to M=128 by issuing the atom twice with row offsets 0 and 16
+# (covering lanes 0..15 + 16..31, i.e. the warp's full slab). The M=128
+# variant doubles per-thread registers and treats the extra slab as the
+# highest m-bit.
+_TCGEN05_FRAG_ROWS = {
+    "32x32b": (128,),
+    "16x64b": (64, 128),
+    "16x128b": (64, 128),
+    "16x256b": (64, 128),
+}
 
 
 def tcgen05_atom_layout(
@@ -666,10 +764,10 @@ def tcgen05_atom_layout(
             f"tcgen05_atom_layout tensor_shape must be 2-D (rows, cols), got 
{tensor_shape!r}"
         )
     rows, cols = tensor_shape
-    expected_rows = _TCGEN05_FRAG_ROWS[instr_shape]
-    if rows != expected_rows:
+    allowed_rows = _TCGEN05_FRAG_ROWS[instr_shape]
+    if rows not in allowed_rows:
         raise ValueError(
-            f"tcgen05_atom_layout {instr_shape!r} expects 
rows={expected_rows}, got {rows}"
+            f"tcgen05_atom_layout {instr_shape!r} expects rows ∈ 
{allowed_rows}, got {rows}"
         )
 
     elem_per_32b = 32 // bits
@@ -710,20 +808,27 @@ def tcgen05_atom_layout(
         ]
         return TileLayout.from_iters(iters, [], {})
 
+    # Iter lists are written high-to-low: ``TileLayout`` decomposes a flat
+    # coordinate via ``SplitCoord`` (src/tirx/ir/layout/utils.cc) using
+    # row-major ordering, where the FIRST iter receives the *high* bits and
+    # the LAST iter receives the *low* bits. So R_w (highest-stride row
+    # contribution) comes first in row_iters_fp32 and R_t1/t2 (lowest)
+    # comes last; same for col.
     if shape == "16x64b":
         # Per-warp tile (fp32 view): (16 rows, 2N cols). Per-lane regs = N.
         # Lane (t0, t1, t2): t0 = laneid & 1, t1 = (laneid >> 1) & 1, t2 = 
laneid >> 2.
         #   Row = t2 + 8*t0 + 16*wid_in_wg
         #   Col (fp32) = t1 + 2*r,   r ∈ [0, N)
         row_iters_fp32 = [
-            (8, 4, laneid),    # R_t2: laneid bits 2..4 → R bits 0..2
-            (2, 1, laneid),    # R_t0: laneid bit 0    → R bit 3
             (4, 1, wid),       # R_w:  wid_in_wg       → R bits 4..5
+            (2, 1, laneid),    # R_t0: laneid bit 0    → R bit 3
+            (8, 4, laneid),    # R_t2: laneid bits 2..4 → R bits 0..2
         ]
         col_iters_fp32 = [
-            (2, 2, laneid),    # C_t1: laneid bit 1    → C bit 0
             (N, 1, "m"),       # C_r:  register slot   → C bits 1..
+            (2, 2, laneid),    # C_t1: laneid bit 1    → C bit 0
         ]
+        m_used_M64 = N
     elif shape == "16x128b":
         # Per-warp tile (fp32 view): (16 rows, 4N cols). Per-lane regs = 2N.
         # Lane (t0, t1): t0 = laneid & 3, t1 = laneid >> 2.
@@ -731,29 +836,50 @@ def tcgen05_atom_layout(
         #   Row = t1 + 8*ra + 16*wid_in_wg
         #   Col (fp32) = t0 + 4*rb
         row_iters_fp32 = [
-            (8, 4, laneid),    # R_t1: laneid bits 2..4 → R bits 0..2
-            (2, 1, "m"),       # R_ra: reg bit 0        → R bit 3
             (4, 1, wid),       # R_w
+            (2, 1, "m"),       # R_ra: reg bit 0        → R bit 3
+            (8, 4, laneid),    # R_t1: laneid bits 2..4 → R bits 0..2
         ]
         col_iters_fp32 = [
-            (4, 1, laneid),    # C_t0: laneid bits 0..1 → C bits 0..1
             (N, 2, "m"),       # C_rb: reg bits 1..     → C bits 2..
+            (4, 1, laneid),    # C_t0: laneid bits 0..1 → C bits 0..1
         ]
+        m_used_M64 = 2 * N
     else:  # 16x256b
         # Per-warp tile (fp32 view): (16 rows, 8N cols). Per-lane regs = 4N.
         # Lane (t0, t1) as for 16x128b. Reg r = v0p + 2*va + 4*vb.
         #   Row = t1 + 8*va + 16*wid_in_wg
         #   Col (fp32) = v0p + 2*t0 + 8*vb
         row_iters_fp32 = [
-            (8, 4, laneid),    # R_t1
-            (2, 2, "m"),       # R_va: reg bit 1 → R bit 3
             (4, 1, wid),       # R_w
+            (2, 2, "m"),       # R_va: reg bit 1 → R bit 3
+            (8, 4, laneid),    # R_t1
         ]
         col_iters_fp32 = [
-            (2, 1, "m"),       # C_v0p: reg bit 0  → C bit 0
-            (4, 1, laneid),    # C_t0
             (N, 4, "m"),       # C_vb:  reg bits 2.. → C bits 3..
+            (4, 1, laneid),    # C_t0
+            (2, 1, "m"),       # C_v0p: reg bit 0  → C bit 0
         ]
+        m_used_M64 = 4 * N
+
+    if rows == 128:
+        # M=128 covers both 16-row half-slabs of each warp's 32-lane TMEM
+        # partition (the M=64 atom covers only lanes 0..15; the high half
+        # 16..31 needs a second PTX issue with row offset 16). We surface
+        # the combined fragment as a single (128, K) tile by inserting a
+        # v_slab iter right *after* R_w (i.e. as the next-highest row bit).
+        # v_slab claims one m-bit at the next free offset
+        # (stride = m_used_M64) so reg indices [0, m_used_M64) hold the
+        # low slab and [m_used_M64, 2*m_used_M64) hold the high slab — the
+        # split the dispatch uses when emitting the two PTX calls. The
+        # inserted iter also doubles wid_in_wg's row stride from 16 to 32,
+        # so the four warps now tile rows 0..31 / 32..63 / 64..95 / 96..127.
+        new_row_iters = []
+        for ext, stride, axis in row_iters_fp32:
+            new_row_iters.append((ext, stride, axis))
+            if axis is wid:
+                new_row_iters.append((2, m_used_M64, "m"))
+        row_iters_fp32 = new_row_iters
 
     def _scale(iters):
         out = []
@@ -768,10 +894,11 @@ def tcgen05_atom_layout(
     col_iters = _scale(col_iters_fp32)
 
     # For the 16-bit packed variant each fp32 register holds two adjacent
-    # column elements (low / high halves), so we prepend a C_pack iter of
-    # extent ``elem_per_32b`` and m-stride 1 to the column group.
+    # column elements (low / high halves). Add a C_pack iter of extent
+    # ``elem_per_32b`` and m-stride 1 at the *low* end of the col axis —
+    # i.e. as the LAST col iter under SplitCoord's high-to-low ordering.
     if elem_per_32b > 1:
-        col_iters = [(elem_per_32b, 1, "m"), *col_iters]
+        col_iters.append((elem_per_32b, 1, "m"))
 
     iters = [Iter(ext, stride, axis) for ext, stride, axis in row_iters + 
col_iters]
     return TileLayout.from_iters(iters, [], {})
diff --git 
a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py 
b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py
index 345eef4519..0c958dde3e 100644
--- a/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py
+++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy_async/tcgen05_ldst.py
@@ -27,7 +27,15 @@ from tvm.arith import Analyzer
 from tvm.runtime import DataType
 from tvm.script import tirx as Tx
 from tvm.tirx import Buffer, PrimFunc
-from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout, 
tid_in_wg
+from tvm.tirx.layout import (
+    S,
+    TCol,
+    TileLayout,
+    TLane,
+    tcgen05_atom_layout,
+    tid_in_wg,
+    tmem_datapath_layout,
+)
 from tvm.tirx.operator.tile_primitive import DispatchContext, predicate, 
register_dispatch
 from tvm.tirx.stmt import TilePrimitiveCall
 
@@ -42,17 +50,18 @@ _TCGEN05_COL_FACTOR_FP32 = {"16x64b": 2, "16x128b": 4, 
"16x256b": 8}
 
 
 def _match_tcgen05_atom_layout(buf):
-    """Return ``(instr_shape, rep)`` if ``buf.layout`` matches a tcgen05
-    ``.16x*b`` atom layout for some supported ``instr_shape``.
+    """Return ``(instr_shape, rep, frag_rows)`` if ``buf.layout`` matches a
+    tcgen05 ``.16x*b`` atom layout for some supported ``instr_shape``.
 
-    The local buffer shape ``(64, K)`` together with the dtype determines the
-    candidate ``rep`` for each ``instr_shape``; we just probe the three shapes
-    and structurally compare. ``None`` if no atom layout matches.
+    The local buffer shape ``(frag_rows, K)`` (``frag_rows`` ∈ {64, 128})
+    together with the dtype determines the candidate ``rep`` for each
+    ``instr_shape``; we just probe the three shapes × two frag_rows and
+    structurally compare. ``None`` if no atom layout matches.
     """
     if len(buf.shape) != 2:
         return None
     rows, cols = int(buf.shape[0]), int(buf.shape[1])
-    if rows != 64:
+    if rows not in (64, 128):
         return None
     dtype = buf.dtype
     layout_c = buf.layout.canonicalize()
@@ -68,10 +77,95 @@ def _match_tcgen05_atom_layout(buf):
         # Recover rep from cols (same arithmetic the factory uses).
         elem_per_32b = 32 // DataType(dtype).bits
         rep = cols // (_TCGEN05_COL_FACTOR_FP32[shape] * elem_per_32b)
-        return shape, rep
+        return shape, rep, rows
+    return None
+
+
+def _classify_tmem_datapath(tmem_buf):
+    """Return ``"D"`` / ``"F"`` if ``tmem_buf.layout`` matches a known tcgen05
+    datapath (PTX ISA §9.7.16.10.5), else ``None``.
+
+    Layout D (M=128, identity row→lane) is the default returned by
+    ``_default_tmem_layout``. Layout F (M=64 non-``.ws``, scattered) is the
+    explicit opt-in produced by ``tmem_pool.alloc(..., datapath="F")``.
+    The dispatch uses this to pair each ``.16x*b`` / ``.32x32b`` atom with a
+    compatible layout — see ``_check_tmem_layout_for_atom``.
+    """
+    if tmem_buf.layout is None:
+        return None
+    buf_layout = tmem_buf.layout.canonicalize()
+    rows = int(tmem_buf.shape[0])
+    if rows == 128:
+        cand = tmem_datapath_layout("D", 128, tmem_buf.shape[1]).canonicalize()
+        try:
+            tvm.ir.assert_structural_equal(buf_layout, cand)
+            return "D"
+        except (AssertionError, ValueError):
+            return None
+    if rows == 64:
+        cand = tmem_datapath_layout("F", 64, tmem_buf.shape[1]).canonicalize()
+        try:
+            tvm.ir.assert_structural_equal(buf_layout, cand)
+            return "F"
+        except (AssertionError, ValueError):
+            return None
     return None
 
 
+# Compatibility matrix between the TMEM buffer's datapath layout and the
+# tcgen05 ld/st atom requested by ``Tx.copy_async``:
+#
+#   datapath × atom              | accepted? | rationale
+#   ---------------------------- | --------- | --------------------------------
+#   D (M=128 full)  × .32x32b    | yes       | full 128 lanes, all 32 per warp
+#   D (M=128 full)  × .16x*b M=64| yes       | reads first half-slab (lanes
+#                                |           |   0..15 of each warp partition)
+#                                |           |   — the rest of acc is wasted
+#                                |           |   for this atom but valid data
+#   D (M=128 full)  × .16x*b M=128| yes      | reads all 128 lanes via row=0
+#                                |           |   and row=16 PTX issues
+#   F (M=64 scatter)× .16x*b M=64| yes       | canonical pairing — F's row
+#                                |           |   indexing matches the atom's
+#                                |           |   scatter access
+#   F (M=64 scatter)× .16x*b M=128| no       | F only writes the low slab; the
+#                                |           |   high slab (row=16) is garbage
+#   F (M=64 scatter)× .32x32b    | no       | F only utilizes 16 of each
+#                                |           |   warp's 32 lanes
+_TMEM_ATOM_COMPAT = {
+    ("D", "32x32b", 128): True,
+    ("D", "16x*b", 64): True,
+    ("D", "16x*b", 128): True,
+    ("F", "32x32b", 128): False,
+    ("F", "16x*b", 64): True,
+    ("F", "16x*b", 128): False,
+}
+
+
+def _check_tmem_layout_for_atom(tmem_buf, atom_kind, frag_rows):
+    """Raise ``ValueError`` if the TMEM buffer's datapath layout is
+    incompatible with the requested ``tcgen05`` atom.
+
+    ``atom_kind`` is ``"32x32b"`` or ``"16x*b"``; ``frag_rows`` is the
+    register-side fragment row count (128 for ``.32x32b`` and ``.16x*b``
+    M=128 variants, 64 for ``.16x*b`` M=64). If the buffer's layout is
+    unrecognized (i.e. it isn't Layout D or Layout F), the dispatch falls
+    back to the structural assertions below.
+    """
+    datapath = _classify_tmem_datapath(tmem_buf)
+    if datapath is None:
+        return None
+    allowed = _TMEM_ATOM_COMPAT.get((datapath, atom_kind, frag_rows), False)
+    if not allowed:
+        raise ValueError(
+            f"tcgen05 dispatch: TMEM buffer with datapath={datapath!r} is "
+            f"incompatible with atom={atom_kind!r} (frag_rows={frag_rows}). "
+            f"See PTX ISA §9.7.16.10.5 for datapath/atom pairings; the "
+            f"buffer was allocated via tmem_pool.alloc(..., "
+            f"datapath={datapath!r})."
+        )
+    return datapath
+
+
 def copy_tmem_local_impl(op_call: TilePrimitiveCall, sctx: DispatchContext) -> 
PrimFunc | None:
     op_call = TilePrimitiveCall.downcast(op_call)
     dst_buffer_region, src_buffer_region = op_call.dst, op_call.src
@@ -108,10 +202,11 @@ def copy_tmem_local_impl(op_call: TilePrimitiveCall, 
sctx: DispatchContext) -> P
     atom_match = _match_tcgen05_atom_layout(local_buf)
 
     if atom_match is not None:
-        shape, num = atom_match
+        shape, num, frag_rows = atom_match
         return _emit_16xnb_path(
             shape=shape,
             num=num,
+            frag_rows=frag_rows,
             direction=direction,
             tmem_buf=tmem_buf,
             local_buf=local_buf,
@@ -138,6 +233,9 @@ def _emit_32x32b_path(
 ) -> PrimFunc:
     """Original M=128 fragment path using ``tcgen05.{ld,st}.32x32b.xN``."""
     # local: 128xWIDTH <-> tmem: 128xSHAPE[1]
+    # ``.32x32b`` accesses 32 lanes per warp — the full warp partition — so
+    # the TMEM buffer must be Layout D (M=128 full datapath). Reject Layout F.
+    _check_tmem_layout_for_atom(tmem_buf, "32x32b", 128)
     assert analyzer.can_prove_equal(local_buf.shape[0], 128)
     assert analyzer.can_prove_equal(tmem_buf.shape[0], 128)
 
@@ -198,6 +296,7 @@ def _emit_16xnb_path(
     *,
     shape,
     num,
+    frag_rows,
     direction,
     tmem_buf,
     local_buf,
@@ -206,55 +305,86 @@ def _emit_16xnb_path(
     elem_per_32b,
     analyzer,
 ) -> PrimFunc:
-    """M=64 fragment path using ``tcgen05.{ld,st}.<shape>.x<num>`` (one of
-    ``.16x64b``, ``.16x128b``, ``.16x256b``).
-
-    Each of the warpgroup's 4 warps issues the atom once with
-    ``row_offset=0``; the PTX TMEM access restriction places warp ``i`` on
-    TMEM lanes ``i*32..i*32+31``, of which the atom uses the first 16 to
-    cover one 16-row slab of the 64-row fragment. Collectively, the four
-    warps cover all 64 rows.
+    """``.16x*b`` fragment path using ``tcgen05.{ld,st}.<shape>.x<num>`` (one
+    of ``.16x64b``, ``.16x128b``, ``.16x256b``).
+
+    Each of the warpgroup's 4 warps issues the atom with ``row_offset=0`` to
+    cover lanes 0..15 of its 32-lane TMEM partition (one 16-row slab); the
+    four warps collectively span M=64 rows. When ``frag_rows == 128`` the
+    dispatch emits a second issue with ``row_offset=16`` to also cover lanes
+    16..31 of each warp's partition, doubling the fragment's row coverage to
+    M=128. The two atoms share the same column footprint; the layout factory
+    surfaces the combined per-thread register vector with the second slab's
+    regs in the high half of the m-axis (so the dispatch can split regs
+    contiguously between the two PTX calls).
     """
     # Per-atom column footprint in fp32 columns:
     #   .16x64b  → 2N    .16x128b → 4N    .16x256b → 8N
     col_factor_fp32 = {"16x64b": 2, "16x128b": 4, "16x256b": 8}[shape]
-    # Per-thread register count (in 32-bit units):
+    # Per-thread register count per 16-row slab (in 32-bit units):
     #   .16x64b.xN  → N        .16x128b.xN → 2N      .16x256b.xN → 4N
-    regs_per_thread = {"16x64b": num, "16x128b": 2 * num, "16x256b": 4 * 
num}[shape]
+    regs_per_thread_per_slab = {"16x64b": num, "16x128b": 2 * num, "16x256b": 
4 * num}[shape]
+    n_slabs = frag_rows // 64  # 1 for M=64, 2 for M=128
+    assert n_slabs in (1, 2)
+    regs_per_thread = regs_per_thread_per_slab * n_slabs
     # Logical column width that the local buffer view exposes (in element 
units).
     width_elems = col_factor_fp32 * num * elem_per_32b
     # Per-thread storage in element units (same total bits as the register 
vector).
     per_thread_elems = regs_per_thread * elem_per_32b
 
-    # Local-side: shape (64, K_cols)
-    assert analyzer.can_prove_equal(local_buf.shape[0], 64), (
-        f".16x*b path expects local_buf rows=64, got {local_buf.shape[0]}"
+    # Local-side: shape (frag_rows, K_cols)
+    assert analyzer.can_prove_equal(local_buf.shape[0], frag_rows), (
+        f".16x*b path expects local_buf rows={frag_rows}, got 
{local_buf.shape[0]}"
     )
     assert analyzer.can_prove_equal(local_buf.shape[1], width_elems), (
         f".16x*b path expects local_buf cols={width_elems}, got 
{local_buf.shape[1]}"
     )
 
-    # TMEM-side: shape (128, W); the M=64 fragment occupies the first 16 lanes 
of
-    # each warp's 32-lane slab.
-    assert analyzer.can_prove_equal(tmem_buf.shape[0], 128), (
-        f".16x*b path expects tmem_buf rows=128, got {tmem_buf.shape[0]}"
-    )
-    tmem_layout = TileLayout(S[(128, tmem_buf.shape[1]) : (1 @ TLane, 1 @ 
TCol)]).canonicalize()
-    tvm.ir.assert_structural_equal(tmem_buf.layout.canonicalize(), tmem_layout)
+    # TMEM-side: structurally classify the buffer's datapath (D or F) and
+    # reject incompatible pairings. The PTX is identical in either case (the
+    # warp partition rule and the atom's lane access pattern are baked into
+    # the hardware); the layout classification just keeps the buffer's
+    # logical row indexing in sync with the physical TMEM occupation.
+    datapath = _check_tmem_layout_for_atom(tmem_buf, "16x*b", frag_rows)
+
+    if datapath == "F":
+        # Layout F: buffer shape (64, W), scattered row→lane.
+        assert analyzer.can_prove_equal(tmem_buf.shape[0], 64), (
+            f".16x*b Layout F expects tmem_buf rows=64, got 
{tmem_buf.shape[0]}"
+        )
+        tmem_rows = 64
+    else:
+        # Layout D (or untagged legacy buffers): shape (128, W), identity.
+        # The legacy structural check below still fires for untagged buffers
+        # so we don't silently accept arbitrary layouts.
+        assert analyzer.can_prove_equal(tmem_buf.shape[0], 128), (
+            f".16x*b path expects tmem_buf rows=128, got {tmem_buf.shape[0]}"
+        )
+        if datapath is None:
+            tmem_layout = TileLayout(
+                S[(128, tmem_buf.shape[1]) : (1 @ TLane, 1 @ TCol)]
+            ).canonicalize()
+            tvm.ir.assert_structural_equal(tmem_buf.layout.canonicalize(), 
tmem_layout)
+        tmem_rows = 128
 
     tmem_st, tmem_extent = get_st_extent(tmem_region)
     local_st, local_extent = get_st_extent(local_region)
 
-    # Local slice must be the full (64, K_cols) view.
+    # Local slice must be the full (frag_rows, K_cols) view.
     assert analyzer.can_prove_equal(local_st[0], 0)
-    assert analyzer.can_prove_equal(local_extent[0], 64)
+    assert analyzer.can_prove_equal(local_extent[0], frag_rows)
     assert analyzer.can_prove_equal(local_extent[1], width_elems)
 
-    # TMEM slice must start at row 0 (warp 0 of the WG is at lane 0) and span
-    # 64 rows (collectively the 4 warps' first 16-lane chunks).
+    # TMEM slice must start at row 0 and span ``frag_rows`` rows. For Layout
+    # F the buffer is already (64, W) so frag_rows=64 covers the full slice;
+    # for Layout D + frag_rows=64 the slice reads the *first* half-slab and
+    # the rest of the buffer's 128 rows is invisible to this atom. For
+    # Layout D + frag_rows=128 the slice covers all 128 physical lanes via
+    # two PTX issues (row=0 + row=16).
     assert analyzer.can_prove_equal(tmem_st[0], 0)
-    assert analyzer.can_prove_equal(tmem_extent[0], 64)
+    assert analyzer.can_prove_equal(tmem_extent[0], frag_rows)
     assert analyzer.can_prove_equal(tmem_extent[1], width_elems)
+    del tmem_rows  # only used for the structural check above
 
     col_off = tmem_st[1]
     assert analyzer.can_prove_equal(tvm.tirx.floormod(col_off, elem_per_32b), 
0)
@@ -282,11 +412,14 @@ def _emit_16xnb_path(
             # for the register-pointer arguments of the PTX builtin.
             local_storage = local_buf.view(per_thread_elems, 
layout=TileLayout(S[per_thread_elems]))  # noqa: E501
             local_32b = local_storage.view("uint32")
-            op(
-                tmem_buf.allocated_addr[0],
-                *[local_32b[local_col_off_elems // elem_per_32b + i] for i in 
range(regs_per_thread)],  # noqa: E501
-                shape=shape, num=num, row=0, col=col_off_32b,
-            )
+            local_reg_base = local_col_off_elems // elem_per_32b
+            for slab in range(n_slabs):
+                reg_base = slab * regs_per_thread_per_slab
+                op(
+                    tmem_buf.allocated_addr[0],
+                    *[local_32b[local_reg_base + reg_base + i] for i in 
range(regs_per_thread_per_slab)],  # noqa: E501
+                    shape=shape, num=num, row=slab * 16, col=col_off_32b,
+                )
     # fmt: on
     return impl
 
diff --git 
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
index 09a15f3bd1..0bea7587ce 100644
--- 
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
+++ 
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
@@ -42,7 +42,14 @@ import pytest
 import tvm
 import tvm.testing
 from tvm.script import tirx as Tx
-from tvm.tirx.layout import S, TCol, TileLayout, TLane, tcgen05_atom_layout
+from tvm.tirx.layout import (
+    S,
+    TCol,
+    TileLayout,
+    TLane,
+    tcgen05_atom_layout,
+    tmem_datapath_layout,
+)
 from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg
 
 
@@ -189,15 +196,58 @@ def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype):
     _run_roundtrip_16b(shape, rep, dtype)
 
 
-def _run_roundtrip_16b(shape: str, rep: int, dtype: str):
+# ``.16x*b`` atom can also span M=128 by emitting two issues per copy_async
+# (row=0 + row=16), covering the full 32-lane TMEM partition of each warp.
+# We only need to spot-check that the dispatch fires correctly and the per-
+# thread reg ↔ TMEM mapping round-trips bit-exactly — the M=64 sweep above
+# already covers the (lane, reg) decomposition, so a sparse rep set suffices.
[email protected]("shape", ["16x64b", "16x128b", "16x256b"])
[email protected]("rep", [1, 2, 4])
[email protected]("dtype", ["float16", "bfloat16"])
+def test_tcgen05_16xnb_roundtrip_16b_M128(shape, rep, dtype):
+    if rep not in _SHAPE_REPS[shape]:
+        pytest.skip(f"rep {rep} not valid for {shape}")
+    _run_roundtrip_16b(shape, rep, dtype, frag_rows_override=128)
+
+
+# Layout F (M=64 non-``.ws``, scattered) round-trip: the buffer is declared
+# with the scatter-encoded TileLayout that ``tmem_datapath_layout("F", ...)``
+# produces. ``.16x*b`` M=64 PTX has the matching scatter built in, so the
+# round-trip is bit-exact in the same way as Layout D + M=64.
[email protected]("shape", ["16x64b", "16x128b", "16x256b"])
[email protected]("rep", [1, 2, 4])
[email protected]("dtype", ["float16", "bfloat16"])
+def test_tcgen05_16xnb_roundtrip_16b_layout_F(shape, rep, dtype):
+    if rep not in _SHAPE_REPS[shape]:
+        pytest.skip(f"rep {rep} not valid for {shape}")
+    _run_roundtrip_16b(shape, rep, dtype, tmem_datapath="F")
+
+
+def _run_roundtrip_16b(
+    shape: str,
+    rep: int,
+    dtype: str,
+    *,
+    frag_rows_override=None,
+    tmem_datapath: str = "D",
+):
     bits = tvm.runtime.DataType(dtype).bits
     assert bits == 16
     elem_per_32b = 2
     K_cols_fp32 = _COL_FACTOR_FP32[shape] * rep
     K_cols_elem = K_cols_fp32 * elem_per_32b
     regs_per_thread = _REGS_FACTOR[shape] * rep
+    if frag_rows_override is not None:
+        # M=128 doubles per-thread registers (second 16-row slab per warp).
+        assert frag_rows_override == 128 and _FRAG_ROWS[shape] == 64
+        regs_per_thread *= 2
     per_thread_elems = regs_per_thread * elem_per_32b
-    frag_rows = _FRAG_ROWS[shape]
+    frag_rows = frag_rows_override if frag_rows_override is not None else 
_FRAG_ROWS[shape]
+    if tmem_datapath == "F":
+        # Layout F is only valid with M=64 (per the datapath table); M=128
+        # would need to read the high slab, which Layout F doesn't expose.
+        assert frag_rows == 64, "Layout F + M=128 is an invalid pairing"
+    tmem_rows = 64 if tmem_datapath == "F" else 128
 
     # The 16-bit round-trip writes and reads exclusively through .16x*b atoms,
     # so the TMEM column footprint is whatever ``K_cols_fp32`` says — no
@@ -205,6 +255,7 @@ def _run_roundtrip_16b(shape: str, rep: int, dtype: str):
     tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32))
     stage_width_elem = tmem_col_width_32b * elem_per_32b
     atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_elem), dtype)
+    tmem_layout = tmem_datapath_layout(tmem_datapath, tmem_rows, 
stage_width_elem)
 
     @Tx.prim_func
     def kernel(A_ptr: Tx.handle, B_ptr: Tx.handle) -> None:
@@ -237,11 +288,11 @@ def _run_roundtrip_16b(shape: str, rep: int, dtype: str):
                 Tx.tvm_storage_sync("shared")
 
                 tmem = Tx.decl_buffer(
-                    (128, stage_width_elem),
+                    (tmem_rows, stage_width_elem),
                     dtype,
                     scope="tmem",
                     allocated_addr=tmem_addr[0],
-                    layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 
1 @ TCol)]),
+                    layout=tmem_layout,
                 )
 
                 # Load per-thread A → reg_in
@@ -298,6 +349,142 @@ def _next_pow2(x: int) -> int:
     return 1 << (x - 1).bit_length()
 
 
+# Unit test: pin down the (row, col) → (TLane, TCol) mapping that the
+# ``tmem_datapath_layout`` factory encodes. A self-consistent round-trip
+# (write + read with the same factory output) can't catch a layout that
+# encodes a *wrong* scatter — the labels would still match structurally
+# even if the row→lane formula doesn't match PTX's actual behavior. This
+# test bypasses compilation and checks the layout's ``apply`` method
+# directly against ``_frag_row_to_tmem_lane`` for every M=64 logical row.
+def test_tmem_datapath_layout_F_row_to_lane_mapping():
+    """Layout F: every logical row r ∈ [0, 64) must land at physical TMEM
+    lane ``(r // 16) * 32 + (r % 16)`` — the canonical scatter that the
+    ``.16x*b`` M=64 PTX accesses (warp i on lanes ``i * 32 .. i * 32 + 15``).
+    """
+    cols = 32
+    layout = tmem_datapath_layout("F", 64, cols)
+    for r in range(64):
+        for c in [0, 1, 7, 16, 31]:
+            # Use ``apply(coord, shape=[64, cols])`` so (r, c) gets flattened
+            # row-major before SplitCoord into the shard iters.
+            axis_values = layout.apply(r, c, shape=[64, cols])
+            expected_lane = (r // 16) * 32 + (r % 16)
+            assert int(axis_values["TLane"]) == expected_lane, (
+                f"(r={r}, c={c}) mapped to TLane {int(axis_values['TLane'])}, "
+                f"expected {expected_lane} (= (r//16)*32 + (r%16))"
+            )
+            assert int(axis_values["TCol"]) == c, (
+                f"(r={r}, c={c}) mapped to TCol {int(axis_values['TCol'])}, 
expected {c}"
+            )
+
+
[email protected]("shape", ["16x64b", "16x128b", "16x256b"])
[email protected]("rep", [1, 2, 4])
+def test_tcgen05_atom_layout_apply_matches_decompose_fp32(shape, rep):
+    """``tcgen05_atom_layout`` is supposed to be the inverse of
+    ``_decompose_fp32`` — i.e. for every (row, col) in the M=64 fragment,
+    ``layout.apply(row, col)`` must return the (laneid, wid_in_wg, m)
+    tuple that PTX puts at frag element ``(row, col)``.
+
+    The factory's per-shape iter lists are written low-to-high (natural
+    decomposition); the reversal added below is what aligns the resulting
+    TileLayout with ``SplitCoord`` (high-to-low). Without the reversal the
+    factory used to silently produce a layout that disagreed with PTX —
+    the round-trip tests didn't catch it because the dispatch ignores the
+    layout label and emits raw PTX. This sweep is the structural fence.
+    """
+    if rep not in _SHAPE_REPS[shape]:
+        pytest.skip(f"rep {rep} not valid for {shape}")
+    cols = _COL_FACTOR_FP32[shape] * rep  # K_cols_fp32
+    layout = tcgen05_atom_layout(shape, (64, cols), "float32")
+    for thread in range(128):
+        laneid = thread & 31
+        wid_in_wg = thread >> 5
+        regs_per_thread = _REGS_FACTOR[shape] * rep
+        for reg in range(regs_per_thread):
+            row, col = _decompose_fp32(shape, thread, reg)
+            axis_values = layout.apply(row, col, shape=[64, cols])
+            assert int(axis_values.get("laneid", 0)) == laneid, (
+                f"shape={shape} rep={rep}: (row={row}, col={col}) "
+                f"mapped to laneid {int(axis_values.get('laneid', 0))}, 
expected {laneid}"
+            )
+            assert int(axis_values.get("wid_in_wg", 0)) == wid_in_wg, (
+                f"shape={shape} rep={rep}: (row={row}, col={col}) "
+                f"mapped to wid_in_wg {int(axis_values.get('wid_in_wg', 0))}, 
expected {wid_in_wg}"
+            )
+            assert int(axis_values.get("m", 0)) == reg, (
+                f"shape={shape} rep={rep}: (row={row}, col={col}) "
+                f"mapped to m {int(axis_values.get('m', 0))}, expected {reg}"
+            )
+
+
+def test_tmem_datapath_layout_D_row_to_lane_mapping():
+    """Layout D: identity row→lane (no scatter)."""
+    cols = 32
+    layout = tmem_datapath_layout("D", 128, cols)
+    for r in [0, 1, 15, 16, 31, 32, 63, 64, 127]:
+        axis_values = layout.apply(r, 0, shape=[128, cols])
+        assert int(axis_values["TLane"]) == r, (
+            f"r={r} mapped to TLane {int(axis_values['TLane'])}, expected {r}"
+        )
+
+
+# Negative tests: the datapath/atom pairing matrix in ``tcgen05_ldst.py``
+# must reject mismatched combinations. We construct a Layout F TMEM buffer
+# (64 rows, scattered) and try to read it with a ``.16x*b`` M=128 atom,
+# which would interpret the second slab (lanes 16..31 of each warp) as
+# meaningful data — but Layout F leaves that slab undefined. Compilation
+# must raise a clear error, not silently emit a broken kernel.
[email protected]("atom_kind,frag_rows", [("16x*b", 128), ("32x32b", 
128)])
+def test_layout_F_rejects_incompatible_atoms(atom_kind, frag_rows):
+    """Layout F + (.16x*b M=128 or .32x32b) must raise at compile time."""
+    if atom_kind == "16x*b":
+        shape = "16x256b"
+        rep = 1
+        # Local fragment shape for M=128 .16x256b rep=1 = (128, 8) fp32.
+        atom_view = tcgen05_atom_layout(shape, (128, 8), "float32")
+        local_extent_rows = 128
+        local_cols = 8
+    else:  # .32x32b path: local (128, 32) fp32
+        atom_view = TileLayout(S[(128, 32) : (1 @ axis_tid_in_wg, 1)])
+        local_extent_rows = 128
+        local_cols = 32
+
+    tmem_layout = tmem_datapath_layout("F", 64, max(32, local_cols))
+    tmem_rows = 64
+    stage_width_elem = max(32, local_cols)
+
+    @Tx.prim_func
+    def kernel() -> None:
+        Tx.device_entry()
+        Tx.warp_id([128 // 32])
+        Tx.cta_id([2])
+        wg_id = Tx.warpgroup_id([1])
+        Tx.warp_id_in_wg([4])
+        Tx.lane_id([32])
+        Tx.thread_id([128])
+        tmem_addr = Tx.alloc_shared([1], "uint32")
+        if wg_id == 0:
+            with Tx.warpgroup():
+                Tx.tvm_storage_sync("shared")
+                tmem = Tx.decl_buffer(
+                    (tmem_rows, stage_width_elem),
+                    "float32",
+                    scope="tmem",
+                    allocated_addr=tmem_addr[0],
+                    layout=tmem_layout,
+                )
+                frag = Tx.alloc_local((local_extent_rows * local_cols // 
128,), "float32")
+                frag_view = frag.view(local_extent_rows, local_cols, 
layout=atom_view)
+                Tx.copy_async(frag_view[:, :], tmem[0:local_extent_rows, 
0:local_cols])
+
+    target = tvm.target.Target("cuda")
+    with target:
+        mod = tvm.IRModule({"main": kernel})
+        with pytest.raises((ValueError, RuntimeError), match="datapath"):
+            tvm.compile(mod, target=target, tir_pipeline="tirx")
+
+
 def _run_load_test(shape: str, rep: int, dtype: str):
     """Stage A into TMEM via .32x32b, then read it back as the fragment via
     .<shape>.x<rep> (through ``Tx.alloc_tcgen05_frag``), and compare each

Reply via email to