This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 7ea22041de1a0f59eab44908123c69b4da121418 Author: Bohan Hou <[email protected]> AuthorDate: Mon May 18 18:59:12 2026 -0400 test(op-dispatch): add wg-scope both-sides-permuted invariance test (#632) (#633) * test(op-dispatch): add wg-scope both-sides-permuted invariance test Confirms the bottom line: when SMEM AND the local view are permuted the SAME way (axes swapped, strides + layout shard reordered accordingly), the dispatcher emits identical per-lane SMEM byte addresses. The existing permutation tests permuted only the SMEM side; this strengthens to the full both-sides case the wg dispatcher's algorithm — finding wid_in_wg's position in the local shard + stride-based row/col on SMEM — is designed to handle. * fix(op-dispatch): require Tx.copy LHS/RHS dims to match in order The local-vs-SMEM extent check was using ``sorted`` (multiset equality), which let mismatched-dim-order Tx.copy through — e.g. SMEM region (32, 8) paired with a local view (8, 32) would be accepted on the multiset [8, 32] == [8, 32], even though dim 0 of LHS and dim 0 of RHS describe different axes. Tx.copy requires per-dim correspondence; tighten to order-preserving comparison. The 3 existing warp-scope permutation invariance tests relied on the permissive check (kept the local layout fixed while permuting only the SMEM side, leaving LHS/RHS dim orders mismatched). Updated each test to also permute its local view so LHS and RHS describe the SAME logical axes — that's the property the test is actually about. --- .../tile_primitive/cuda/copy/ld_stmatrix.py | 10 ++- .../tile_primitive/cuda/test_ldstmatrix.py | 86 +++++++++++++++++++--- 2 files changed, 84 insertions(+), 12 deletions(-) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py index aace360aca..3c03f2f37a 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/ld_stmatrix.py @@ -452,12 +452,16 @@ def _bind_common(op_call: TilePrimitiveCall, want_direction: str, *, is_wg: bool local_ext_i = [_as_int(e) for e in local_ext] if None in local_ext_i: return f"local extents must be compile-time integers, got {local_ext}" - smem_non_unit = sorted([e for e in smem_ext_i if e != 1]) - local_non_unit = sorted([e for e in local_ext_i if e != 1]) + # Strict order-preserving dim correspondence after dropping unit dims: + # Tx.copy requires LHS dim i and RHS dim i to describe the same logical + # iteration extent — not just the same multiset. + smem_non_unit = [e for e in smem_ext_i if e != 1] + local_non_unit = [e for e in local_ext_i if e != 1] if smem_non_unit != local_non_unit: return ( f"local region {local_ext_i} non-unit extents must match SMEM " - f"region {smem_ext_i} (got {local_non_unit} vs {smem_non_unit})" + f"region {smem_ext_i} dim-for-dim " + f"(got {local_non_unit} vs {smem_non_unit})" ) cfg = op_call.config or {} diff --git a/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py b/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py index 1d0816bf9e..a5befb7760 100644 --- a/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py +++ b/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py @@ -218,7 +218,16 @@ def _assert_permute_same_addr(f_ref, f_perm, expected_inst): def test_permutation_invariance_horizontal(): - layout = _x4_h_warp_layout() + """Both SMEM and local view permuted (row/col axes swapped) — must + produce identical per-lane SMEM addresses.""" + ref_layout = _x4_h_warp_layout() # (8, 32) + # Permuted local (32, 8): lane_high+m → dim 0 (col); laneid_low → dim 1 (row). + perm_layout = TileLayout(S[(4, 4, 2, 8) : ( + 8 @ Axis.laneid, + 2, + 1, + 1 @ Axis.laneid, + )]) @Tx.prim_func def f_ref(): @@ -232,7 +241,7 @@ def test_permutation_invariance_horizontal(): ) with Tx.warp(): regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") - regs_warp = regs.view(8, 32, layout=layout) + regs_warp = regs.view(8, 32, layout=ref_layout) Tx.copy(D[0, 0:8, 0:32], regs_warp[0:8, 0:32], trans=True) @Tx.prim_func @@ -247,8 +256,8 @@ def test_permutation_invariance_horizontal(): ) with Tx.warp(): regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") - regs_warp = regs.view(8, 32, layout=layout) - Tx.copy(D[0, 0:32, 0:8], regs_warp[0:8, 0:32], trans=True) + regs_warp = regs.view(32, 8, layout=perm_layout) + Tx.copy(D[0, 0:32, 0:8], regs_warp[0:32, 0:8], trans=True) _assert_permute_same_addr( f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16" @@ -256,7 +265,10 @@ def test_permutation_invariance_horizontal(): def test_permutation_invariance_vertical(): - layout = _x4_v_warp_layout() + """Vertical x4 with both SMEM and local view axes swapped.""" + ref_layout = _x4_v_warp_layout() # (32, 8) + # Permuted (8, 32): per-thread m → dim 0 (col); all laneid → dim 1 (row). + perm_layout = TileLayout(S[(8, 32) : (1, 1 @ Axis.laneid)]) @Tx.prim_func def f_ref(): @@ -270,7 +282,7 @@ def test_permutation_invariance_vertical(): ) with Tx.warp(): regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") - regs_warp = regs.view(32, 8, layout=layout) + regs_warp = regs.view(32, 8, layout=ref_layout) Tx.copy(D[0:32, 0:8], regs_warp[0:32, 0:8], trans=False) @Tx.prim_func @@ -285,8 +297,8 @@ def test_permutation_invariance_vertical(): ) with Tx.warp(): regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") - regs_warp = regs.view(32, 8, layout=layout) - Tx.copy(D[0:8, 0:32], regs_warp[0:32, 0:8], trans=False) + regs_warp = regs.view(8, 32, layout=perm_layout) + Tx.copy(D[0:8, 0:32], regs_warp[0:8, 0:32], trans=False) _assert_permute_same_addr( f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.shared.b16" @@ -311,6 +323,15 @@ def test_permutation_invariance_2x2(): regs_warp = regs.view(16, 16, layout=layout) Tx.copy(D[0:16, 0:16], regs_warp[0:16, 0:16], trans=False) + # Permuted: row/col axes swap → row_block iter moves to dim 1, col_block + # to dim 0 (the dim with larger stride is row). + perm_layout = TileLayout.from_iters([ + Iter(2, 8, Axis.laneid), # col_block (was on dim 1) → dim 0 stride 8 + Iter(8, 1, Axis.m), # per-thread → dim 0 + Iter(8, 1, Axis.laneid), # lane_low → dim 1 row stride 1 + Iter(2, 16, Axis.laneid), # row_block → dim 1 row stride 8 + ]) + @Tx.prim_func def f_perm(): with Tx.kernel(): @@ -323,7 +344,7 @@ def test_permutation_invariance_2x2(): ) with Tx.warp(): regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") - regs_warp = regs.view(16, 16, layout=layout) + regs_warp = regs.view(16, 16, layout=perm_layout) Tx.copy(D[0:16, 0:16], regs_warp[0:16, 0:16], trans=False) _assert_permute_same_addr( @@ -387,6 +408,53 @@ def test_warp_2x2(): # --------------------------------------------------------------------------- +def test_wg_permutation_invariance_both_sides(): + """**Both** SMEM and local view permuted the SAME way (axes swapped): + the per-lane SMEM byte address must come out identical. The dispatcher + reads ``wid_in_wg``'s position in the local shard + stride-based row/col + on SMEM, so a consistent axis permutation just reshuffles the lookups — + final address is unchanged.""" + # Reference: (8, 128) row-major + standard layout. + ref_local = TileLayout(S[(8, 4, 4, 8) : ( + 1 @ Axis.laneid, + 1 @ Axis.wid_in_wg, + 8 @ Axis.laneid, + 1, + )]) + # Permuted: (128, 8) col-fast + layout with the laneid_low iter moved to + # the new dim 1, everything else in the new dim 0. + perm_local = TileLayout(S[(4, 4, 8, 8) : ( + 1 @ Axis.wid_in_wg, + 8 @ Axis.laneid, + 1, + 1 @ Axis.laneid, + )]) + + def make(shape, strides, layout, sl): + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([128]) + with Tx.cta(): + D = Tx.alloc_buffer( + shape, "bfloat16", scope="shared", + layout=TileLayout(S[shape : strides]), + ) + with Tx.warpgroup(): + regs = Tx.alloc_buffer((4,), "uint32", scope="local") + regs_wg = regs.view("bfloat16").view(*shape, layout=layout) + Tx.copy(D[sl], regs_wg[sl], trans=True) + return f + + f_ref = make((8, 128), (128, 1), ref_local, (slice(0, 8), slice(0, 128))) + f_perm = make((128, 8), (1, 128), perm_local, (slice(0, 128), slice(0, 8))) + + _assert_permute_same_addr( + f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + ) + + def test_wg_stmatrix_x4_trans(): wg_layout = _x4_h_wg_layout()
