This is an automated email from the ASF dual-hosted git repository. spectrometerHBH pushed a commit to branch tir-bench in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 6053f82fcc3efef8ee355d1b0f794c06db6decb0 Author: Bohan Hou <[email protected]> AuthorDate: Mon May 18 17:48:03 2026 -0400 feat(op-dispatch): add warp ldmatrix/stmatrix dispatch for Tx.copy (#630) * feat(op-dispatch): add warp ldmatrix/stmatrix dispatch for Tx.copy Adds warp-cooperative variants for the Tx.copy primitive that lower to a single PTX ldmatrix.sync.aligned.m8n8.x{1,2,4}[.trans].b16 or stmatrix. The dispatcher picks num and (row, col) axes from the SMEM region's buffer-layout strides so it is invariant under consistent permutation of the SMEM tensor's shape/strides, and supports horizontal, vertical, and 2x2-grid arrangements. SwizzleLayout XOR is honored automatically via buf.ptr_to(...). 24 unit tests cover all (num × trans × direction) configs, swizzle equivalence, and permutation invariance (byte-for-byte identical addresses across permuted layouts). * feat(op-dispatch): add wg-scope warp_stmatrix/ldmatrix + active-set check Extends the warp-cooperative ldmatrix/stmatrix dispatch to warpgroup scope: each Tx.copy at warpgroup scope emits 4 stmatrix instructions (one per warp) covering 4 per-warp tiles inside the SMEM region. The per-warp distribution is read from the LOCAL fragment's TileLayout shard (its ``wid_in_wg`` iter's position + extents of subsequent iters in the same dim), so it's invariant under the layout's internal structure choice. Also adds an active-thread-set check at both warp and wg scope: PTX ld/stmatrix require every lane of the participating warp(s) to be active, so an enclosing ``if Tx.filter(...)`` that narrows ``sctx.intra['laneid']`` or ``sctx.intra['wid_in_wg']`` is rejected (laneid must be (32, 0); for wg scope wid_in_wg must also be (4, 0)). Tests: 28 cases — original warp coverage, new wg dispatch, layout-required rejection, and active-set narrowing rejection for both warp and wg. * docs(test): trim wg layout helper docstring * refactor(op-dispatch): require warp/wg-wide local view with matching extents Tx.copy semantics demand LHS/RHS region extents to match. Previously the warp/wg dispatcher accepted a per-thread local fragment (e.g. ``regs[0:8]``) against a warp-wide SMEM tile (e.g. ``D[0:8, 0:32]``) — the byte addresses happened to come out right but the extents did not match, so the call was not a well-formed Tx.copy. The dispatcher now requires: * Warp scope: local is a warp-wide VIEW whose shape equals the SMEM region extents (modulo unit dims), with laneid iters in its layout shard whose extents multiply to 32 (full warp). * Warpgroup scope: same but ALSO with a wid_in_wg iter (extent 4). Layouts are written bottom-up with ``.tile().tile()``: pure_m → tile laneid → (tile wid_in_wg) Helpers in the test file cover x4 horizontal / vertical / 2×2 arrangements. The impl decomposes the local view via ``local_buf.local()`` inside ``with Tx.warp():`` to get a per-thread fragment for the PTX intrinsics. Coverage: 18 tests for x4 (warp/wg, st/ld, swizzle, permutation invariance, all 3 arrangements, rejection paths). x1/x2 dropped pending their more complex lane→matrix mapping; PTX intrinsics still available. * docs(test): inline x4 layout helpers as direct TileLayout shards --- .../operator/tile_primitive/cuda/copy/__init__.py | 1 + .../tile_primitive/cuda/copy/warp_matrix.py | 717 +++++++++++++++++++++ .../tile_primitive/cuda/test_ldstmatrix.py | 564 ++++++++++++++++ 3 files changed, 1282 insertions(+) diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/__init__.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy/__init__.py index b1b1cc4591..0c236f3a0c 100644 --- a/python/tvm/tirx/operator/tile_primitive/cuda/copy/__init__.py +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/__init__.py @@ -25,3 +25,4 @@ from .utils import ( copy_default_impl, ) from .vectorized import * +from .warp_matrix import * diff --git a/python/tvm/tirx/operator/tile_primitive/cuda/copy/warp_matrix.py b/python/tvm/tirx/operator/tile_primitive/cuda/copy/warp_matrix.py new file mode 100644 index 0000000000..aace360aca --- /dev/null +++ b/python/tvm/tirx/operator/tile_primitive/cuda/copy/warp_matrix.py @@ -0,0 +1,717 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. + +"""CUDA copy dispatch: warp-cooperative ldmatrix / stmatrix (PTX m8n8 b16). + +Registered ops: copy (variant=warp_ldmatrix, warp_stmatrix; priority=25). + +Each ``Tx.copy(SMEM[region], LOCAL[region])`` (or the reverse) emits a single +``ldmatrix.sync.aligned.m8n8.x{1,2,4}[.trans].b16`` / ``stmatrix`` PTX +instruction — the user does the outer iteration; one ``Tx.copy(...)`` call +lowers to one PTX instruction. + +Dispatcher contract (called once per ``Tx.copy``): + + * direction: ``local`` ↔ ``shared::cta`` determines st vs ld. + * ``num`` (1/2/4) is inferred from the SMEM region's last two non-unit dims: + ``(8, 8)`` → x1, ``(8, 16)`` → x2, ``(8, 32)`` → x4 (horizontal stack only + for now; vertical / 2×2 stacks are future work). + * ``trans`` is read from the kwarg ``trans=True/False`` on ``Tx.copy``. + (Auto-inferring trans from the local fragment's TileLayout when the user + encodes it as a thread-distributed layout is future work — for now the + user passes the flag explicitly, matching the existing fp8 callsite.) + * The local fragment is per-thread (each lane holds ``num`` b32 register + slots = ``num*32`` bits). The dtype can be any width that evenly divides + 32 (bfloat16/float16 → ``num*2`` elements/lane; uint32 → ``num``). + * Per-lane SMEM address: lane k provides row ``k % 8`` of matrix + ``k // 8 if num>1 else 0``. The address is computed via + ``smem_buf.ptr_to([...])`` so any ``SwizzleLayout`` / ``ComposeLayout`` on + the SMEM buffer is honored automatically — exact equivalence with the + hand-written XOR-swizzled form in fp8_blockwise_gemm depends on this. + +Bottom-line invariants: + + * On the fp8 ``stmatrix.x4.trans`` callsite, the generated PTX is identical + to the hand-written form (same per-lane addr expression, same swizzle XOR + via the SMEM buffer's layout, same num/trans). + * Any input that the dispatcher cannot prove correctly maps to one of the + PTX-defined fragment shapes is rejected — falls through to scalar/vec + variants, never silently mis-emitted. +""" + +from __future__ import annotations + +import math +import re +from dataclasses import dataclass +from typing import Optional + +from tvm.runtime import DataType +from tvm.script import tirx as Tx +from tvm.tirx import Buffer, BufferRegion, IntImm, PrimFunc +from tvm.tirx.layout import Axis, ComposeLayout, TileLayout +from tvm.tirx.operator.tile_primitive import DispatchContext, fail, predicate, register_dispatch +from tvm.tirx.stmt import TilePrimitiveCall + +from ..exec_scope_utils import exec_scope_ok +from .utils import _scope_allowed + + +# ---------- helpers --------------------------------------------------------- + + +def _as_int(x): + """Best-effort coercion to a Python int; returns None for non-static.""" + if isinstance(x, int): + return x + if isinstance(x, IntImm): + return int(x.value) + if hasattr(x, "value") and isinstance(x.value, int): + return int(x.value) + return None + + +def _region_st_ext(region: BufferRegion): + return [r.min for r in region.region], [r.extent for r in region.region] + + +def _direction(op_call: TilePrimitiveCall) -> Optional[str]: + """Return ``'st'`` (local→shared), ``'ld'`` (shared→local), or None.""" + op_call = TilePrimitiveCall.downcast(op_call) + s = op_call.src.buffer.scope() + d = op_call.dst.buffer.scope() + if s == "local" and d.startswith("shared"): + return "st" + if s.startswith("shared") and d == "local": + return "ld" + return None + + +def _buffer_per_dim_stride(buf: Buffer) -> Optional[list[int]]: + """For each buffer dim, return the per-unit element stride. + + Walks ``buf.layout`` (TileLayout or ComposeLayout(Swizzle, TileLayout)), + groups its shard by the buffer shape, and takes the **smallest** shard + stride within each group — that's the stride incurred by incrementing + that buffer dim by 1. Returns None if any stride is non-static or the + layout isn't supported. + + The SwizzleLayout XORs bits of the resulting linear offset; it does NOT + change which buffer dim has the bigger macro-stride. So using the inner + TileLayout's strides is sound for our row-vs-col identification. + """ + layout = buf.layout + if layout is None: + return None + tile_layout = layout + if isinstance(layout, ComposeLayout): + tile_layout = layout.tile_layout + if not isinstance(tile_layout, TileLayout): + return None + shard = getattr(tile_layout, "shard", None) + if not shard: + return None + try: + grouped, seps = tile_layout.group(list(buf.shape)) + except Exception: # noqa: BLE001 + return None + strides: list[int] = [] + for d in range(len(buf.shape)): + start, end = int(seps[d]), int(seps[d + 1]) + if end == start: + strides.append(0) + continue + group_strides = [] + for i in range(start, end): + s = _as_int(grouped.shard[i].stride) + if s is None: + return None + group_strides.append(s) + # Smallest stride in the group = stride incurred by buffer-dim+=1 + # (we walk the inner-most shard first). + strides.append(min(group_strides)) + return strides + + +# Arrangement constants: how the ``num`` 8×8 matrices are laid out in the +# 2D SMEM region (per-warp tile). +_HORIZONTAL = "horizontal" # 8 × (num*8): matrices side-by-side along col_dim +_VERTICAL = "vertical" # (num*8) × 8: matrices stacked along row_dim +_GRID_2X2 = "2x2" # 16 × 16: 4 matrices in a 2×2 grid (num=4 only) + + +def _has_full_laneid_iters(local_buf: Buffer) -> bool: + """Check the local layout's shard has laneid iters whose extents multiply + to 32 (full warp coverage).""" + if local_buf.layout is None: + return False + shard = getattr(local_buf.layout, "shard", None) + if not shard: + return False + product = 1 + for it in shard: + if it.axis.name == "laneid": + e = _as_int(it.extent) + if e is None: + return False + product *= e + return product == 32 + + +def _wg_distribution_from_layout( + local_buf: Buffer, smem_ext_i: list[int] +) -> Optional[tuple[int, int, list[int]]]: + """Read the warp distribution from the local fragment's layout. + + The local view must be wg-wide (its shape matches the SMEM region + extents) and carry a ``wid_in_wg`` iter in the shard. Returns + ``(wg_axis_dim, wg_step, per_warp_extents)``: + + * ``wg_axis_dim`` — the local/SMEM shape dim in which the ``wid_in_wg`` + iter lives (via ``layout.group(buf.shape)``). + * ``wg_step`` — the **shape-coord step** corresponding to ``warp_id += + 1`` along that dim. In TileLayout shard, iters within a dim are + ordered slowest-to-fastest in mixed-radix; the step is the product + of subsequent (faster) iters' extents in the same dim. + * ``per_warp_extents`` — ``smem_ext_i`` with ``wg_axis_dim`` reduced + from ``ext`` to ``ext // 4`` (the per-warp tile). + + Note: this is shape-coord units, not linear stride. The dispatcher adds + ``warp_id * wg_step`` to ``smem_idx[wg_axis_dim]`` directly. + """ + if local_buf.layout is None: + return None + shard = getattr(local_buf.layout, "shard", None) + if not shard: + return None + wid_pos = None + wid_iter = None + for i, it in enumerate(shard): + if it.axis.name == "wid_in_wg": + if wid_iter is not None: + return None # multiple wid_in_wg iters → too complex for now + wid_pos = i + wid_iter = it + if wid_iter is None: + return None + if _as_int(wid_iter.extent) != 4: + return None + + # Find which local-buffer dim the wid iter belongs to. + try: + grouped, seps = local_buf.layout.group(list(local_buf.shape)) + except Exception: # noqa: BLE001 + return None + wid_local_dim = None + for d in range(len(local_buf.shape)): + if int(seps[d]) <= wid_pos < int(seps[d + 1]): + wid_local_dim = d + break + if wid_local_dim is None: + return None + + # Per-warp shape step = product of extents of subsequent iters in the + # same dim (faster-changing axes), all in the SAME shape-dim segment. + dim_end = int(seps[wid_local_dim + 1]) + wg_step = 1 + for i in range(wid_pos + 1, dim_end): + e = _as_int(grouped.shard[i].extent) + if e is None: + return None + wg_step *= e + + # Map local dim → SMEM dim by aligning non-unit dims one-to-one. The + # local view's shape must equal the SMEM region's non-unit extents. + local_shape_i = [_as_int(s) for s in local_buf.shape] + if None in local_shape_i: + return None + smem_non_unit = [(i, e) for i, e in enumerate(smem_ext_i) if e != 1] + if [e for _, e in smem_non_unit] != local_shape_i: + return None + wid_smem_dim = smem_non_unit[wid_local_dim][0] + + # 4 warps × wg_step must fit the SMEM dim's extent. + if smem_ext_i[wid_smem_dim] != wg_step * 4: + return None + + per_warp = list(smem_ext_i) + per_warp[wid_smem_dim] = wg_step + return wid_smem_dim, wg_step, per_warp + + +def _infer_arrangement( + smem_ext_i: list[int], smem_strides: list[int] +) -> Optional[tuple[int, int, int, str]]: + """Identify the m8n8.x{1,2,4} arrangement via the buffer's per-dim strides. + + Among the slice's non-unit dims, the dim with the LARGER buffer stride is + the "row direction" (where matrix rows live), the SMALLER is the "col + direction". Returns ``(num, row_dim, col_dim, arrangement)`` where + ``arrangement`` is one of: + + * ``"horizontal"`` — row=8, col=num*8 ∈ {8,16,32}: matrices side-by-side + along col_dim. Lane k: row_dim += k%8; col_dim += matrix_id*8. + * ``"vertical"`` — row=num*8 ∈ {16,32}, col=8: matrices stacked along + row_dim. Lane k: row_dim += matrix_id*8 + k%8; col_dim += 0. + * ``"2x2"`` — row=16, col=16 (num=4 only): four matrices in a 2×2 grid. + Lane k: row_dim += (matrix_id//2)*8 + k%8; col_dim += (matrix_id%2)*8. + + Returns None if no pattern matches or if both stride and extent ties make + the row/col choice genuinely ambiguous (degenerate square with equal + strides — pathological, can be addressed if a use case appears). + """ + non_unit_idxs = [i for i, e in enumerate(smem_ext_i) if e != 1] + if len(non_unit_idxs) != 2: + return None + i0, i1 = non_unit_idxs + s0, s1 = smem_strides[i0], smem_strides[i1] + e0, e1 = smem_ext_i[i0], smem_ext_i[i1] + if s0 > s1: + row_dim, col_dim = i0, i1 + elif s1 > s0: + row_dim, col_dim = i1, i0 + elif e0 != e1: + # Strides tied but extents differ — the dim with the smaller extent + # is conventionally the "row" (8 rows per matrix in PTX m8n8). + if e0 < e1: + row_dim, col_dim = i0, i1 + else: + row_dim, col_dim = i1, i0 + else: + # Both strides AND extents equal — genuinely ambiguous (degenerate + # square). Caller can resolve by reshaping or by choosing a + # non-square slice. + return None + e_row = smem_ext_i[row_dim] + e_col = smem_ext_i[col_dim] + + if e_row == 8 and e_col in (8, 16, 32): + return e_col // 8, row_dim, col_dim, _HORIZONTAL + if e_row in (16, 32) and e_col == 8: + return e_row // 8, row_dim, col_dim, _VERTICAL + if e_row == 16 and e_col == 16: + return 4, row_dim, col_dim, _GRID_2X2 + return None + + +@dataclass +class _Bound: + num: int + trans: bool + direction: str # "st" or "ld" + smem_region: BufferRegion + local_region: BufferRegion + row_dim: int # SMEM buffer dim with LARGER stride (or smaller extent on tie) + col_dim: int # SMEM buffer dim with SMALLER stride + arrangement: str # one of _HORIZONTAL / _VERTICAL / _GRID_2X2 + local_elements_per_b32: int + # Warpgroup-scope fields. ``wg`` is False for warp-scope binds. + wg: bool = False + # The SMEM dim along which the 4 warps walk (each warp adds + # ``warp_id * wg_step`` to this dim on top of the per-stamp offset). + wg_axis_dim: int = -1 + wg_step: int = 0 + + +def _try_bind(op_call: TilePrimitiveCall, sctx: DispatchContext, want_direction: str): + """Validate and bind dispatcher state for **warp** scope. Returns + ``_Bound`` on success or a short error string on rejection.""" + if not sctx.is_warp: + return f"exec_scope is {sctx.scope_kind!r}, not 'warp'" + err = _check_full_active_set(sctx, is_wg=False) + if err is not None: + return err + return _bind_common(op_call, want_direction, is_wg=False) + + +def _try_bind_wg(op_call: TilePrimitiveCall, sctx: DispatchContext, want_direction: str): + """Validate and bind for **warpgroup** scope. Returns ``_Bound`` (with + ``wg=True`` and warp-walk fields populated) or an error string.""" + if not sctx.is_warpgroup: + return f"exec_scope is {sctx.scope_kind!r}, not 'warpgroup'" + err = _check_full_active_set(sctx, is_wg=True) + if err is not None: + return err + return _bind_common(op_call, want_direction, is_wg=True) + + +def _check_full_active_set(sctx: DispatchContext, *, is_wg: bool) -> Optional[str]: + """Verify the active thread set is the FULL warp/warpgroup. + + PTX ldmatrix/stmatrix requires every lane of the participating warp to be + active (32-lane sync). If an enclosing ``if Tx.filter(...)`` narrowed the + active set, ``sctx.intra`` reports the reduced extent — we reject those + cases here. + + For warp scope: laneid must be (32, 0). + For warpgroup scope: laneid (32, 0) AND wid_in_wg (4, 0). + """ + required = {"laneid": 32} + if is_wg: + required["wid_in_wg"] = 4 + for axis_name, expected in required.items(): + if axis_name not in sctx.intra: + return f"sctx.intra missing {axis_name!r} (scope_kind={sctx.scope_kind!r})" + extent_raw, offset_raw = sctx.intra[axis_name] + extent = _as_int(extent_raw) + offset = _as_int(offset_raw) + if extent is None or offset is None: + return f"non-static active range for {axis_name}: ({extent_raw}, {offset_raw})" + if extent != expected or offset != 0: + return ( + f"active {axis_name} range is [{offset}, {offset + extent}); " + f"ldmatrix/stmatrix needs the full [0, {expected}) — an enclosing " + "if/filter has narrowed the warp" + ) + return None + + +def _bind_common(op_call: TilePrimitiveCall, want_direction: str, *, is_wg: bool): + direction = _direction(op_call) + if direction != want_direction: + return f"direction {direction} != {want_direction}" + + op_call = TilePrimitiveCall.downcast(op_call) + smem_region = op_call.dst if direction == "st" else op_call.src + local_region = op_call.src if direction == "st" else op_call.dst + + smem_buf: Buffer = smem_region.buffer + local_buf: Buffer = local_region.buffer + + # B1: SMEM dtype 16-bit (PTX .b16). Local dtype any width that divides 32. + smem_bits = DataType(smem_buf.dtype).bits + if smem_bits != 16: + return f"SMEM dtype must be 16-bit (b16), got {smem_buf.dtype}" + local_bits = DataType(local_buf.dtype).bits + if 32 % local_bits != 0: + return f"local dtype bits {local_bits} must evenly divide 32 (b32 reg unit)" + elements_per_b32 = 32 // local_bits + + # B2: SMEM region extents + buffer strides. + _, smem_ext = _region_st_ext(smem_region) + smem_ext_i = [_as_int(e) for e in smem_ext] + if None in smem_ext_i: + return f"SMEM extents must be compile-time integers, got {smem_ext}" + smem_strides = _buffer_per_dim_stride(smem_buf) + if smem_strides is None: + return f"could not determine static per-dim strides from SMEM layout {smem_buf.layout}" + + if is_wg: + # WG: local is a wg-wide view; layout carries laneid (extent 32 full) + # AND a wid_in_wg iter. Per-warp SMEM dim/step come from the + # wid_in_wg iter's position in the shard. + wg_info = _wg_distribution_from_layout(local_buf, smem_ext_i) + if wg_info is None: + return ( + f"warpgroup local fragment must be a wg-wide view (shape matching " + f"SMEM region) with a wid_in_wg iter in its layout shard; " + f"got shape={list(local_buf.shape)} layout={local_buf.layout}" + ) + wg_axis_dim, wg_step, per_warp_ext = wg_info + inferred = _infer_arrangement(per_warp_ext, smem_strides) + else: + # Warp: local is a warp-wide view; layout carries laneid iters whose + # extents multiply to 32 (full warp). Whole SMEM region is the per-warp + # tile (no warp-walk). + if not _has_full_laneid_iters(local_buf): + return ( + f"warp local fragment must be a warp-wide view (shape matching " + f"SMEM region) with laneid iters totaling extent 32 in its " + f"layout shard; got shape={list(local_buf.shape)} layout={local_buf.layout}" + ) + per_warp_ext = smem_ext_i + wg_axis_dim = -1 + wg_step = 0 + inferred = _infer_arrangement(smem_ext_i, smem_strides) + + if inferred is None: + return ( + f"per-warp tile {per_warp_ext} (strides {smem_strides}) doesn't match any " + "m8n8.x{1,2,4} arrangement" + ) + num, row_dim, col_dim, arrangement = inferred + + # B3: local fragment is a warp- or wg-wide VIEW. Its logical extents must + # equal the SMEM region extents (matching ``Tx.copy`` semantics — both + # sides describe the same region size). + _, local_ext = _region_st_ext(local_region) + local_ext_i = [_as_int(e) for e in local_ext] + if None in local_ext_i: + return f"local extents must be compile-time integers, got {local_ext}" + smem_non_unit = sorted([e for e in smem_ext_i if e != 1]) + local_non_unit = sorted([e for e in local_ext_i if e != 1]) + if smem_non_unit != local_non_unit: + return ( + f"local region {local_ext_i} non-unit extents must match SMEM " + f"region {smem_ext_i} (got {local_non_unit} vs {smem_non_unit})" + ) + + cfg = op_call.config or {} + trans = bool(cfg.get("trans", False)) + + return _Bound( + num=num, + trans=trans, + direction=direction, + smem_region=smem_region, + local_region=local_region, + row_dim=row_dim, + col_dim=col_dim, + arrangement=arrangement, + local_elements_per_b32=elements_per_b32, + wg=is_wg, + wg_axis_dim=wg_axis_dim, + wg_step=wg_step, + ) + + +def _sm_version(sctx: DispatchContext) -> int: + arch = getattr(sctx.target, "arch", "") or "" + m = re.match(r"sm_(\d+)", arch) + return int(m.group(1)) if m else 0 + + +def _make_predicate(want_direction: str, min_sm: int, *, wg: bool = False): + bind = _try_bind_wg if wg else _try_bind + + def _pred(op_call, sctx): + res = bind(op_call, sctx, want_direction) + if isinstance(res, str): + return False, res + sm = _sm_version(sctx) + if sm < min_sm: + name = "stmatrix" if want_direction == "st" else "ldmatrix" + return False, f"{name} requires sm_{min_sm}+, got sm_{sm}" + return True, None + return _pred + + +# ---------- impl ------------------------------------------------------------ + + +def _impl(op_call: TilePrimitiveCall, sctx: DispatchContext, want_direction: str) -> PrimFunc: + return _emit(op_call, sctx, want_direction, is_wg=False) + + +def _impl_wg(op_call: TilePrimitiveCall, sctx: DispatchContext, want_direction: str) -> PrimFunc: + return _emit(op_call, sctx, want_direction, is_wg=True) + + +def _emit( + op_call: TilePrimitiveCall, sctx: DispatchContext, want_direction: str, *, is_wg: bool +) -> PrimFunc: + res = (_try_bind_wg if is_wg else _try_bind)(op_call, sctx, want_direction) + if isinstance(res, str): + fail(res) + b: _Bound = res + + smem_buf = b.smem_region.buffer + local_buf = b.local_region.buffer + smem_st, _ = _region_st_ext(b.smem_region) + local_st, _ = _region_st_ext(b.local_region) + + tid_x = sctx.launch_params["threadIdx.x"] + num = b.num + trans = b.trans + row_dim = b.row_dim + col_dim = b.col_dim + arrangement = b.arrangement + wg_axis_dim = b.wg_axis_dim + wg_step = b.wg_step + + # Python-level closures: build PrimExpr index lists at parse time. Index + # mutation must live outside the prim_func body — TVM Script treats + # ``list[i] = ...`` inside a func as a BufferStore. + # + # Per-arrangement lane → SMEM (row_dim, col_dim) offsets on the + # PER-WARP tile: + # horizontal (8 × num*8): row += k%8; col += matrix_id*8. + # vertical (num*8 × 8): row += matrix_id*8 + k%8; col += 0. + # 2x2 (16 × 16): row += (matrix_id//2)*8 + k%8; + # col += (matrix_id%2)*8. + # + # At warpgroup scope (``is_wg=True``), an additional warp-walk offset + # ``warp_id * wg_step`` is layered onto ``wg_axis_dim`` (the dim along + # which the 4 warps line up their per-warp tiles). + def _make_smem_idx(row_in_matrix, matrix_id, warp_id): + idx = list(smem_st) + if arrangement == _HORIZONTAL: + idx[row_dim] = smem_st[row_dim] + row_in_matrix + idx[col_dim] = smem_st[col_dim] + matrix_id * 8 + elif arrangement == _VERTICAL: + idx[row_dim] = smem_st[row_dim] + matrix_id * 8 + row_in_matrix + else: # _GRID_2X2 (num=4 only) + idx[row_dim] = smem_st[row_dim] + (matrix_id // 2) * 8 + row_in_matrix + idx[col_dim] = smem_st[col_dim] + (matrix_id % 2) * 8 + if is_wg: + idx[wg_axis_dim] = idx[wg_axis_dim] + warp_id * wg_step + return idx + + elements_per_b32 = b.local_elements_per_b32 + + def _make_ld_handles(local_per_thread): + return tuple( + local_per_thread.ptr_to([r * elements_per_b32]) for r in range(num) + ) + + if b.direction == "st": + if is_wg: + # fmt: off + @Tx.prim_func(check_well_formed=False) + def impl(): + with Tx.warp(): + warp_id = Tx.meta_var((tid_x // 32) % 4) + lane_id = Tx.meta_var(tid_x % 32) + row_in_matrix = Tx.meta_var(lane_id % 8) + matrix_id = Tx.meta_var(lane_id // 8 if num > 1 else 0) + local_per_thread = local_buf.local() + Tx.ptx.stmatrix( + smem_buf.ptr_to(_make_smem_idx(row_in_matrix, matrix_id, warp_id)), + local_per_thread.ptr_to([0]), + num=num, + trans=trans, + ) + # fmt: on + return impl + + # fmt: off + @Tx.prim_func(check_well_formed=False) + def impl(): + lane_id = Tx.meta_var(tid_x % 32) + row_in_matrix = Tx.meta_var(lane_id % 8) + matrix_id = Tx.meta_var(lane_id // 8 if num > 1 else 0) + local_per_thread = local_buf.local() + Tx.ptx.stmatrix( + smem_buf.ptr_to(_make_smem_idx(row_in_matrix, matrix_id, 0)), + local_per_thread.ptr_to([0]), + num=num, + trans=trans, + ) + # fmt: on + return impl + + if is_wg: + # fmt: off + @Tx.prim_func(check_well_formed=False) + def impl(): + with Tx.warp(): + warp_id = Tx.meta_var((tid_x // 32) % 4) + lane_id = Tx.meta_var(tid_x % 32) + row_in_matrix = Tx.meta_var(lane_id % 8) + matrix_id = Tx.meta_var(lane_id // 8 if num > 1 else 0) + local_per_thread = local_buf.local() + Tx.ptx.ldmatrix( + trans, num, "b16", + smem_buf.ptr_to(_make_smem_idx(row_in_matrix, matrix_id, warp_id)), + *_make_ld_handles(local_per_thread), + ) + # fmt: on + return impl + + # fmt: off + @Tx.prim_func(check_well_formed=False) + def impl(): + lane_id = Tx.meta_var(tid_x % 32) + row_in_matrix = Tx.meta_var(lane_id % 8) + matrix_id = Tx.meta_var(lane_id // 8 if num > 1 else 0) + local_per_thread = local_buf.local() + Tx.ptx.ldmatrix( + trans, num, "b16", + smem_buf.ptr_to(_make_smem_idx(row_in_matrix, matrix_id, 0)), + *_make_ld_handles(local_per_thread), + ) + # fmt: on + return impl + + +# ---------- registration ---------------------------------------------------- + + +_STMATRIX_PAIRS = [("local", "shared*"), ("local", "shared::cta")] +_LDMATRIX_PAIRS = [("shared*", "local"), ("shared::cta", "local")] + + +@register_dispatch( + "copy", + "cuda", + variant="warp_stmatrix", + priority=25, + when=[ + predicate("storage_scope", _scope_allowed, allowed_pairs=_STMATRIX_PAIRS), + predicate("exec_scope", exec_scope_ok, expected_scopes=["warp"]), + predicate("stmatrix_compat", _make_predicate("st", min_sm=90)), + ], +) +def copy_schedule_warp_stmatrix(op_call: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc: + return _impl(op_call, sctx, want_direction="st") + + +@register_dispatch( + "copy", + "cuda", + variant="warp_ldmatrix", + priority=25, + when=[ + predicate("storage_scope", _scope_allowed, allowed_pairs=_LDMATRIX_PAIRS), + predicate("exec_scope", exec_scope_ok, expected_scopes=["warp"]), + predicate("ldmatrix_compat", _make_predicate("ld", min_sm=75)), + ], +) +def copy_schedule_warp_ldmatrix(op_call: TilePrimitiveCall, sctx: DispatchContext) -> PrimFunc: + return _impl(op_call, sctx, want_direction="ld") + + +@register_dispatch( + "copy", + "cuda", + variant="warpgroup_stmatrix", + priority=25, + when=[ + predicate("storage_scope", _scope_allowed, allowed_pairs=_STMATRIX_PAIRS), + predicate("exec_scope", exec_scope_ok, expected_scopes=["warpgroup"]), + predicate("wg_stmatrix_compat", _make_predicate("st", min_sm=90, wg=True)), + ], +) +def copy_schedule_warpgroup_stmatrix( + op_call: TilePrimitiveCall, sctx: DispatchContext +) -> PrimFunc: + return _impl_wg(op_call, sctx, want_direction="st") + + +@register_dispatch( + "copy", + "cuda", + variant="warpgroup_ldmatrix", + priority=25, + when=[ + predicate("storage_scope", _scope_allowed, allowed_pairs=_LDMATRIX_PAIRS), + predicate("exec_scope", exec_scope_ok, expected_scopes=["warpgroup"]), + predicate("wg_ldmatrix_compat", _make_predicate("ld", min_sm=75, wg=True)), + ], +) +def copy_schedule_warpgroup_ldmatrix( + op_call: TilePrimitiveCall, sctx: DispatchContext +) -> PrimFunc: + return _impl_wg(op_call, sctx, want_direction="ld") + + +__all__ = [ + "copy_schedule_warp_ldmatrix", + "copy_schedule_warp_stmatrix", + "copy_schedule_warpgroup_ldmatrix", + "copy_schedule_warpgroup_stmatrix", +] diff --git a/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py b/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py new file mode 100644 index 0000000000..1d0816bf9e --- /dev/null +++ b/tests/python/tirx/operator/tile_primitive/cuda/test_ldstmatrix.py @@ -0,0 +1,564 @@ +# Licensed to the Apache Software Foundation (ASF) under one +# or more contributor license agreements. See the NOTICE file +# distributed with this work for additional information +# regarding copyright ownership. The ASF licenses this file +# to you under the Apache License, Version 2.0 (the +# "License"); you may not use this file except in compliance +# with the License. You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, +# software distributed under the License is distributed on an +# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +# KIND, either express or implied. See the License for the +# specific language governing permissions and limitations +# under the License. +# pylint: disable=missing-function-docstring + +"""Tests for the warp/warpgroup ldmatrix/stmatrix dispatcher under +``Tx.copy``. Covers x4 in horizontal / vertical / 2×2 arrangements at warp +and warpgroup scope. Each Tx.copy must have matching LHS/RHS region +extents — the local fragment is a warp- or wg-wide VIEW with thread +distribution encoded in its TileLayout (laneid / wid_in_wg iters). +x1/x2 variants are TODO (their lane→matrix mapping is more involved). +""" + +from __future__ import annotations + +import re + +import pytest + +import tvm +from tvm.script import tirx as Tx +from tvm.tirx.layout import Axis, Iter, S, TileLayout +from tvm.tirx.operator.tile_primitive.cuda.tma_utils import SwizzleMode, mma_shared_layout + + +# --------------------------------------------------------------------------- +# Layout helpers: pure-m → tile laneid (→ tile wid_in_wg) +# --------------------------------------------------------------------------- + + +def _x4_h_warp_layout(): + """x4 horizontal warp-wide view (8, 32): + row = laneid % 8; col = (laneid // 8) * 8 + per-thread 0..7""" + return TileLayout(S[(8, 4, 8) : ( + 1 @ Axis.laneid, + 8 @ Axis.laneid, + 1, + )]) + + +def _x4_h_wg_layout(): + """x4 horizontal wg-wide view (8, 128): + row = laneid % 8; col = wid_in_wg*32 + (laneid//8)*8 + per-thread 0..7""" + return TileLayout(S[(8, 4, 4, 8) : ( + 1 @ Axis.laneid, + 1 @ Axis.wid_in_wg, + 8 @ Axis.laneid, + 1, + )]) + + +def _x4_v_warp_layout(): + """x4 vertical warp-wide view (32, 8): row = laneid; col = per-thread 0..7""" + return TileLayout(S[(32, 8) : (1 @ Axis.laneid, 1)]) + + +def _x4_2x2_warp_layout(): + """x4 2×2 grid warp-wide view (16, 16): + row = (laneid//16)*8 + laneid%8; col = ((laneid//8)%2)*8 + per-thread 0..7""" + return TileLayout.from_iters([ + Iter(8, 1, Axis.laneid), # lane_low → row stride 1 + Iter(2, 16, Axis.laneid), # row-block (lane bit 4) → row stride 8 + Iter(2, 8, Axis.laneid), # col-block (lane bit 3) → col stride 8 + Iter(8, 1, Axis.m), # per-thread → col stride 1 + ]) + + +_SM100A = tvm.target.Target({"kind": "cuda", "arch": "sm_100a"}) + + +def _compile_get_cuda(prim_func) -> str: + with _SM100A: + mod = tvm.compile( + tvm.IRModule({"main": prim_func}), target=_SM100A, tir_pipeline="tirx" + ) + return mod.mod.imports[0].inspect_source() + + +# --------------------------------------------------------------------------- +# warp scope x4: stmatrix / ldmatrix × {non-trans, trans} +# --------------------------------------------------------------------------- + + [email protected]("trans", [False, True]) +def test_warp_stmatrix_x4(trans): + layout = _x4_h_warp_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (8, 32), "bfloat16", scope="shared", + layout=TileLayout(S[(8, 32) : (32, 1)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(8, 32, layout=layout) + Tx.copy(D[0:8, 0:32], regs_warp[0:8, 0:32], trans=trans) + + src = _compile_get_cuda(f) + expected = f"stmatrix.sync.aligned.m8n8.x4{'.trans' if trans else ''}.shared.b16" + assert expected in src + assert "& 7" in src and ">> 3" in src + + [email protected]("trans", [False, True]) +def test_warp_ldmatrix_x4(trans): + layout = _x4_h_warp_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (8, 32), "bfloat16", scope="shared", + layout=TileLayout(S[(8, 32) : (32, 1)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(8, 32, layout=layout) + Tx.copy(regs_warp[0:8, 0:32], D[0:8, 0:32], trans=trans) + + src = _compile_get_cuda(f) + expected = f"ldmatrix.sync.aligned.m8n8.x4{'.trans' if trans else ''}.shared.b16" + assert expected in src + assert "& 7" in src and ">> 3" in src + + +# --------------------------------------------------------------------------- +# Swizzle: 128B SwizzleLayout XOR honored +# --------------------------------------------------------------------------- + + +def test_warp_stmatrix_swizzle_128b(): + shape = (1, 8, 128) + sw_layout = mma_shared_layout("bfloat16", SwizzleMode.SWIZZLE_128B_ATOM, shape) + layout = _x4_h_warp_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer(shape, "bfloat16", scope="shared", layout=sw_layout) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(8, 32, layout=layout) + Tx.copy(D[0, 0:8, 0:32], regs_warp[0:8, 0:32], trans=True) + + src = _compile_get_cuda(f) + assert "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16" in src + assert bool(re.search(r"\^.*threadIdx|threadIdx.*\^", src)) + + +def test_warp_ldmatrix_swizzle_128b(): + shape = (1, 8, 128) + sw_layout = mma_shared_layout("bfloat16", SwizzleMode.SWIZZLE_128B_ATOM, shape) + layout = _x4_h_warp_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer(shape, "bfloat16", scope="shared", layout=sw_layout) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(8, 32, layout=layout) + Tx.copy(regs_warp[0:8, 0:32], D[0, 0:8, 0:32], trans=False) + + src = _compile_get_cuda(f) + assert "ldmatrix.sync.aligned.m8n8.x4.shared.b16" in src + assert bool(re.search(r"\^.*threadIdx|threadIdx.*\^", src)) + + +# --------------------------------------------------------------------------- +# Permutation invariance: rebuilding the SMEM with permuted shape/strides +# gives identical per-lane addresses (3 arrangements × byte-equal check) +# --------------------------------------------------------------------------- + + +def _stmatrix_line(src): + for line in src.split("\n"): + if "ptx_stmatrix_m8n8" in line and "D_ptr[" in line: + return line.strip() + return None + + +def _assert_permute_same_addr(f_ref, f_perm, expected_inst): + src_ref = _compile_get_cuda(f_ref) + src_perm = _compile_get_cuda(f_perm) + assert expected_inst in src_ref + assert expected_inst in src_perm + a_ref = _stmatrix_line(src_ref) + a_perm = _stmatrix_line(src_perm) + assert a_ref is not None and a_perm is not None + assert a_ref == a_perm, f"\n ref: {a_ref}\n perm: {a_perm}" + + +def test_permutation_invariance_horizontal(): + layout = _x4_h_warp_layout() + + @Tx.prim_func + def f_ref(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (2, 8, 32), "bfloat16", scope="shared", + layout=TileLayout(S[(2, 8, 32) : (256, 32, 1)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(8, 32, layout=layout) + Tx.copy(D[0, 0:8, 0:32], regs_warp[0:8, 0:32], trans=True) + + @Tx.prim_func + def f_perm(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (2, 32, 8), "bfloat16", scope="shared", + layout=TileLayout(S[(2, 32, 8) : (256, 1, 32)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(8, 32, layout=layout) + Tx.copy(D[0, 0:32, 0:8], regs_warp[0:8, 0:32], trans=True) + + _assert_permute_same_addr( + f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16" + ) + + +def test_permutation_invariance_vertical(): + layout = _x4_v_warp_layout() + + @Tx.prim_func + def f_ref(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (32, 8), "bfloat16", scope="shared", + layout=TileLayout(S[(32, 8) : (8, 1)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(32, 8, layout=layout) + Tx.copy(D[0:32, 0:8], regs_warp[0:32, 0:8], trans=False) + + @Tx.prim_func + def f_perm(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (8, 32), "bfloat16", scope="shared", + layout=TileLayout(S[(8, 32) : (1, 8)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(32, 8, layout=layout) + Tx.copy(D[0:8, 0:32], regs_warp[0:32, 0:8], trans=False) + + _assert_permute_same_addr( + f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.shared.b16" + ) + + +def test_permutation_invariance_2x2(): + layout = _x4_2x2_warp_layout() + + @Tx.prim_func + def f_ref(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (16, 16), "bfloat16", scope="shared", + layout=TileLayout(S[(16, 16) : (16, 1)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(16, 16, layout=layout) + Tx.copy(D[0:16, 0:16], regs_warp[0:16, 0:16], trans=False) + + @Tx.prim_func + def f_perm(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (16, 16), "bfloat16", scope="shared", + layout=TileLayout(S[(16, 16) : (1, 16)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(16, 16, layout=layout) + Tx.copy(D[0:16, 0:16], regs_warp[0:16, 0:16], trans=False) + + _assert_permute_same_addr( + f_ref, f_perm, "stmatrix.sync.aligned.m8n8.x4.shared.b16" + ) + + +# --------------------------------------------------------------------------- +# Arrangement coverage (vertical / 2×2 dispatch reaches PTX emit) +# --------------------------------------------------------------------------- + + +def test_warp_vertical(): + layout = _x4_v_warp_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (32, 8), "bfloat16", scope="shared", + layout=TileLayout(S[(32, 8) : (8, 1)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(32, 8, layout=layout) + Tx.copy(D[0:32, 0:8], regs_warp[0:32, 0:8], trans=False) + + src = _compile_get_cuda(f) + assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in src + assert "threadIdx.x" in src + + +def test_warp_2x2(): + layout = _x4_2x2_warp_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (16, 16), "bfloat16", scope="shared", + layout=TileLayout(S[(16, 16) : (16, 1)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(16, 16, layout=layout) + Tx.copy(D[0:16, 0:16], regs_warp[0:16, 0:16], trans=False) + + src = _compile_get_cuda(f) + assert "stmatrix.sync.aligned.m8n8.x4.shared.b16" in src + assert "threadIdx.x" in src + + +# --------------------------------------------------------------------------- +# Warpgroup-scope x4 +# --------------------------------------------------------------------------- + + +def test_wg_stmatrix_x4_trans(): + wg_layout = _x4_h_wg_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([128]) + with Tx.cta(): + D = Tx.alloc_buffer( + (8, 128), "bfloat16", scope="shared", + layout=TileLayout(S[(8, 128) : (128, 1)]), + ) + with Tx.warpgroup(): + regs = Tx.alloc_buffer((4,), "uint32", scope="local") + regs_wg = regs.view("bfloat16").view(8, 128, layout=wg_layout) + Tx.copy(D[0:8, 0:128], regs_wg[0:8, 0:128], trans=True) + + src = _compile_get_cuda(f) + assert "stmatrix.sync.aligned.m8n8.x4.trans.shared.b16" in src + + +# --------------------------------------------------------------------------- +# Rejection cases +# --------------------------------------------------------------------------- + + +def test_reject_extent_mismatch(): + """Local region extents don't match SMEM region — Tx.copy semantically + invalid, dispatcher rejects.""" + layout = _x4_h_warp_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (8, 32), "bfloat16", scope="shared", + layout=TileLayout(S[(8, 32) : (32, 1)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + # Raw per-thread fragment, no warp-wide view. + Tx.copy(D[0:8, 0:32], regs[0:8], trans=True) + + with pytest.raises(Exception) as excinfo: + _compile_get_cuda(f) + assert "warp_stmatrix" in str(excinfo.value) + + +def test_reject_non_b16_smem(): + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (8, 32), "float32", scope="shared", + layout=TileLayout(S[(8, 32) : (32, 1)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((8,), "float32", scope="local") + Tx.copy(D[0:8, 0:32], regs[0:8], trans=True) + + with pytest.raises(Exception) as excinfo: + _compile_get_cuda(f) + s = str(excinfo.value) + assert "warp_stmatrix" in s and "b16" in s + + +def test_reject_wrong_smem_shape(): + """8×40 doesn't decompose into any m8n8.x{1,2,4} arrangement.""" + layout = _x4_h_warp_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (8, 40), "bfloat16", scope="shared", + layout=TileLayout(S[(8, 40) : (40, 1)]), + ) + with Tx.warp(): + regs = Tx.alloc_buffer((10,), "bfloat16", scope="local") + regs_warp = regs.view(8, 40, layout=TileLayout(S[(8, 40) : (40, 1)])) + Tx.copy(D[0:8, 0:40], regs_warp[0:8, 0:40], trans=True) + + with pytest.raises(Exception) as excinfo: + _compile_get_cuda(f) + assert "warp_stmatrix" in str(excinfo.value) + + +def test_reject_warp_filtered_lanes(): + """``if Tx.filter(lane, 0, 16)`` narrows the active set — stmatrix + requires all 32 lanes.""" + layout = _x4_h_warp_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (8, 32), "bfloat16", scope="shared", + layout=TileLayout(S[(8, 32) : (32, 1)]), + ) + with Tx.warp(): + lane_id = Tx.lane_id([32]) + if Tx.filter(lane_id, 0, 16): + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + regs_warp = regs.view(8, 32, layout=layout) + Tx.copy(D[0:8, 0:32], regs_warp[0:8, 0:32], trans=True) + + with pytest.raises(Exception) as excinfo: + _compile_get_cuda(f) + s = str(excinfo.value) + assert "warp_stmatrix" in s and "laneid" in s and "narrow" in s + + +def test_reject_wg_filtered_warps(): + """``if Tx.filter(warp_id, 0, 2)`` at wg scope narrows to 2 warps — + stmatrix wg dispatcher needs all 4.""" + wg_layout = _x4_h_wg_layout() + + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([128]) + with Tx.cta(): + D = Tx.alloc_buffer( + (8, 128), "bfloat16", scope="shared", + layout=TileLayout(S[(8, 128) : (128, 1)]), + ) + with Tx.warpgroup(): + warp_id = Tx.warp_id_in_wg([4]) + if Tx.filter(warp_id, 0, 2): + regs = Tx.alloc_buffer((4,), "uint32", scope="local") + regs_wg = regs.view("bfloat16").view(8, 128, layout=wg_layout) + Tx.copy(D[0:8, 0:128], regs_wg[0:8, 0:128], trans=True) + + with pytest.raises(Exception) as excinfo: + _compile_get_cuda(f) + s = str(excinfo.value) + assert "warpgroup_stmatrix" in s and "wid_in_wg" in s and "narrow" in s + + +def test_reject_non_warp_scope(): + """Tx.copy at cta scope (no warp/wg wrap) — warp_stmatrix dispatcher must + not fire. The dispatch error log must list warp_stmatrix as rejected.""" + @Tx.prim_func + def f(): + with Tx.kernel(): + Tx.cta_id([1]) + Tx.thread_id([32]) + with Tx.cta(): + D = Tx.alloc_buffer( + (8, 32), "bfloat16", scope="shared", + layout=TileLayout(S[(8, 32) : (32, 1)]), + ) + regs = Tx.alloc_buffer((8,), "bfloat16", scope="local") + Tx.copy(D[0:8, 0:32], regs[0:8], trans=True) + + try: + src = _compile_get_cuda(f) + assert "stmatrix" not in src + except Exception as e: + s = str(e) + assert "warp_stmatrix" in s
