This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch tir-bench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit 7ea22041de1a0f59eab44908123c69b4da121418
Author: Bohan Hou <[email protected]>
AuthorDate: Mon May 18 18:59:12 2026 -0400

    test(op-dispatch): add wg-scope both-sides-permuted invariance test (#632) 
(#633)
    
    * test(op-dispatch): add wg-scope both-sides-permuted invariance test
    
    Confirms the bottom line: when SMEM AND the local view are permuted the
    SAME way (axes swapped, strides + layout shard reordered accordingly),
    the dispatcher emits identical per-lane SMEM byte addresses. The
    existing permutation tests permuted only the SMEM side; this strengthens
    to the full both-sides case the wg dispatcher's algorithm — finding
    wid_in_wg's position in the local shard + stride-based row/col on SMEM —
    is designed to handle.
    
    * fix(op-dispatch): require Tx.copy LHS/RHS dims to match in order
    
    The local-vs-SMEM extent check was using ``sorted`` (multiset equality),
    which let mismatched-dim-order Tx.copy through — e.g. SMEM region (32, 8)
    paired with a local view (8, 32) would be accepted on the multiset
    [8, 32] == [8, 32], even though dim 0 of LHS and dim 0 of RHS describe
    different axes. Tx.copy requires per-dim correspondence; tighten to
    order-preserving comparison.
    
    The 3 existing warp-scope permutation invariance tests relied on the
    permissive check (kept the local layout fixed while permuting only the
    SMEM side, leaving LHS/RHS dim orders mismatched). Updated each test to
    also permute its local view so LHS and RHS describe the SAME logical
    axes — that's the property the test is actually about.
---
 .../tile_primitive/cuda/copy/ld_stmatrix.py        | 10 ++-
 .../tile_primitive/cuda/test_ldstmatrix.py         | 86 +++++++++++++++++++---
 2 files changed, 84 insertions(+), 12 deletions(-)

diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py 
b/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py
index aace360aca..3c03f2f37a 100644
--- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py
+++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py
@@ -452,12 +452,16 @@ def _bind_common(op_call: TilePrimitiveCall, 
want_direction: str, *, is_wg: bool
     local_ext_i = [_as_int(e) for e in local_ext]
     if None in local_ext_i:
         return f"local extents must be compile-time integers, got {local_ext}"
-    smem_non_unit = sorted([e for e in smem_ext_i if e != 1])
-    local_non_unit = sorted([e for e in local_ext_i if e != 1])
+    # Strict order-preserving dim correspondence after dropping unit dims:
+    # Tx.copy requires LHS dim i and RHS dim i to describe the same logical
+    # iteration extent — not just the same multiset.
+    smem_non_unit = [e for e in smem_ext_i if e != 1]
+    local_non_unit = [e for e in local_ext_i if e != 1]
     if smem_non_unit != local_non_unit:
         return (
             f"local region {local_ext_i} non-unit extents must match SMEM "
-            f"region {smem_ext_i} (got {local_non_unit} vs {smem_non_unit})"
+            f"region {smem_ext_i} dim-for-dim "
+            f"(got {local_non_unit} vs {smem_non_unit})"
         )
 
     cfg = op_call.config or {}
diff --git a/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py 
b/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py
index 1d0816bf9e..a5befb7760 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py
@@ -218,7 +218,16 @@ def _assert_permute_same_addr(f_ref, f_perm, 
expected_inst):
 
 
 def test_permutation_invariance_horizontal():
-    layout = _x4_h_warp_layout()
+    """Both SMEM and local view permuted (row/col axes swapped) — must
+    produce identical per-lane SMEM addresses."""
+    ref_layout = _x4_h_warp_layout()  # (8, 32)
+    # Permuted local (32, 8): lane_high+m → dim 0 (col); laneid_low → dim 1 
(row).
+    perm_layout = TileLayout(S[(4, 4, 2, 8) : (
+        8 @ Axis.laneid,
+        2,
+        1,
+        1 @ Axis.laneid,
+    )])
 
     @Tx.prim_func
     def f_ref():
@@ -232,7 +241,7 @@ def test_permutation_invariance_horizontal():
                 )
                 with Tx.warp():
                     regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
-                    regs_warp = regs.view(8, 32, layout=layout)
+                    regs_warp = regs.view(8, 32, layout=ref_layout)
                     Tx.copy(D[0, 0:8, 0:32], regs_warp[0:8, 0:32], trans=True)
 
     @Tx.prim_func
@@ -247,8 +256,8 @@ def test_permutation_invariance_horizontal():
                 )
                 with Tx.warp():
                     regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
-                    regs_warp = regs.view(8, 32, layout=layout)
-                    Tx.copy(D[0, 0:32, 0:8], regs_warp[0:8, 0:32], trans=True)
+                    regs_warp = regs.view(32, 8, layout=perm_layout)
+                    Tx.copy(D[0, 0:32, 0:8], regs_warp[0:32, 0:8], trans=True)
 
     _assert_permute_same_addr(
         f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
@@ -256,7 +265,10 @@ def test_permutation_invariance_horizontal():
 
 
 def test_permutation_invariance_vertical():
-    layout = _x4_v_warp_layout()
+    """Vertical x4 with both SMEM and local view axes swapped."""
+    ref_layout = _x4_v_warp_layout()  # (32, 8)
+    # Permuted (8, 32): per-thread m → dim 0 (col); all laneid → dim 1 (row).
+    perm_layout = TileLayout(S[(8, 32) : (1, 1 @ Axis.laneid)])
 
     @Tx.prim_func
     def f_ref():
@@ -270,7 +282,7 @@ def test_permutation_invariance_vertical():
                 )
                 with Tx.warp():
                     regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
-                    regs_warp = regs.view(32, 8, layout=layout)
+                    regs_warp = regs.view(32, 8, layout=ref_layout)
                     Tx.copy(D[0:32, 0:8], regs_warp[0:32, 0:8], trans=False)
 
     @Tx.prim_func
@@ -285,8 +297,8 @@ def test_permutation_invariance_vertical():
                 )
                 with Tx.warp():
                     regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
-                    regs_warp = regs.view(32, 8, layout=layout)
-                    Tx.copy(D[0:8, 0:32], regs_warp[0:32, 0:8], trans=False)
+                    regs_warp = regs.view(8, 32, layout=perm_layout)
+                    Tx.copy(D[0:8, 0:32], regs_warp[0:8, 0:32], trans=False)
 
     _assert_permute_same_addr(
         f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.shared.b16"
@@ -311,6 +323,15 @@ def test_permutation_invariance_2x2():
                     regs_warp = regs.view(16, 16, layout=layout)
                     Tx.copy(D[0:16, 0:16], regs_warp[0:16, 0:16], trans=False)
 
+    # Permuted: row/col axes swap → row_block iter moves to dim 1, col_block
+    # to dim 0 (the dim with larger stride is row).
+    perm_layout = TileLayout.from_iters([
+        Iter(2, 8, Axis.laneid),    # col_block (was on dim 1) → dim 0 stride 8
+        Iter(8, 1, Axis.m),         # per-thread → dim 0
+        Iter(8, 1, Axis.laneid),    # lane_low → dim 1 row stride 1
+        Iter(2, 16, Axis.laneid),   # row_block → dim 1 row stride 8
+    ])
+
     @Tx.prim_func
     def f_perm():
         with Tx.kernel():
@@ -323,7 +344,7 @@ def test_permutation_invariance_2x2():
                 )
                 with Tx.warp():
                     regs = Tx.alloc_buffer((8,), "bfloat16", scope="local")
-                    regs_warp = regs.view(16, 16, layout=layout)
+                    regs_warp = regs.view(16, 16, layout=perm_layout)
                     Tx.copy(D[0:16, 0:16], regs_warp[0:16, 0:16], trans=False)
 
     _assert_permute_same_addr(
@@ -387,6 +408,53 @@ def test_warp_2x2():
 # ---------------------------------------------------------------------------
 
 
+def test_wg_permutation_invariance_both_sides():
+    """**Both** SMEM and local view permuted the SAME way (axes swapped):
+    the per-lane SMEM byte address must come out identical. The dispatcher
+    reads ``wid_in_wg``'s position in the local shard + stride-based row/col
+    on SMEM, so a consistent axis permutation just reshuffles the lookups —
+    final address is unchanged."""
+    # Reference: (8, 128) row-major + standard layout.
+    ref_local = TileLayout(S[(8, 4, 4, 8) : (
+        1 @ Axis.laneid,
+        1 @ Axis.wid_in_wg,
+        8 @ Axis.laneid,
+        1,
+    )])
+    # Permuted: (128, 8) col-fast + layout with the laneid_low iter moved to
+    # the new dim 1, everything else in the new dim 0.
+    perm_local = TileLayout(S[(4, 4, 8, 8) : (
+        1 @ Axis.wid_in_wg,
+        8 @ Axis.laneid,
+        1,
+        1 @ Axis.laneid,
+    )])
+
+    def make(shape, strides, layout, sl):
+        @Tx.prim_func
+        def f():
+            with Tx.kernel():
+                Tx.cta_id([1])
+                Tx.thread_id([128])
+                with Tx.cta():
+                    D = Tx.alloc_buffer(
+                        shape, "bfloat16", scope="shared",
+                        layout=TileLayout(S[shape : strides]),
+                    )
+                    with Tx.warpgroup():
+                        regs = Tx.alloc_buffer((4,), "uint32", scope="local")
+                        regs_wg = regs.view("bfloat16").view(*shape, 
layout=layout)
+                        Tx.copy(D[sl], regs_wg[sl], trans=True)
+        return f
+
+    f_ref = make((8, 128), (128, 1), ref_local, (slice(0, 8), slice(0, 128)))
+    f_perm = make((128, 8), (1, 128), perm_local, (slice(0, 128), slice(0, 8)))
+
+    _assert_permute_same_addr(
+        f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16"
+    )
+
+
 def test_wg_stmatrix_x4_trans():
     wg_layout = _x4_h_wg_layout()
 

Reply via email to