This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 16d0a7edae [TIRX][CUDA] Framework support for FA4, CLC intrinsics, and
nvfp4 tcgen05 GEMM (#19785)
16d0a7edae is described below
commit 16d0a7edae0d52ccf0b947656310359d965e79d9
Author: Bohan Hou <[email protected]>
AuthorDate: Tue Jun 16 03:52:14 2026 -0700
[TIRX][CUDA] Framework support for FA4, CLC intrinsics, and nvfp4 tcgen05
GEMM (#19785)
---
python/tvm/backend/cuda/lang/pipeline.py | 11 +-
python/tvm/backend/cuda/lang/tile_scheduler.py | 135 ++++++++++++++++-
python/tvm/backend/cuda/op.py | 79 +++++++++-
.../tvm/backend/cuda/operator/intrinsics/sync.py | 100 +++++++++++-
.../tile_primitive/copy_async/tcgen05_ldst.py | 35 +++--
.../operator/tile_primitive/elementwise/reg.py | 67 ++++++++
python/tvm/backend/cuda/script.py | 6 +
python/tvm/support/nvcc.py | 76 ++++++++--
python/tvm/tirx/script/builder/external_kernel.py | 2 +-
src/backend/cuda/op/target_builtin.cc | 6 +
src/target/llvm/codegen_llvm.cc | 17 +++
src/target/llvm/codegen_llvm.h | 3 +
src/tirx/ir/layout/tile_slice.cc | 6 +-
tests/python/codegen/test_target_codegen_llvm.py | 39 +++++
tests/python/tirx/codegen/test_codegen_cuda.py | 11 ++
tests/python/tirx/codegen/test_codegen_nvshmem.py | 3 +
tests/python/tirx/codegen/test_cuda_copy.py | 11 ++
tests/python/tirx/codegen/test_cuda_cta_reduce.py | 13 ++
tests/python/tirx/codegen/test_cuda_warp_reduce.py | 13 ++
tests/python/tirx/conftest.py | 40 +++++
.../tile_primitive/cuda/copy/test_fallback.py | 5 +
.../tile_primitive/cuda/copy/test_gmem_smem.py | 4 +
.../operator/tile_primitive/cuda/copy/test_reg.py | 5 +
.../tile_primitive/cuda/copy_async/test_ldgsts.py | 3 +
.../tile_primitive/cuda/copy_async/test_tmem.py | 7 +
.../cuda/copy_async/test_tmem_16xnb.py | 144 ++++++++++++++++++
.../tile_primitive/cuda/elementwise/test_binary.py | 13 ++
.../tile_primitive/cuda/elementwise/test_fma.py | 15 ++
.../tile_primitive/cuda/elementwise/test_unary.py | 168 ++++++++++++++++++++-
.../cuda/gemm_async/test_gemm_async.py | 23 +++
.../cuda/permute_layout/test_permute_layout.py | 7 +
.../cuda/reduction/test_reduction.py | 23 +++
tests/python/tirx/test_buffer_print.py | 4 +
tests/python/tirx/test_control_flow.py | 8 +
tests/python/tirx/test_layout.py | 35 +++++
tests/scripts/task_python_unittest.sh | 1 +
36 files changed, 1096 insertions(+), 42 deletions(-)
diff --git a/python/tvm/backend/cuda/lang/pipeline.py
b/python/tvm/backend/cuda/lang/pipeline.py
index ee86090398..40fd40c3fa 100644
--- a/python/tvm/backend/cuda/lang/pipeline.py
+++ b/python/tvm/backend/cuda/lang/pipeline.py
@@ -110,7 +110,7 @@ class MBarrier:
T.ptx.mbarrier.try_wait(self.buf.ptr_to([stage]), phase ^
self.phase_offset)
@T.inline
- def arrive(self, stage, cta_id=None, pred=None):
+ def arrive(self, stage, cta_id=None, pred=None, count=None):
# Default: local-CTA arrive — emits the simple
# ``mbarrier.arrive.shared.b64`` form. To arrive on a remote
# CTA's mbarrier in a cluster kernel, callers must pass
@@ -119,11 +119,18 @@ class MBarrier:
# the cross-CTA path was both surprising (``bar.arrive(stage)``
# silently ``mapa`` ed across the cluster) and a per-call cost
# of ~3 PTX ops on every single-CTA kernel.
+ #
+ # ``count`` (cross-CTA path only) emits the explicit arrival-count
+ # operand, i.e. ``mbarrier.arrive.shared::cluster.b64 _, [addr],
count``.
+ # When ``None`` the implicit count-of-1 form is emitted. Passing
+ # ``count=1`` is semantically identical but spells the count
explicitly.
if cta_id is None:
T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]))
else:
actual_pred = True if pred is None else pred
- T.ptx.mbarrier.arrive(self.buf.ptr_to([stage]), cta_id=cta_id,
pred=actual_pred)
+ T.ptx.mbarrier.arrive(
+ self.buf.ptr_to([stage]), cta_id=cta_id, pred=actual_pred,
count=count
+ )
def ptr_to(self, idx):
return self.buf.ptr_to(idx)
diff --git a/python/tvm/backend/cuda/lang/tile_scheduler.py
b/python/tvm/backend/cuda/lang/tile_scheduler.py
index 3fd27f25ee..c6154f2462 100644
--- a/python/tvm/backend/cuda/lang/tile_scheduler.py
+++ b/python/tvm/backend/cuda/lang/tile_scheduler.py
@@ -20,6 +20,7 @@ These classes emit TIR via @T.inline. Decorate with
@T.meta_class so that
instances are automatically treated as meta values inside @T.prim_func.
"""
+from tvm.backend.cuda.lang.pipeline import Pipeline, PipelineState
from tvm.script import tirx as T
@@ -753,13 +754,20 @@ class FlashAttentionLPTScheduler(BaseTileScheduler):
"""
def __init__(
- self, prefix: str, num_batches: int, num_heads: int, num_m_blocks:
int, l2_swizzle: int
+ self,
+ prefix: str,
+ num_batches: int,
+ num_heads: int,
+ num_m_blocks: int,
+ l2_swizzle: int,
+ num_ctas: int | None = None,
):
super().__init__(prefix)
self._num_batches = num_batches
self._num_heads = num_heads
self._num_m_blocks = num_m_blocks
self._l2_swizzle = l2_swizzle
+ self._num_ctas = num_ctas
self._total_tasks = num_batches * num_heads * num_m_blocks
# Derived constants for L2 swizzle
@@ -807,10 +815,131 @@ class FlashAttentionLPTScheduler(BaseTileScheduler):
@T.inline
def next_tile(self):
- """Advance to next tile by striding by num_ctas."""
- self.linear_idx = self._total_tasks
+ """Advance to the next tile.
+
+ Single-tile mode (``num_ctas=None``, the default): each CTA owns one
+ task; terminate. Persistent mode (``num_ctas=N``): stride by N, like
+ :class:`FlashAttentionLinearScheduler`, while keeping the LPT + L2
+ swizzle index mapping.
+ """
+ if self._num_ctas is None:
+ self.linear_idx = self._total_tasks
+ else:
+ self.linear_idx = self.linear_idx + self._num_ctas
+ self.update_current_m_n_idx(self.linear_idx)
# fmt: on
def valid(self):
"""Check if there are more tiles to process."""
return self.linear_idx < self._total_tasks
+
+
+class _CLCWorker(ClusterPersistentScheduler2D):
+ """Per-role CLC handle: IS-A ClusterPersistentScheduler2D (so m_idx /
n_idx work as
+ usual) plus the role-local barrier phase and handshake. A coord-free role
(e.g. an
+ MMA warp consuming whatever a loader staged) arms the loop with reset()
not init().
+ """
+
+ def __init__(self, clc, prefix):
+ super().__init__(
+ prefix,
+ num_m_tiles=clc._num_m_tiles,
+ num_n_tiles=clc._num_n_tiles,
+ num_clusters=clc._num_m_tiles * clc._num_n_tiles,
+ l2_group_size=clc._l2_group_size,
+ )
+ self._clc = clc
+ self._sa = PipelineState(1, 0)
+ self._done = T.local_scalar("int32")
+ self._nxt = T.local_scalar("uint32")
+
+ @T.inline
+ def reset(self):
+ self._done = 0
+
+ @T.inline
+ def init(self, cluster_id):
+ # Explicit base call: TVMScript's parser has no zero-arg super().
+ ClusterPersistentScheduler2D.init(self, cluster_id)
+ self._done = 0
+
+ def valid(self):
+ return self._done == 0
+
+ @T.inline
+ def consume(self):
+ # Single-elected-thread scope: wait for the handle, decode, release
the slot.
+ self._clc.sched_arr.full.wait(0, self._sa.phase)
+ self._sa.advance()
+ self._nxt =
T.ptx.clc_query_cancel(T.address_of(self._clc.clc_handle[0]))
+ self._clc.sched_fin.empty.arrive(0, cta_id=0, pred=True)
+
+ @T.inline
+ def consume_wg(self, wg_id, warp_id, lane_id):
+ # Warpgroup scope: all threads decode; one elected lane releases the
slot.
+ self._clc.sched_arr.full.wait(0, self._sa.phase)
+ self._sa.advance()
+ self._nxt =
T.ptx.clc_query_cancel(T.address_of(self._clc.clc_handle[0]))
+ T.cuda.warpgroup_sync(wg_id + 1)
+ if (warp_id == 0) & (lane_id == 0):
+ self._clc.sched_fin.empty.arrive(0, cta_id=0, pred=True)
+
+ @T.inline
+ def advance_coords(self):
+ if self._nxt != 0xFFFFFFFF:
+ self.update_current_m_n_idx(self._nxt // self._clc._cta_group)
+
+ @T.inline
+ def mark_done_if_drained(self):
+ if self._nxt == 0xFFFFFFFF:
+ self._done = 1
+
+
[email protected]_class
+class ClusterLaunchControlScheduler:
+ """Blackwell Cluster Launch Control (CLC) tile scheduler.
+
+ A scheduler warp runs ``run_scheduler`` (issues ``try_cancel`` to steal
the next
+ cluster); worker roles each take a ``worker()`` handle and pull the stolen
tile
+ through the shared smem handshake. Owns the CLC smem: the 16B response
handle, the
+ arrival barrier (handle ready), and the finished barrier (slot consumed;
+ ``finish_arrivals`` arrivals per round). Tile-coord mapping is delegated to
+ ``ClusterPersistentScheduler2D`` (group-major L2 ordering).
+ """
+
+ def __init__(self, pool, num_m_tiles, num_n_tiles, l2_group_size,
cta_group, finish_arrivals):
+ self._num_m_tiles = num_m_tiles
+ self._num_n_tiles = num_n_tiles
+ self._l2_group_size = l2_group_size
+ self._cta_group = cta_group
+ self.sched_arr = Pipeline(pool, 1, full="tma", empty="mbar",
init_empty=1)
+ self.sched_fin = Pipeline(pool, 1, full="mbar", empty="mbar",
init_empty=finish_arrivals)
+ self.clc_handle = pool.alloc((4,), "uint32", align=16)
+ self._s_done = T.local_scalar("int32")
+ self._s_nxt = T.local_scalar("uint32")
+
+ def worker(self, prefix):
+ return _CLCWorker(self, prefix)
+
+ @T.inline
+ def run_scheduler(self, cbx):
+ # cta0 drives try_cancel; both CTAs expect_bytes + consume the handle
so the
+ # finished-barrier count is met and the slot can be reissued.
+ if T.ptx.elect_sync():
+ sa = PipelineState(1, 0)
+ sf = PipelineState(1, 1)
+ self._s_done = 0
+ while self._s_done == 0:
+ if cbx == 0:
+ self.sched_fin.empty.wait(0, sf.phase)
+ sf.advance()
+ T.ptx.clc_try_cancel(
+ T.address_of(self.clc_handle[0]),
T.address_of(self.sched_arr.full.buf[0])
+ )
+ self.sched_arr.full.arrive(0, 16) # expect_bytes for the 16B
handle
+ self.sched_arr.full.wait(0, sa.phase)
+ sa.advance()
+ self._s_nxt =
T.ptx.clc_query_cancel(T.address_of(self.clc_handle[0]))
+ self.sched_fin.empty.arrive(0, cta_id=0, pred=True)
+ if self._s_nxt == 0xFFFFFFFF:
+ self._s_done = 1
diff --git a/python/tvm/backend/cuda/op.py b/python/tvm/backend/cuda/op.py
index e76d5fbe24..9570e26662 100644
--- a/python/tvm/backend/cuda/op.py
+++ b/python/tvm/backend/cuda/op.py
@@ -653,12 +653,12 @@ def ptx_mbarrier_init(bar, thread_count):
return call_intrin("", "tirx.ptx_mbarrier_init", bar, thread_count)
-def ptx_mbarrier_arrive(bar, cta_id=None, pred=None):
+def ptx_mbarrier_arrive(bar, cta_id=None, pred=None, count=None):
"""TVM intrinsic to call
mbarrier.arrive.shared::cta.b64
or
@p mapa.shared::cluster.u32
- @p mbarrier.arrive.shared::cluster.b64
+ @p mbarrier.arrive.shared::cluster.b64 [, count]
Parameters
----------
@@ -670,11 +670,29 @@ def ptx_mbarrier_arrive(bar, cta_id=None, pred=None):
pred : Optional[PrimExpr]
The predicate to guard the operation.
+
+ count : Optional[PrimExpr]
+ Explicit arrival count operand for the cross-CTA (cluster) form. When
+ ``None`` the implicit count-of-1 form is emitted; when given, emits
+ ``mbarrier.arrive.shared::cluster.b64 _, [addr], count``.
"""
if cta_id is None and pred is None:
return call_intrin("", "tirx.ptx_mbarrier_arrive", bar)
assert cta_id is not None and pred is not None
- return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred)
+ if count is None:
+ return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred)
+ return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, pred,
count)
+
+
+def ptx_mbarrier_arrive_cluster_count(bar, cta_id, count):
+ """Cross-CTA ``mbarrier.arrive`` on CTA ``cta_id`` with an explicit count.
+
+ Convenience for an already-elected thread: emits
+ ``@p mapa.shared::cluster.u32`` + ``@p mbarrier.arrive.shared::cluster.b64
_,
+ [addr], count`` with the guard defaulted to 1.
+ """
+ return call_intrin("", "tirx.ptx_mbarrier_arrive", bar, cta_id, True,
count)
+
def ptx_mbarrier_arrive_expect_tx(bar, byte_count, cta_id=None, pred=None):
@@ -706,7 +724,11 @@ def ptx_mbarrier_arrive_expect_tx(bar, byte_count,
cta_id=None, pred=None):
"""
if cta_id is None and pred is None:
return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar,
byte_count)
- assert cta_id is not None and pred is not None
+ assert cta_id is not None
+ # Cross-CTA expect_tx from an already-elected thread: default the guard to
1
+ # (the caller has elected a single lane), so callers can pass cta_id alone.
+ if pred is None:
+ pred = True
return call_intrin("", "tirx.ptx_mbarrier_arrive_expect_tx", bar,
byte_count, cta_id, pred)
@@ -729,6 +751,23 @@ def ptx_mbarrier_try_wait(bar, phase):
return call_intrin("", "tirx.ptx_mbarrier_try_wait", bar, phase)
+def ptx_mbarrier_try_wait_acquire_cluster(bar, phase):
+ """``mbarrier.try_wait.parity.acquire.cluster`` retry loop.
+
+ Cluster-scope acquire wait — used to wait on a barrier that a remote CTA in
+ the cluster arrives on (a group cluster wait).
+
+ Parameters
+ ----------
+ bar : Var
+ The pointer to barrier variable.
+
+ phase : int
+ The phase of the barrier.
+ """
+ return call_intrin("", "tirx.ptx_mbarrier_try_wait_acquire_cluster", bar,
phase)
+
+
def ptx_mbarrier_try_wait_once(bar, phase, ticks):
"""TVM intrinsic for one-shot non-blocking ``mbarrier.try_wait.parity``.
@@ -1261,6 +1300,38 @@ def ptx_barrier_cluster_wait(acquire=False,
aligned=True):
return call_intrin("", "tirx.ptx_barrier_cluster_wait", acquire, aligned)
+def ptx_clc_try_cancel(handle, mbar):
+ """TVM intrinsic to call clusterlaunchcontrol.try_cancel.
+
+ Async-requests cancelling the next cluster's launch (work-stealing):
writes the
+ 16B response handle to smem and signals ``mbar`` (complete_tx, multicast
to both
+ cluster CTAs).
+
+ Parameters
+ ----------
+ handle : PrimExpr
+ Pointer to the 16B (uint4) smem response handle.
+
+ mbar : PrimExpr
+ Pointer to the mbarrier signalled when the handle lands.
+ """
+ return call_intrin("", "tirx.ptx_clc_try_cancel", handle, mbar)
+
+
+def ptx_clc_query_cancel(handle):
+ """TVM intrinsic to call clusterlaunchcontrol.query_cancel.
+
+ Decodes the response handle written by :func:`ptx_clc_try_cancel`. Returns
the
+ cancelled cluster's first ``ctaid.x``, or ``0xFFFFFFFF`` when no work was
stolen.
+
+ Parameters
+ ----------
+ handle : PrimExpr
+ Pointer to the 16B (uint4) smem response handle.
+ """
+ return call_intrin("uint32", "tirx.ptx_clc_query_cancel", handle)
+
+
def ptx_elect_sync():
"""TVM intrinsic to call elect.sync"""
return call_intrin("uint32", "tirx.ptx_elect_sync")
diff --git a/python/tvm/backend/cuda/operator/intrinsics/sync.py
b/python/tvm/backend/cuda/operator/intrinsics/sync.py
index 0fcdb31a46..791d9cc981 100644
--- a/python/tvm/backend/cuda/operator/intrinsics/sync.py
+++ b/python/tvm/backend/cuda/operator/intrinsics/sync.py
@@ -168,6 +168,54 @@ device_intrinsic(
)
+# =============================================================================
+# clusterlaunchcontrol.try_cancel / query_cancel — Blackwell Cluster Launch
+# Control (CLC) work-stealing, written from the PTX ISA spec (section
+# "clusterlaunchcontrol", PTX ISA 8.6). try_cancel async-requests cancelling
the
+# next cluster's launch, writing a 16B response to smem + signalling mbar.
query
+# decodes the response: on success it extracts the cancelled cluster's first
+# ctaid.x (via the get_first_ctaid::x form); a single uint32 is returned, with
+# 0xFFFFFFFF as the "no work stolen" sentinel (a device helper returns one
scalar).
+# =============================================================================
+device_intrinsic(
+ "ptx_clc_try_cancel",
+ c_signature="(void* handle, void* mbar)",
+ body=(
+ " unsigned int addr = (unsigned
int)__cvta_generic_to_shared(handle);\n"
+ " unsigned int bar = (unsigned
int)__cvta_generic_to_shared(mbar);\n"
+ " asm volatile(\n"
+ '
"clusterlaunchcontrol.try_cancel.async.shared::cta.mbarrier::complete_tx::bytes"\n'
+ ' ".multicast::cluster::all.b128 [%0], [%1];\\n"\n'
+ ' :: "r"(addr), "r"(bar) : "memory");'
+ ),
+)
+
+
+device_intrinsic(
+ "ptx_clc_query_cancel",
+ c_signature="(void* handle)",
+ return_type="uint32_t",
+ tvm_return_type="uint32",
+ body=(
+ " unsigned int addr = (unsigned
int)__cvta_generic_to_shared(handle);\n"
+ " unsigned int first_ctaid_x;\n"
+ " asm volatile(\n"
+ ' "{\\n"\n'
+ ' ".reg .pred canceled;\\n"\n'
+ ' ".reg .b128 response;\\n"\n'
+ ' "ld.shared.b128 response, [%1];\\n"\n'
+ ' "clusterlaunchcontrol.query_cancel.is_canceled.pred.b128
canceled, response;\\n"\n'
+ ' "mov.u32 %0, 0xffffffff;\\n"\n'
+ ' "@canceled
clusterlaunchcontrol.query_cancel.get_first_ctaid::x.b32.b128"\n'
+ ' " %0, response;\\n"\n'
+ ' "}\\n"\n'
+ ' : "=r"(first_ctaid_x) : "r"(addr) : "memory");\n'
+ ' asm volatile("fence.proxy.async.shared::cta;\\n" ::: "memory");\n'
+ " return first_ctaid_x;"
+ ),
+)
+
+
# =============================================================================
# mbarrier.init.shared.b64 [addr], count ; — 1 form.
# =============================================================================
@@ -208,7 +256,7 @@ device_intrinsic(
' "{\\n"\n'
' ".reg .pred p;\\n"\n'
' ".reg .b32 remAddr32;\\n"\n'
- ' "setp.eq.u32 p, %2, 1;\\n"\n'
+ ' "setp.ne.s32 p, %2, 0;\\n"\n'
' "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\\n"\n'
' "@p mbarrier.arrive.shared::cluster.b64 _,
[remAddr32];\\n"\n'
' "}\\n"\n'
@@ -217,15 +265,38 @@ device_intrinsic(
)
+# Same cross-CTA arrive, but with an explicit arrival-count operand
+# (``..., [remAddr32], count``). Matches the ``tma::cluster::arrive`` spelling.
+device_intrinsic(
+ "_ptx_mbarrier_arrive_remote_count",
+ helper_name="tvm_builtin_ptx_mbarrier_arrive_remote_count",
+ c_signature="(void* barrier, int cta_id, int pred, int count)",
+ body=(
+ " unsigned int barrier_addr = __cvta_generic_to_shared(barrier);\n"
+ " asm volatile(\n"
+ ' "{\\n"\n'
+ ' ".reg .pred p;\\n"\n'
+ ' ".reg .b32 remAddr32;\\n"\n'
+ ' "setp.ne.s32 p, %2, 0;\\n"\n'
+ ' "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\\n"\n'
+ ' "@p mbarrier.arrive.shared::cluster.b64 _, [remAddr32],
%3;\\n"\n'
+ ' "}\\n"\n'
+ ' :: "r"(barrier_addr), "r"(cta_id), "r"(pred), "r"(count) :
"memory");'
+ ),
+)
+
+
@register_codegen("ptx_mbarrier_arrive")
def _codegen_mbarrier_arrive(*args):
- """Dispatch by arg count: 1 -> local, 3 -> remote (cluster-mapped)."""
+ """Dispatch by arg count: 1 -> local, 3 -> remote, 4 -> remote+count."""
if len(args) == 1:
result =
CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_local"](list(args))
elif len(args) == 3:
result =
CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_remote"](list(args))
+ elif len(args) == 4:
+ result =
CODEGEN_REGISTRY["tirx._ptx_mbarrier_arrive_remote_count"](list(args))
else:
- raise ValueError(f"ptx_mbarrier_arrive expects 1 or 3 args, got
{len(args)}")
+ raise ValueError(f"ptx_mbarrier_arrive expects 1, 3, or 4 args, got
{len(args)}")
return result[0] if isinstance(result, tuple) else result
@@ -252,7 +323,7 @@ device_intrinsic(
' "{\\n"\n'
' ".reg .pred p;\\n"\n'
' ".reg .b32 remAddr32;\\n"\n'
- ' "setp.eq.u32 p, %2, 1;\\n"\n'
+ ' "setp.ne.s32 p, %2, 0;\\n"\n'
' "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\\n"\n'
' "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _,
[remAddr32], %3;\\n"\n'
' "}\\n"\n'
@@ -303,6 +374,27 @@ device_intrinsic(
)
+# mbarrier.try_wait.parity.acquire.cluster — cluster-scope acquire wait used
for
+# cross-CTA barrier handshakes (e.g. the tmem-finished handoff).
+device_intrinsic(
+ "ptx_mbarrier_try_wait_acquire_cluster",
+ c_signature="(void* barrier, int phase)",
+ body=(
+ " unsigned int barrier_addr_int =
__cvta_generic_to_shared(barrier);\n"
+ " asm volatile(\n"
+ ' "{\\n"\n'
+ ' ".reg .pred P1;\\n"\n'
+ ' "LAB_WAIT_AC:\\n"\n'
+ ' "mbarrier.try_wait.parity.acquire.cluster.shared::cta.b64 P1,
[%0], %1;\\n"\n'
+ ' "@P1 bra.uni DONE_AC;\\n"\n'
+ ' "bra.uni LAB_WAIT_AC;\\n"\n'
+ ' "DONE_AC:\\n"\n'
+ ' "}\\n"\n'
+ ' :: "r"(barrier_addr_int), "r"(phase) : "memory");'
+ ),
+)
+
+
# =============================================================================
# mbarrier.try_wait.parity — ONE-SHOT non-blocking variant. Returns true
# if the requested parity has already been reached, false otherwise.
diff --git
a/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py
b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py
index ffd5e18a3a..081ea5a772 100644
--- a/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py
+++ b/python/tvm/backend/cuda/operator/tile_primitive/copy_async/tcgen05_ldst.py
@@ -369,20 +369,24 @@ def _emit_16xnb_path(
tmem_st, tmem_extent = get_st_extent(tmem_region)
local_st, local_extent = get_st_extent(local_region)
- # Local slice must be the full (frag_rows, K_cols) view.
+ # Rows must span the full frag. The COLUMN extent may be a sub-multiple of
+ # the atom's full width ``width_elems`` — i.e. a per-chunk column slice of
a
+ # wider frag (e.g. an epilogue that loads one big (128, MMA_N) frag in
+ # EPI_TILE-wide chunks). The atom layout maps consecutive columns to
+ # consecutive registers within each slab, so a column slice occupies a
+ # contiguous register window; we emit ``num_eff`` (the slice's atom rep) at
+ # the slab base + the column's register offset. When the slice IS the full
+ # atom (the common case), num_eff == num and reg offset == 0 (no change).
assert analyzer.can_prove_equal(local_st[0], 0)
assert analyzer.can_prove_equal(local_extent[0], frag_rows)
- assert analyzer.can_prove_equal(local_extent[1], width_elems)
-
- # TMEM slice must start at row 0 and span ``frag_rows`` rows. For Layout
- # F the buffer is already (64, W) so frag_rows=64 covers the full slice;
- # for Layout D + frag_rows=64 the slice reads the *first* half-slab and
- # the rest of the buffer's 128 rows is invisible to this atom. For
- # Layout D + frag_rows=128 the slice covers all 128 physical lanes via
- # two PTX issues (row=0 + row=16).
assert analyzer.can_prove_equal(tmem_st[0], 0)
assert analyzer.can_prove_equal(tmem_extent[0], frag_rows)
- assert analyzer.can_prove_equal(tmem_extent[1], width_elems)
+ # local and tmem column slices must match and divide the atom's full width.
+ assert analyzer.can_prove_equal(local_extent[1], tmem_extent[1])
+ slice_w = int(local_extent[1])
+ assert width_elems % slice_w == 0, f"slice width {slice_w} must divide
atom width {width_elems}"
+ num_eff = num * slice_w // width_elems
+ regs_eff = regs_per_thread_per_slab * slice_w // width_elems
del tmem_rows # only used for the structural check above
col_off = tmem_st[1]
@@ -410,13 +414,18 @@ def _emit_16xnb_path(
# for the register-pointer arguments of the PTX builtin.
local_storage = local_buf.view(per_thread_elems,
layout=TileLayout(S[per_thread_elems]))
local_32b = local_storage.view("uint32")
- local_reg_base = local_col_off_elems // elem_per_32b
+ # Register offset of the column slice within each slab. The old
+ # ``local_col_off // elem_per_32b`` is only correct when the slice IS
the
+ # full atom; in general consecutive columns advance registers at the
rate
+ # (regs_per_thread_per_slab / width_elems). For a full-atom load the
+ # offset is 0 either way, so existing callers are unaffected.
+ local_reg_base = local_col_off_elems * regs_per_thread_per_slab //
width_elems
for slab in range(n_slabs):
reg_base = slab * regs_per_thread_per_slab
op(
tmem_buf.allocated_addr[0],
- *[local_32b[local_reg_base + reg_base + i] for i in
range(regs_per_thread_per_slab)], # noqa: E501
- shape=shape, num=num, row=slab * 16, col=col_off_32b,
+ *[local_32b[local_reg_base + reg_base + i] for i in
range(regs_eff)],
+ shape=shape, num=num_eff, row=slab * 16, col=col_off_32b,
)
# fmt: on
return impl
diff --git a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py
b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py
index eddf9f3d8e..64d77a21cf 100644
--- a/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py
+++ b/python/tvm/backend/cuda/operator/tile_primitive/elementwise/reg.py
@@ -45,8 +45,10 @@ from ..common import get_st_extent
from ..copy._common import _carve_tail, _verify_s_tail_contig
from ..layout_utils import get_sublayout_from_region, layout_signature
from ._common import (
+ _TID_AXIS_FOR_SCOPE,
_all_threads_active,
_tensor_shape_of,
+ _thread_cnt,
align_operands_to_anchor,
buffer_regions,
compute_dtype_of,
@@ -67,6 +69,68 @@ def _validate_anchor_layout(anchor_br) -> tuple[bool, str |
None]:
return True, None
+def _validate_scope_level_anchor(anchor_br, sctx: DispatchContext) ->
tuple[bool, str | None]:
+ """For warp/warpgroup/cta scope, require dst to be scope-level: after
+ canonicalizing with the target its thread axes are the scope's intra-thread
+ axis (laneid/tid_in_wg/tx) and, sorted by stride, tile a complete ``T:1``
+ chain over all ``T`` threads of the scope. Rejects thread-local
``.local()``
+ views; thread scope is exempt.
+ """
+ scope = sctx.scope_kind
+ if scope == "thread":
+ return True, None
+ expected_axis = _TID_AXIS_FOR_SCOPE.get(scope)
+ if expected_axis is None:
+ return True, None
+ expected_cnt = _thread_cnt(sctx)
+
+ # Canonicalize the sliced anchor with the target so warp/lane axes fuse.
+ st, ext = get_st_extent(anchor_br)
+ sliced = get_sublayout_from_region(anchor_br.buffer.layout,
anchor_br.buffer.shape, st, ext)
+ with sctx.target:
+ canon = sliced.canonicalize() if hasattr(sliced, "canonicalize") else
sliced
+ shard = getattr(canon, "shard", None)
+ if shard is None:
+ return False, f"{scope}-scope op operand layout is not a TileLayout
after slicing"
+
+ thread_iters = [it for it in shard if it.axis.is_thread()]
+ if not thread_iters:
+ return (
+ False,
+ f"{scope}-scope op needs a {scope}-level operand whose layout
carries "
+ f"thread axes ({expected_axis} composing to {expected_cnt}:1); got
a "
+ f"thread-local view with no thread axes — pass the {scope}-level
tensor, "
+ f"not its `.local()` (per-thread) view",
+ )
+ bad = sorted({it.axis.name for it in thread_iters if it.axis.name !=
expected_axis})
+ if bad:
+ return (
+ False,
+ f"{scope}-scope op operand carries thread axes {bad}; after "
+ f"canonicalization a {scope}-level layout must use only
{expected_axis!r}",
+ )
+ # Sorted by stride the thread iters must tile a complete chain 1, e0,
+ # e0*e1, ... up to the scope thread count — i.e. cover all T threads with
+ # no gap or overlap (extents alone would miss gaps/overlaps).
+ running = 1
+ for it in sorted(thread_iters, key=lambda i: int(i.stride)):
+ stride, extent = int(it.stride), int(it.extent)
+ if stride != running:
+ return (
+ False,
+ f"{scope}-scope op operand thread axes do not tile a complete "
+ f"{expected_cnt}:1 (sorted by stride: expected {running}, got
{stride})",
+ )
+ running *= extent
+ if running != expected_cnt:
+ return (
+ False,
+ f"{scope}-scope op operand thread axes span {running} threads, not
the "
+ f"full {expected_cnt} of the {scope}",
+ )
+ return True, None
+
+
def _check_layout_operands_agree(plan) -> tuple[bool, str | None]:
"""Replica sigs must match across non-trivial-layout operands.
@@ -133,6 +197,9 @@ def is_reg_ewise(spec):
ok3, reason3 = _validate_anchor_layout(anchor)
if not ok3:
return False, reason3
+ ok_scope, reason_scope = _validate_scope_level_anchor(anchor, sctx)
+ if not ok_scope:
+ return False, reason_scope
# Shape compat (NumPy-style broadcast): anchor's tensor shape is the
# result shape; every operand must broadcast TO anchor.
anchor_tshape = _tensor_shape_of(anchor.region)
diff --git a/python/tvm/backend/cuda/script.py
b/python/tvm/backend/cuda/script.py
index a1148f9b67..a46aa7e7e4 100644
--- a/python/tvm/backend/cuda/script.py
+++ b/python/tvm/backend/cuda/script.py
@@ -53,6 +53,8 @@ class PTXNamespace:
self.stmatrix = _op_wrapper(_cuda_op.ptx_stmatrix)
self.setmaxnreg: Callable[..., Any] =
_op_wrapper(_cuda_op.ptx_setmaxnreg)
self.elect_sync: Callable[..., Any] =
_op_wrapper(_cuda_op.ptx_elect_sync)
+ self.clc_try_cancel = _op_wrapper(_cuda_op.ptx_clc_try_cancel)
+ self.clc_query_cancel = _op_wrapper(_cuda_op.ptx_clc_query_cancel)
self.fetch_register: Callable[..., Any] =
_op_wrapper(_cuda_op.ptx_fetch_register)
self.ld = _op_wrapper(_cuda_op.ptx_ld)
self.ld_acquire = _op_wrapper(_cuda_op.ptx_ld_acquire)
@@ -276,6 +278,9 @@ class MbarrierNamespace:
self.init = _op_wrapper(_cuda_op.ptx_mbarrier_init)
self.try_wait = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait)
self.try_wait_once = _op_wrapper(_cuda_op.ptx_mbarrier_try_wait_once)
+ self.try_wait_acquire_cluster = _op_wrapper(
+ _cuda_op.ptx_mbarrier_try_wait_acquire_cluster
+ )
self.arrive = MbarrierArriveNamespace()
@@ -284,6 +289,7 @@ class MbarrierArriveNamespace:
def __init__(self):
self.expect_tx = _op_wrapper(_cuda_op.ptx_mbarrier_arrive_expect_tx)
+ self.cluster_count =
_op_wrapper(_cuda_op.ptx_mbarrier_arrive_cluster_count)
def __call__(self, *args, **kwds):
return _op_wrapper(_cuda_op.ptx_mbarrier_arrive)(*args, **kwds)
diff --git a/python/tvm/support/nvcc.py b/python/tvm/support/nvcc.py
index ea5939fcef..b421042fb3 100644
--- a/python/tvm/support/nvcc.py
+++ b/python/tvm/support/nvcc.py
@@ -32,7 +32,7 @@ from . import utils
def compile_cuda(
- code, target_format=None, arch=None, options=None, path_target=None,
compiler="nvcc"
+ code, target_format=None, arch=None, options=None, path_target=None,
compiler="nvrtc"
):
"""Compile CUDA code with NVCC or NVRTC.
@@ -54,7 +54,7 @@ def compile_cuda(
Output file.
compiler : str, optional
- Compiler backend: "nvcc" or "nvrtc".
+ Compiler backend: "nvrtc" (default) or "nvcc".
This can be set by the TVM_CUDA_COMPILE_MODE environment variable.
Returns
@@ -191,7 +191,7 @@ def _compile_cuda_nvcc(
"--expt-extended-lambda",
"--use_fast_math",
"--ptxas-options=-v", # printing out number of registers
-
"--ptxas-options=--verbose,--register-usage-level=10,--warn-on-local-memory-usage",
# printing out number of registers # noqa: E501
+
f"--ptxas-options=--verbose,--register-usage-level={os.environ.get('TVM_CUDA_PTXAS_REG_LEVEL',
'10')},--warn-on-local-memory-usage", # noqa: E501
]
major, _ =
parse_compute_version(get_target_compute_version(Target.current(allow_none=True)))
@@ -342,14 +342,23 @@ def _compile_cuda_nvrtc(
line for line in code.splitlines() if line.strip() not in
headers_to_strip
)
- # NVRTC compiles device code and does not include the host-side cuda.h.
- # CUtensorMap is a host-side structure, to reference and use it in device
code,
- # we must forward-declare it for NVRTC.
+ # NVRTC compiles device code and does not include the host-side cuda.h
+ # (it is guarded behind ``#ifndef __CUDACC_RTC__`` in generated code and is
+ # stripped above), so the complete ``CUtensorMap_st`` layout that cuda.h
+ # normally provides is missing. TMA kernels take ``CUtensorMap`` by value
as
+ # ``__grid_constant__`` params, which requires the complete type. Define
the
+ # ``CUtensorMap_st`` tag with cuda.h's layout (64-byte aligned, 128 bytes)
+ # plus the typedef alias. This is compatible with cccl's
``<cuda/barrier>``,
+ # which only forward-declares ``struct CUtensorMap_st;`` and re-typedefs
the
+ # alias (a redundant typedef to the same type is legal in C++); defining
the
+ # tag rather than ``struct CUtensorMap`` avoids the previous redefinition
+ # clash with that header.
if "CUtensorMap" in code_filtered:
code_filtered = (
- "struct __align__(128) CUtensorMap {\n"
+ "struct alignas(64) CUtensorMap_st {\n"
" unsigned long long opaque[16];\n"
- "};\n\n" + code_filtered
+ "};\n"
+ "typedef struct CUtensorMap_st CUtensorMap;\n\n" + code_filtered
)
# Add standard type definitions and compatibility macros that NVRTC
doesn't provide.
@@ -371,6 +380,13 @@ using cuda::std::int64_t;
#define __volatile__ volatile
#endif
+// NVRTC does not pull in the host <math.h>, so INFINITY is undefined. Provide
it
+// from libcu++ (same float +inf value nvcc's <math.h> yields).
+#include <cuda/std/limits>
+#ifndef INFINITY
+#define INFINITY (::cuda::std::numeric_limits<float>::infinity())
+#endif
+
"""
code_filtered = nvrtc_preamble + code_filtered
@@ -406,6 +422,9 @@ namespace std {
compile_opts = [
f"--gpu-architecture={arch}".encode(),
b"-default-device",
+ # nvcc enables 128-bit integers by default on Linux; NVRTC requires the
+ # flag to be passed explicitly for kernels that use __int128_t.
+ b"--device-int128",
]
if use_nvshmem:
@@ -469,6 +488,21 @@ namespace std {
]
)
+ # Define the vector-deprecation silencing macros as no-ops for every NVRTC
+ # compile. These live in vector_types.h, which the fp4/fp6/fp8 headers use
+ # but do not include; depending on the include chain NVRTC pulls in, the
+ # macro can be left undefined and trigger a bogus "declaration has no
storage
+ # class" error. Defining them empty is harmless (they only gate host-side
+ # deprecation warnings) and matches what the NVSHMEM path already did.
+ compile_opts.extend(
+ [
+ b"-D__NV_SILENCE_DEPRECATION_BEGIN=",
+ b"-D__NV_SILENCE_DEPRECATION_END=",
+ b"-D__NV_SILENCE_HOST_DEPRECATION_BEGIN=",
+ b"-D__NV_SILENCE_HOST_DEPRECATION_END=",
+ ]
+ )
+
compile_opts.extend(
[
b"-U__CUDA_NO_HALF_OPERATORS__",
@@ -481,6 +515,24 @@ namespace std {
]
)
+ # Mirror the nvcc path's ptxas options. register-usage-level drives ptxas
+ # register allocation / instruction scheduling and is perf-relevant (FA4
was
+ # tuned around it, hence the env-driven default); -v and
+ # --warn-on-local-memory-usage are diagnostic. NVRTC rejects -O3 and
+ # --register-usage-level as top-level flags but forwards them to its
internal
+ # ptxas via --ptxas-options (ptxas already defaults to -O3). NB: unlike
nvcc,
+ # NVRTC does not comma-split --ptxas-options, so each ptxas flag must be
its
+ # own entry. The nvcc-only --expt-relaxed-constexpr /
--expt-extended-lambda
+ # have no NVRTC equivalent and are intentionally not mirrored.
+ reg_level = os.environ.get("TVM_CUDA_PTXAS_REG_LEVEL", "10")
+ compile_opts.extend(
+ [
+ b"--ptxas-options=-v",
+ f"--ptxas-options=--register-usage-level={reg_level}".encode(),
+ b"--ptxas-options=--warn-on-local-memory-usage",
+ ]
+ )
+
# Add user-provided options, filtering out nvcc-specific flags that nvrtc
doesn't support
if options:
nvcc_only_prefixes = (
@@ -802,7 +854,7 @@ def tvm_callback_cuda_compile(code):
Compile CUDA code using the configured backend (nvcc or nvrtc).
This callback is invoked by TVM's C++ backend during CUDA module
compilation.
- By default, uses nvcc to generate fatbin. The current target is fetched
+ By default, uses nvrtc to generate cubin. The current target is fetched
inside the callback (via ``tvm.target.Target.current(allow_none=True)``)
so the caller does not need to push/pop a target scope around the
invocation.
@@ -810,9 +862,9 @@ def tvm_callback_cuda_compile(code):
Environment Variables
---------------------
TVM_CUDA_COMPILE_MODE : str
- Compiler backend: "nvcc" (default) or "nvrtc"
- - "nvcc": Use nvcc subprocess, generates fatbin
+ Compiler backend: "nvrtc" (default) or "nvcc"
- "nvrtc": Use NVRTC via cuda-bindings for faster JIT, generates cubin
+ - "nvcc": Use nvcc subprocess, generates fatbin
TVM_KERNEL_DUMP : str
If set, dump generated CUDA/intermediate files and append "-lineinfo"
so profilers can
correlate SASS back to the dumped source.
@@ -830,7 +882,7 @@ def tvm_callback_cuda_compile(code):
# The current Target is fetched inside compile_cuda via
# tvm.target.Target.current(allow_none=True) when arch is unset; the
# caller no longer needs to push/pop a target scope.
- compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc").lower()
+ compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvrtc").lower()
if compiler == "nvrtc":
return compile_cuda(code, target_format="cubin", compiler="nvrtc")
diff --git a/python/tvm/tirx/script/builder/external_kernel.py
b/python/tvm/tirx/script/builder/external_kernel.py
index c1f5d58716..d56ed9ea03 100644
--- a/python/tvm/tirx/script/builder/external_kernel.py
+++ b/python/tvm/tirx/script/builder/external_kernel.py
@@ -159,7 +159,7 @@ class SourceKernel(BaseKernel): # pylint:
disable=too-few-public-methods
target_format = "cubin" if use_nvshmem else "ptx"
output_path = f"{temp_dir}/{kernel_name}.{target_format}"
- compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvcc")
+ compiler = os.environ.get("TVM_CUDA_COMPILE_MODE", "nvrtc")
nvcc.compile_cuda(
source_code,
target_format=target_format,
diff --git a/src/backend/cuda/op/target_builtin.cc
b/src/backend/cuda/op/target_builtin.cc
index 005fe5b322..353c04b501 100644
--- a/src/backend/cuda/op/target_builtin.cc
+++ b/src/backend/cuda/op/target_builtin.cc
@@ -152,6 +152,9 @@ TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_arrive_expect_tx)
TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait)
.set_attr<TCallEffectKind>("TCallEffectKind",
static_cast<int64_t>(CallEffectKind::kOpaque));
+TIRX_DEFINE_BUILTIN_FUNC(ptx_mbarrier_try_wait_acquire_cluster)
+ .set_attr<TCallEffectKind>("TCallEffectKind",
static_cast<int64_t>(CallEffectKind::kOpaque));
+
TIRX_DEFINE_BUILTIN_FUNC(ptx_bar_arrive)
.set_attr<TCallEffectKind>("TCallEffectKind",
static_cast<int64_t>(CallEffectKind::kOpaque));
@@ -497,6 +500,8 @@ const DeviceIntrinsicRegistration kDeviceIntrinsics[] = {
TIRX_DEVICE_INTRIN_ALIAS(ptx_bar_sync, ptx, kOpaque),
TIRX_DEVICE_INTRIN_ALIAS(ptx_barrier_cluster_arrive, ptx, kOpaque),
TIRX_DEVICE_INTRIN_ALIAS(ptx_barrier_cluster_wait, ptx, kOpaque),
+ TIRX_DEVICE_INTRIN_ALIAS(ptx_clc_query_cancel, ptx, kOpaque),
+ TIRX_DEVICE_INTRIN_ALIAS(ptx_clc_try_cancel, ptx, kOpaque),
TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async, ptx, kOpaque),
TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk, ptx, kOpaque),
TIRX_DEVICE_INTRIN_ALIAS(ptx_cp_async_bulk_commit_group, ptx, kOpaque),
@@ -540,6 +545,7 @@ const DeviceIntrinsicRegistration kDeviceIntrinsics[] = {
TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_init, ptx, kOpaque),
TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_test_wait_parity, ptx, kOpaque),
TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait, ptx, kOpaque),
+ TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait_acquire_cluster, ptx,
kOpaque),
TIRX_DEVICE_INTRIN_ALIAS(ptx_mbarrier_try_wait_once, ptx, kOpaque),
TIRX_DEVICE_INTRIN_ALIAS(ptx_mma, ptx, kOpaque),
TIRX_DEVICE_INTRIN_ALIAS(ptx_mma_legacy, ptx, kOpaque),
diff --git a/src/target/llvm/codegen_llvm.cc b/src/target/llvm/codegen_llvm.cc
index 88a28ebccb..f32dcdde11 100644
--- a/src/target/llvm/codegen_llvm.cc
+++ b/src/target/llvm/codegen_llvm.cc
@@ -133,6 +133,8 @@ void CodeGenLLVM::Init(const std::string& module_name,
LLVMTarget* llvm_target,
builder_.reset(new IRBuilder(*ctx));
module_.reset(new llvm::Module(module_name, *ctx));
md_builder_.reset(new llvm::MDBuilder(*ctx));
+ functions_.clear();
+ function_symbol_owners_.clear();
// types
t_void_ = llvm::Type::getVoidTy(*ctx);
t_void_p_ = llvmGetPointerTo(llvm::Type::getInt8Ty(*ctx),
GetGlobalAddressSpace());
@@ -260,6 +262,21 @@ llvm::Function* CodeGenLLVM::DeclareFunctionInternal(const
GlobalVar& gvar, cons
llvm::FunctionType::get(GetLLVMType(func->ret_type), param_types, false);
auto [symbol_name, linkage_type] = GetLinkage(gvar, func);
+ if (auto it = function_symbol_owners_.find(symbol_name); it !=
function_symbol_owners_.end()) {
+ constexpr const char* kFFISymbolPrefix = "__tvm_ffi_";
+ std::string user_symbol = symbol_name;
+ if (user_symbol.rfind(kFFISymbolPrefix, 0) == 0) {
+ user_symbol =
user_symbol.substr(std::char_traits<char>::length(kFFISymbolPrefix));
+ }
+ TVM_FFI_THROW(InternalError) << "Duplicate PrimFunc global_symbol '" <<
user_symbol
+ << "' in LLVM codegen: IRModule keys '" <<
it->second
+ << "' and '" << gvar->name_hint
+ << "' both lower to the same exported symbol
'" << symbol_name
+ << "'. "
+ << "Each exposed PrimFunc in one IRModule
must have a unique "
+ "global_symbol.";
+ }
+ function_symbol_owners_[symbol_name] = gvar->name_hint;
auto function = module_->getFunction(MakeStringRef(symbol_name));
if (function == nullptr) {
diff --git a/src/target/llvm/codegen_llvm.h b/src/target/llvm/codegen_llvm.h
index 8526b3f642..08396d596d 100644
--- a/src/target/llvm/codegen_llvm.h
+++ b/src/target/llvm/codegen_llvm.h
@@ -547,6 +547,9 @@ class CodeGenLLVM : public ExprFunctor<llvm::Value*(const
PrimExpr&)>,
// that function.
std::unordered_map<const GlobalVarNode*, llvm::Function*> functions_;
+ // Map from the generated LLVM function symbol to the GlobalVar that owns it.
+ std::unordered_map<std::string, std::string> function_symbol_owners_;
+
// Whether current function is restricted
bool is_restricted_{true};
// The analyzer information
diff --git a/src/tirx/ir/layout/tile_slice.cc b/src/tirx/ir/layout/tile_slice.cc
index 3f4db48379..ce1809ae99 100644
--- a/src/tirx/ir/layout/tile_slice.cc
+++ b/src/tirx/ir/layout/tile_slice.cc
@@ -144,7 +144,11 @@ ffi::Optional<TileLayout> SlicePerGroup(TileLayout layout,
PrimExpr begin, PrimE
ffi::Optional<Layout> TileLayoutNode::Slice(const Array<PrimExpr>& shape,
const Region& region) const {
arith::Analyzer analyzer;
- auto [grouped_layout, seps] = Group(ffi::GetRef<TileLayout>(this), shape);
+ // Canonicalize the whole layout first so scope fusion (e.g. wid_in_wg+laneid
+ // -> tid_in_wg) runs globally; otherwise grouping can split sibling thread
+ // axes and SlicePerGroup's per-group fusion leaves an ill-formed mix.
+ TileLayout canon = this->Canonicalize().as<TileLayout>().value();
+ auto [grouped_layout, seps] = Group(canon, shape);
std::vector<Iter> new_shard;
ffi::Map<Axis, PrimExpr> new_offset;
for (size_t i = 0; i < seps.size() - 1; ++i) {
diff --git a/tests/python/codegen/test_target_codegen_llvm.py
b/tests/python/codegen/test_target_codegen_llvm.py
index 7c093f9be2..624d587b82 100644
--- a/tests/python/codegen/test_target_codegen_llvm.py
+++ b/tests/python/codegen/test_target_codegen_llvm.py
@@ -30,6 +30,45 @@ from tvm.target.codegen import llvm_get_intrinsic_name,
llvm_lookup_intrinsic_id
from tvm.testing import env
[email protected](not env.has_llvm(), reason="need llvm")
+def test_duplicate_primfunc_global_symbol_diagnostic():
+ @I.ir_module(s_tir=True)
+ class Module:
+ @T.prim_func(s_tir=True)
+ def first_unique_key(A: T.Buffer((1,), "float32")):
+ T.func_attr({"global_symbol": "dup_symbol", "tirx.noalias": True})
+ A[0] = T.float32(1)
+
+ @T.prim_func(s_tir=True)
+ def second_unique_key(A: T.Buffer((1,), "float32")):
+ T.func_attr({"global_symbol": "dup_symbol", "tirx.noalias": True})
+ A[0] = T.float32(2)
+
+ with pytest.raises(
+ tvm.error.InternalError, match="Duplicate PrimFunc global_symbol
'dup_symbol'"
+ ) as err:
+ tvm.compile(Module, target="llvm")
+ assert "first_unique_key" in str(err.value)
+ assert "second_unique_key" in str(err.value)
+
+
[email protected](not env.has_llvm(), reason="need llvm")
+def test_unique_primfunc_global_symbols_compile():
+ @I.ir_module(s_tir=True)
+ class Module:
+ @T.prim_func(s_tir=True)
+ def first_unique_key(A: T.Buffer((1,), "float32")):
+ T.func_attr({"global_symbol": "dup_symbol_a", "tirx.noalias":
True})
+ A[0] = T.float32(1)
+
+ @T.prim_func(s_tir=True)
+ def second_unique_key(A: T.Buffer((1,), "float32")):
+ T.func_attr({"global_symbol": "dup_symbol_b", "tirx.noalias":
True})
+ A[0] = T.float32(2)
+
+ tvm.compile(Module, target="llvm")
+
+
@pytest.mark.skipif(not env.has_llvm(), reason="need llvm")
def test_llvm_intrin():
@I.ir_module(s_tir=True)
diff --git a/tests/python/tirx/codegen/test_codegen_cuda.py
b/tests/python/tirx/codegen/test_codegen_cuda.py
index f253d6d375..521a72f6d7 100644
--- a/tests/python/tirx/codegen/test_codegen_cuda.py
+++ b/tests/python/tirx/codegen/test_codegen_cuda.py
@@ -21,6 +21,7 @@ import pytest
import tvm
import tvm.testing
from tvm.script import tirx as T
+from tvm.testing import env
DEV = tvm.device("cuda")
@@ -118,6 +119,8 @@ def test_cuda_handle_uint64_reinterpret_codegen():
assert "*(void* *)" not in src
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_cuda_atomic_add():
@T.prim_func
def main(A: T.Buffer((1,), "int32"), B: T.Buffer((1,), "float32")):
@@ -442,6 +445,8 @@ def test_cuda_atomic_cas():
assert "tvm_builtin_cuda_atomic_cas" in src
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_cuda_func_call():
def test_add_one():
add_one = """
@@ -497,6 +502,8 @@ __device__ void print(int32_t a) {
test_print()
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_warp_shuffle_xor_sync():
# fmt: off
@T.prim_func
@@ -532,6 +539,8 @@ def test_warp_shuffle_xor_sync():
np.testing.assert_allclose(A.numpy(), A_ref)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("cp_size", [4, 8, 16])
@pytest.mark.parametrize("cache_hint", ["", "evict_last"])
@pytest.mark.parametrize("prefetch_size", [-1, 64, 128, 256])
@@ -575,6 +584,8 @@ def test_ptx_cp_async(cp_size, cache_hint, prefetch_size,
predicate, fill_mode):
print(src)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("trans", [False, True])
@pytest.mark.parametrize("num", [1, 2, 4])
def test_ptx_ldmatrix(trans, num):
diff --git a/tests/python/tirx/codegen/test_codegen_nvshmem.py
b/tests/python/tirx/codegen/test_codegen_nvshmem.py
index ff9f17170d..d386907742 100644
--- a/tests/python/tirx/codegen/test_codegen_nvshmem.py
+++ b/tests/python/tirx/codegen/test_codegen_nvshmem.py
@@ -28,6 +28,7 @@ from tvm.runtime import ShapeTuple
from tvm.runtime import disco as di
from tvm.script import tirx as T
from tvm.support.popen_pool import PopenWorker
+from tvm.testing import env
NUM_WORKERS = 4
@@ -61,6 +62,8 @@ def create_nvshmem_array(sess, shape, dtype,
init_data_fn=None, zero_out=True):
return arr
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.skip(reason="nvshmem doesn't work with pytest")
def test_codegen_nvshmem():
def _test_func():
diff --git a/tests/python/tirx/codegen/test_cuda_copy.py
b/tests/python/tirx/codegen/test_cuda_copy.py
index cb08f42473..047eb1f12c 100644
--- a/tests/python/tirx/codegen/test_cuda_copy.py
+++ b/tests/python/tirx/codegen/test_cuda_copy.py
@@ -21,6 +21,7 @@ import pytest
import tvm
from tvm.script import tirx as T
+from tvm.testing import env
DEV = tvm.cuda(0)
TARGET = tvm.target.Target("cuda")
@@ -34,6 +35,8 @@ def _build_and_run(func, *np_args):
return (*tuple(a.numpy() for a in rt_args), mod)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_copy_128b():
"""copy_128b: copies 16 bytes (4 float32 elements) via uint4 load/store."""
@@ -63,6 +66,8 @@ def test_copy_128b():
assert "tvm_builtin_copy_128b" in mod.mod.imports[0].inspect_source()
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_copy_64b():
"""copy_64b: copies 8 bytes (2 float32 elements) via uint2 load/store."""
@@ -92,6 +97,8 @@ def test_copy_64b():
assert "tvm_builtin_copy_64b" in mod.mod.imports[0].inspect_source()
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_copy_32b():
"""copy_32b: copies 4 bytes (1 float32 element) via unsigned int
load/store."""
@@ -121,6 +128,8 @@ def test_copy_32b():
assert "tvm_builtin_copy_32b" in mod.mod.imports[0].inspect_source()
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_copy_16b():
"""copy_16b: copies 2 bytes (1 float16 element) via unsigned short
load/store."""
@@ -150,6 +159,8 @@ def test_copy_16b():
assert "tvm_builtin_copy_16b" in mod.mod.imports[0].inspect_source()
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_copy_8b():
"""copy_8b: copies 1 byte (1 uint8 element) via unsigned char
load/store."""
diff --git a/tests/python/tirx/codegen/test_cuda_cta_reduce.py
b/tests/python/tirx/codegen/test_cuda_cta_reduce.py
index 51b8f1099a..bf07da1b67 100644
--- a/tests/python/tirx/codegen/test_cuda_cta_reduce.py
+++ b/tests/python/tirx/codegen/test_cuda_cta_reduce.py
@@ -21,6 +21,7 @@ import pytest
import tvm
from tvm.script import tirx as T
+from tvm.testing import env
DEV = tvm.cuda(0)
TARGET = tvm.target.Target("cuda")
@@ -35,6 +36,8 @@ def _build_and_run(func, n):
return out.numpy(), mod
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_cta_sum_4_warps():
"""CTA sum with 4 warps (128 threads): all threads get the same sum."""
NUM_WARPS = 4
@@ -61,6 +64,8 @@ def test_cta_sum_4_warps():
assert "cta_reduce_sum_4" in mod.mod.imports[0].inspect_source()
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_cta_sum_8_warps():
"""CTA sum with 8 warps (256 threads)."""
NUM_WARPS = 8
@@ -86,6 +91,8 @@ def test_cta_sum_8_warps():
np.testing.assert_allclose(result, np.full(N, expected))
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_cta_max_4_warps():
"""CTA max with 4 warps: all threads get the maximum value."""
NUM_WARPS = 4
@@ -110,6 +117,8 @@ def test_cta_max_4_warps():
np.testing.assert_allclose(result, np.full(N, float(N)))
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_cta_min_4_warps():
"""CTA min with 4 warps: all threads get the minimum value."""
NUM_WARPS = 4
@@ -134,6 +143,8 @@ def test_cta_min_4_warps():
np.testing.assert_allclose(result, np.full(N, 1.0))
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_cta_sum_1_warp():
"""CTA sum with 1 warp: degenerates to a pure warp reduce."""
NUM_WARPS = 1
@@ -159,6 +170,8 @@ def test_cta_sum_1_warp():
np.testing.assert_allclose(result, np.full(N, expected))
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("num_warps", [1, 2, 4, 8, 16])
def test_cta_sum_all_warp_counts(num_warps):
"""Parametric test: cta_sum with various warp counts."""
diff --git a/tests/python/tirx/codegen/test_cuda_warp_reduce.py
b/tests/python/tirx/codegen/test_cuda_warp_reduce.py
index df568a95e4..e5167a055c 100644
--- a/tests/python/tirx/codegen/test_cuda_warp_reduce.py
+++ b/tests/python/tirx/codegen/test_cuda_warp_reduce.py
@@ -21,6 +21,7 @@ import pytest
import tvm
from tvm.script import tirx as T
+from tvm.testing import env
DEV = tvm.cuda(0)
TARGET = tvm.target.Target("cuda")
@@ -35,6 +36,8 @@ def _build_and_run(func, n=32):
return out.numpy(), mod
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_warp_sum_full():
"""Full warp sum (width=32): each lane gets the sum of all 32 values."""
@@ -57,6 +60,8 @@ def test_warp_sum_full():
assert "warp_reduce_sum_32" in mod.mod.imports[0].inspect_source()
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_warp_sum_partial_8():
"""Partial warp sum (width=8): 4 groups of 8 lanes, each group sums
independently."""
@@ -85,6 +90,8 @@ def test_warp_sum_partial_8():
np.testing.assert_allclose(result, expected)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_warp_max_partial_4():
"""Partial warp max (width=4): 8 groups of 4 lanes."""
@@ -109,6 +116,8 @@ def test_warp_max_partial_4():
np.testing.assert_allclose(result, expected)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_warp_min_full():
"""Full warp min (width=32)."""
@@ -129,6 +138,8 @@ def test_warp_min_full():
np.testing.assert_allclose(result, np.full(32, 1.0))
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_warp_sum_partial_2():
"""Smallest partial warp sum (width=2): 16 pairs of adjacent lanes."""
@@ -155,6 +166,8 @@ def test_warp_sum_partial_2():
np.testing.assert_allclose(result, expected)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("width", [2, 4, 8, 16, 32])
def test_warp_sum_all_widths(width):
"""Parametric test: warp_sum with every valid width."""
diff --git a/tests/python/tirx/conftest.py b/tests/python/tirx/conftest.py
new file mode 100644
index 0000000000..fb8ba62f4f
--- /dev/null
+++ b/tests/python/tirx/conftest.py
@@ -0,0 +1,40 @@
+# Licensed to the Apache Software Foundation (ASF) under one
+# or more contributor license agreements. See the NOTICE file
+# distributed with this work for additional information
+# regarding copyright ownership. The ASF licenses this file
+# to you under the Apache License, Version 2.0 (the
+# "License"); you may not use this file except in compliance
+# with the License. You may obtain a copy of the License at
+#
+# http://www.apache.org/licenses/LICENSE-2.0
+#
+# Unless required by applicable law or agreed to in writing,
+# software distributed under the License is distributed on an
+# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
+# KIND, either express or implied. See the License for the
+# specific language governing permissions and limitations
+# under the License.
+"""Suite-level hardware gate for the tirx tests.
+
+The tirx kernels and codegen paths target Blackwell (sm_100a) — they emit
+PTX/SASS (tcgen05, tmem, cp.async ``.async`` modifiers, fp8 conversions, ...)
+that ptxas/NVRTC reject for older targets, and many tests execute on the
+device. Running the suite on a CPU-only node or a pre-sm_100 GPU therefore
+fails at compile/run time rather than skipping. Gate the whole directory on a
+real sm_100a device so it skips cleanly where the hardware is absent and runs
+in full where it is present.
+"""
+
+import pytest
+
+from tvm.testing import env
+
+
+def pytest_collection_modifyitems(config, items):
+ if env.has_cuda_compute(10):
+ return
+ skip = pytest.mark.skip(
+ reason="tirx suite requires a CUDA compute capability 10.0 (sm_100a)
device"
+ )
+ for item in items:
+ item.add_marker(skip)
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py
b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py
index 75faf61366..1824b41eae 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_fallback.py
@@ -32,6 +32,7 @@ import tvm
import tvm.testing
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
# Force the fallback dispatch to register before any test compiles a kernel.
# Without this import, in fresh pytest workers the `copy/fallback` variant
@@ -128,6 +129,8 @@ def _build_round_trip_kernel(scope, n_threads, shape,
dtype):
return kernel
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >=
9.0")
@pytest.mark.parametrize(
"scope,n_threads,shape,why",
[
@@ -158,6 +161,8 @@ def test_fallback_round_trip(scope, n_threads, shape, why):
np.testing.assert_array_equal(B.numpy(), A_np)
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >=
9.0")
def test_fallback_thread_scope():
"""``T.thread()`` — single thread, no gate. Either ``gmem_smem`` picks
it up (n_elements % 1 == 0) or ``fallback`` does — both end up emitting
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py
b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py
index dc5a46a751..c31ca79db9 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_gmem_smem.py
@@ -103,6 +103,8 @@ TASKS = [
]
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >=
9.0")
@pytest.mark.parametrize(
"scope,n_threads,shape",
[pytest.param(*t, id=f"{t[0]}-{t[1]}-{'x'.join(map(str, t[2]))}") for t in
TASKS],
@@ -194,6 +196,8 @@ def test_gmem_smem_roundtrip(scope, n_threads, shape,
dtype):
),
],
)
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >=
9.0")
@pytest.mark.parametrize(
"dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16",
"float32"]
)
diff --git a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py
b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py
index 4516225303..26c4d5de9b 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy/test_reg.py
@@ -35,6 +35,7 @@ import tvm
import tvm.testing
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
from tvm.tirx.layout import S, TileLayout, laneid, tid_in_wg, tx
@@ -228,6 +229,8 @@ def _expected(shape, dtype):
return out
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >=
9.0")
@pytest.mark.parametrize("non_r_scope", ["shared", "global"])
@pytest.mark.parametrize(
"scope,n_threads,k",
@@ -287,6 +290,8 @@ def test_reg_roundtrip(scope, n_threads, k, dtype,
non_r_scope):
),
],
)
[email protected]
[email protected](not env.has_cuda_compute(9), reason="need cuda compute >=
9.0")
@pytest.mark.parametrize(
"dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16",
"float32"]
)
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py
index b4d54d2b41..96f9283253 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_ldgsts.py
@@ -24,6 +24,7 @@ import tvm
import tvm.testing
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
from tvm.tirx.layout import S, TileLayout
@@ -65,6 +66,8 @@ from tvm.tirx.layout import S, TileLayout
),
],
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize(
"dtype", ["int8", "float8_e4m3fn", "float8_e5m2", "float16", "bfloat16",
"float32"]
)
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py
index 0f910a4376..55e32339c7 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem.py
@@ -24,10 +24,13 @@ import tvm
import tvm.testing
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
from tvm.tirx.layout import S, TCol, TileLayout, TLane
from tvm.tirx.layout import tid_in_wg as axis_tid_in_wg
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize("dtype", ["float16", "float32"])
@pytest.mark.parametrize("width_32b", [4, 8, 16, 32])
def test_copy_tmem2reg_async(dtype, width_32b):
@@ -132,6 +135,8 @@ def test_copy_tmem2reg_async(dtype, width_32b):
# ----------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize("dtype", ["uint8", "float16", "float32"])
@pytest.mark.parametrize("width_32b", [2, 4, 8, 16, 32, 64, 128])
@pytest.mark.parametrize("offset_32b", [0, 3, 10])
@@ -224,6 +229,8 @@ def test_copy_tmem2reg(dtype, width_32b, offset_32b):
np.testing.assert_allclose(B.numpy(), A_np)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize("dtype", ["float16", "float32"])
@pytest.mark.parametrize("width_32b", [4, 8, 16, 32])
@pytest.mark.parametrize("local_offset_32b", [0, 2, 4])
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
index 4209359460..aac93c0252 100644
---
a/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
+++
b/tests/python/tirx/operator/tile_primitive/cuda/copy_async/test_tmem_16xnb.py
@@ -43,6 +43,7 @@ import tvm
import tvm.testing
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
from tvm.tirx.layout import (
S,
TCol,
@@ -152,6 +153,8 @@ def _expected_reg_value_16b(
# --------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize("shape", list(_SHAPE_REPS))
@pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32]) # subset; full reps
below
@pytest.mark.parametrize("dtype", ["float32"])
@@ -162,6 +165,8 @@ def test_tcgen05_ld_16xnb_load_fp32(shape, rep, dtype):
_run_load_test(shape, rep, dtype)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize(
"shape, rep",
[
@@ -175,6 +180,8 @@ def test_tcgen05_ld_16xnb_load_fp32_large_rep(shape, rep):
_run_load_test(shape, rep, "float32")
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize("shape", list(_SHAPE_REPS))
@pytest.mark.parametrize("rep", [1, 2, 4, 8, 16, 32])
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@@ -201,6 +208,8 @@ def test_tcgen05_16xnb_roundtrip_16b(shape, rep, dtype):
# We only need to spot-check that the dispatch fires correctly and the per-
# thread reg ↔ TMEM mapping round-trips bit-exactly — the M=64 sweep above
# already covers the (lane, reg) decomposition, so a sparse rep set suffices.
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"])
@pytest.mark.parametrize("rep", [1, 2, 4])
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@@ -214,6 +223,8 @@ def test_tcgen05_16xnb_roundtrip_16b_M128(shape, rep,
dtype):
# with the scatter-encoded TileLayout that ``tmem_datapath_layout("F", ...)``
# produces. ``.16x*b`` M=64 PTX has the matching scatter built in, so the
# round-trip is bit-exact in the same way as Layout D + M=64.
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize("shape", ["16x64b", "16x128b", "16x256b"])
@pytest.mark.parametrize("rep", [1, 2, 4])
@pytest.mark.parametrize("dtype", ["float16", "bfloat16"])
@@ -639,6 +650,8 @@ def _run_load_test(shape: str, rep: int, dtype: str):
# --------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize("shape", list(_SHAPE_REPS))
@pytest.mark.parametrize("rep", [1, 4, 16])
@pytest.mark.parametrize("dtype", ["float32"])
@@ -853,5 +866,136 @@ def test_alloc_tcgen05_frag_wrapper_compiles(shape,
frag_rows, K_cols):
)
+# --------------------------------------------------------------------------
+# Test 3: column-slice loads of a wider frag
+#
+# An epilogue may allocate one wide ``(128, K)`` frag and load it from TMEM in
+# EPI_TILE-wide column chunks (``frag[:, c:c+w]``) so all loads are in flight
+# before a single ``wait.ld``. The ``.16x*b`` dispatch must emit each slice as
+# its own atom (``num_eff`` derived from the slice width) at the correct
+# per-slab register offset. We verify this is *bit-exact identical* to one
+# full-width load of the same frag — which the sweeps above already validate
+# against the layout-derived expectation. M=128 here exercises the 2-slab path
+# (the slice's two slabs live ``regs_per_thread_per_slab`` apart, not
adjacent).
+# --------------------------------------------------------------------------
+
+
+def _run_sliced_vs_full_load(shape, full_rep, n_chunks):
+ dtype = "float32"
+ K_cols_fp32 = _COL_FACTOR_FP32[shape] * full_rep
+ assert K_cols_fp32 % n_chunks == 0
+ chunk_elem = K_cols_fp32 // n_chunks # fp32: elem == fp32 col
+ frag_rows = 128 # M=128 => 2 slabs
+ per_thread_elems = _REGS_FACTOR[shape] * full_rep * 2 # *2 for the second
slab
+
+ tmem_col_width_32b = max(32, _next_pow2(K_cols_fp32))
+ stage_width_elem = tmem_col_width_32b
+ CHUNK_FP32 = 128
+ n_stage = tmem_col_width_32b // CHUNK_FP32 if tmem_col_width_32b >
CHUNK_FP32 else 1
+ stage_w = tmem_col_width_32b if n_stage == 1 else CHUNK_FP32
+ VEC_LEN = 4 # 128-bit / fp32
+
+ atom_view = tcgen05_atom_layout(shape, (frag_rows, K_cols_fp32), dtype)
+ stage_view = TileLayout(S[(128, stage_w) : (1 @ axis_tid_in_wg, 1)])
+
+ @T.prim_func
+ def kernel(A_ptr: T.handle, Bf_ptr: T.handle, Bs_ptr: T.handle) -> None:
+ A = T.match_buffer(A_ptr, (128, stage_width_elem), dtype)
+ Bf = T.match_buffer(Bf_ptr, (128, per_thread_elems), dtype) #
full-load dump
+ Bs = T.match_buffer(Bs_ptr, (128, per_thread_elems), dtype) #
sliced-load dump
+ A_flat = A.view(-1)
+
+ T.device_entry()
+ warp_id = T.warp_id([4])
+ T.cta_id([2])
+ wg_id = T.warpgroup_id([1])
+ T.warp_id_in_wg([4])
+ T.lane_id([32])
+ tid_in_wg = T.thread_id([128])
+
+ tmem_addr = T.alloc_shared([1], "uint32")
+ if wg_id == 0:
+ if warp_id == 0:
+ T.ptx.tcgen05.alloc(T.address_of(tmem_addr),
n_cols=tmem_col_width_32b, cta_group=1)
+ T.tvm_storage_sync("shared")
+ tmem = T.decl_buffer(
+ (128, stage_width_elem),
+ dtype,
+ scope="tmem",
+ allocated_addr=tmem_addr[0],
+ layout=TileLayout(S[(128, stage_width_elem) : (1 @ TLane, 1 @
TCol)]),
+ )
+ # Stage A -> TMEM via the standard .32x32b path.
+ stage_reg = T.alloc_local((stage_w,), dtype)
+ stage_local = stage_reg.view(128, stage_w, layout=stage_view)
+ for ci in range(n_stage):
+ coff = ci * stage_w
+ for i in range(stage_w // VEC_LEN):
+ g = T.meta_var(tid_in_wg * stage_width_elem + coff + i *
VEC_LEN)
+ Tx.copy(stage_reg[i * VEC_LEN : i * VEC_LEN + VEC_LEN],
A_flat[g : g + VEC_LEN])
+ T.cuda.cta_sync()
+ Tx.wg.copy_async(tmem[:, coff : coff + stage_w],
stage_local[:, :])
+ T.ptx.tcgen05.wait.st()
+ T.cuda.cta_sync()
+
+ # (a) one full-width load
+ ff = T.alloc_local((per_thread_elems,), dtype)
+ ffl = ff.view(frag_rows, K_cols_fp32, layout=atom_view)
+ Tx.wg.copy_async(ffl[:, :], tmem[0:frag_rows, 0:K_cols_fp32])
+ T.ptx.tcgen05.wait.ld()
+ T.cuda.cta_sync()
+ for i in range(per_thread_elems):
+ Bf[tid_in_wg, i] = ff[i]
+
+ # (b) the same frag loaded in n_chunks column slices
+ sf = T.alloc_local((per_thread_elems,), dtype)
+ sfl = sf.view(frag_rows, K_cols_fp32, layout=atom_view)
+ for ck in range(n_chunks):
+ lo = T.meta_var(ck * chunk_elem)
+ Tx.wg.copy_async(
+ sfl[:, lo : lo + chunk_elem], tmem[0:frag_rows, lo : lo +
chunk_elem]
+ )
+ T.ptx.tcgen05.wait.ld()
+ T.cuda.cta_sync()
+ for i in range(per_thread_elems):
+ Bs[tid_in_wg, i] = sf[i]
+
+ if warp_id == 0:
+ T.ptx.tcgen05.relinquish_alloc_permit(cta_group=1)
+ T.ptx.tcgen05.dealloc(tmem_addr[0], n_cols=tmem_col_width_32b,
cta_group=1)
+
+ target = tvm.target.Target("cuda")
+ with target:
+ mod = tvm.IRModule({"main": kernel})
+ mod = tvm.compile(mod, target=target, tir_pipeline="tirx")
+ A_np = tvm.testing.generate_random_array(dtype, (128,
stage_width_elem))
+ Bf_np = np.zeros((128, per_thread_elems), dtype=dtype)
+ Bs_np = np.zeros((128, per_thread_elems), dtype=dtype)
+ DEV = tvm.cuda(0)
+ A = tvm.runtime.tensor(A_np, DEV)
+ Bf = tvm.runtime.tensor(Bf_np, DEV)
+ Bs = tvm.runtime.tensor(Bs_np, DEV)
+ mod(A, Bf, Bs)
+ # Sliced load must reproduce the full-width load bit-for-bit.
+ np.testing.assert_array_equal(Bs.numpy().view(np.uint32),
Bf.numpy().view(np.uint32))
+
+
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
[email protected](
+ "full_rep, n_chunks",
+ [
+ (32, 8), # 16x256b.x32 (256 fp32 cols) loaded in 8 chunks of 32 cols
(nvfp4 EPI_TILE=32)
+ (32, 16), # ...in 16 chunks of 16 cols (nvfp4 EPI_TILE=16)
+ (32, 4), # ...in 4 chunks of 64 cols
+ (16, 8), # 16x256b.x16 (128 fp32 cols) in 8 chunks of 16 cols
+ (16, 2), # ...in 2 chunks of 64 cols
+ ],
+)
+def test_tcgen05_ld_16x256b_sliced_matches_full_M128(full_rep, n_chunks):
+ """Per-chunk column-slice load of a wide M=128 frag == full-width load."""
+ _run_sliced_vs_full_load("16x256b", full_rep, n_chunks)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py
b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py
index 1ce0d34ea6..8d39ba3556 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_binary.py
@@ -23,6 +23,7 @@ import tvm
import tvm.testing
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
from tvm.tirx.layout import S, TileLayout, wg_local_layout
@@ -67,6 +68,8 @@ from tvm.tirx.layout import S, TileLayout, wg_local_layout
),
],
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"])
@pytest.mark.parametrize("operands_type", ["region_region", "region_const",
"const_region"])
@pytest.mark.parametrize("dtype", ["float16"])
@@ -223,6 +226,8 @@ def test_binary_non_commutative_const_lhs_rejected(op_type):
tvm.compile(mod, target=target, tir_pipeline="tirx")
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("exec_scope", ["warp", "warpgroup"])
@pytest.mark.parametrize("op_type", ["add", "mul"])
def test_binary_op_shared_subcta_scope(exec_scope, op_type):
@@ -276,6 +281,8 @@ def test_binary_op_shared_subcta_scope(exec_scope, op_type):
tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-3)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("exec_scope", ["cta", "warpgroup", "warp"])
@pytest.mark.parametrize("rhs_kind", ["region", "broadcast", "const"])
@pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"])
@@ -392,6 +399,8 @@ def test_binary_op_local_subcta_trivial(exec_scope,
rhs_kind, op_type):
),
],
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("storage_scope", ["shared", "local"])
@pytest.mark.parametrize("exec_scope", ["cta", "thread"])
@pytest.mark.parametrize("op_type", ["add", "sub", "mul", "fdiv"])
@@ -495,6 +504,8 @@ def test_binary_op_vectorized(input, storage_scope,
exec_scope, op_type, dtype):
tvm.testing.assert_allclose(A_ref, A.numpy(), atol=atol)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["add", "sub", "mul"])
def test_binary_op_packed_f32x2_auto_dispatch(op_type):
target = tvm.target.Target("cuda")
@@ -568,6 +579,8 @@ def test_binary_op_packed_f32x2_auto_dispatch(op_type):
tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-3)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_name", ["add", "sub", "mul"])
def test_binary_op_warpgroup_wg_local_layout(op_name):
dtype = "float32"
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py
b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py
index aa0f5ced8f..02352638e4 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_fma.py
@@ -26,6 +26,7 @@ import tvm
import tvm.testing
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
from tvm.tirx.layout import S, TileLayout, wg_local_layout
@@ -41,6 +42,8 @@ def _get_sm_version():
# ---------------------------------------------------------------------------
# FMA op: scalar scale + scalar bias
# ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_fma_scalar_scalar():
sm = _get_sm_version()
if sm < 100:
@@ -78,6 +81,8 @@ def test_fma_scalar_scalar():
# ---------------------------------------------------------------------------
# FMA op: buffer scale + scalar bias (Horner pattern)
# ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_fma_buffer_scale_scalar_bias():
sm = _get_sm_version()
if sm < 100:
@@ -119,6 +124,8 @@ def test_fma_buffer_scale_scalar_bias():
# ---------------------------------------------------------------------------
# Binary op with scalar broadcast (PrimExpr scalar, e.g. BufferLoad)
# ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_mul_scalar_broadcast():
sm = _get_sm_version()
if sm < 100:
@@ -158,6 +165,8 @@ def test_mul_scalar_broadcast():
# ---------------------------------------------------------------------------
# Binary add with rounding mode
# ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_add_rounding_mode():
sm = _get_sm_version()
if sm < 100:
@@ -199,6 +208,8 @@ def test_add_rounding_mode():
# ---------------------------------------------------------------------------
# FMA op: layout=None local buffer (no TileLayout)
# ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_fma_no_layout():
sm = _get_sm_version()
if sm < 100:
@@ -238,6 +249,8 @@ def test_fma_no_layout():
# ---------------------------------------------------------------------------
# Binary sub with rounding mode (buffer-buffer)
# ---------------------------------------------------------------------------
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_sub_buffer_buffer_rounding():
sm = _get_sm_version()
if sm < 100:
@@ -278,6 +291,8 @@ def test_sub_buffer_buffer_rounding():
tvm.testing.assert_allclose(expected, A_dev.numpy(), atol=1e-6)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_fma_warpgroup_wg_local_layout():
rows, cols = 128, 8
dtype = "float32"
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
index c20df63beb..fb70b37541 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/elementwise/test_unary.py
@@ -23,6 +23,7 @@ import tvm
import tvm.testing
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
from tvm.tirx.cuda.operator.tile_primitive.layout_utils import (
cast_layout_supported_for_local as _cast_layout_supported_for_local,
)
@@ -54,6 +55,8 @@ from tvm.tirx.layout import S, TileLayout, laneid, tid_in_wg,
tx, warpid
),
],
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["zero", "sqrt"])
@pytest.mark.parametrize(
"src_dtype,dst_dtype", [("float16", "float16"), ("float32", "float16"),
("float32", "bfloat16")]
@@ -145,6 +148,8 @@ def test_unary_op_shared(input, op_type, src_dtype,
dst_dtype):
tvm.testing.assert_allclose(B_ref, B.numpy(), atol=1e-2, rtol=1e-2)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("exec_scope", ["warp", "warpgroup"])
def test_unary_op_shared_subcta_scope(exec_scope):
dtype = "float16"
@@ -209,6 +214,8 @@ def test_unary_op_shared_subcta_scope(exec_scope):
),
],
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["sqrt", "exp"])
@pytest.mark.parametrize("bias_type", ["const", "region"])
@pytest.mark.parametrize(
@@ -432,6 +439,8 @@ def test_unary_op_shared_with_bias_scale(input, op_type,
bias_type, src_dtype, d
),
],
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["reciprocal", "exp", "exp2"])
@pytest.mark.parametrize(
"src_dtype,dst_dtype", [("float16", "float16"), ("float32", "float16"),
("float32", "bfloat16")]
@@ -554,6 +563,8 @@ def test_unary_op_local(input, op_type, src_dtype,
dst_dtype):
),
],
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["sqrt", "exp"])
@pytest.mark.parametrize("bias_type", ["const", "region"])
@pytest.mark.parametrize(
@@ -682,6 +693,8 @@ def test_unary_op_local_with_bias_scale(input, op_type,
bias_type, src_dtype, ds
tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("shape", [(128, 8), (128, 4, 16), (128, 5, 5)])
@pytest.mark.parametrize("op_type", ["fill"])
@pytest.mark.parametrize("exec_scope", ["thread", "cta"])
@@ -740,6 +753,8 @@ def test_unary_op_vectorized(shape, op_type, exec_scope,
storage_scope):
tvm.testing.assert_allclose(A.numpy(), np.full(shape, value.value),
atol=1e-2)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["zero", "sqrt", "reciprocal", "exp",
"silu"])
@pytest.mark.parametrize("dtype", ["float16"])
def test_unary_op_local_thread_wise(op_type, dtype):
@@ -791,6 +806,8 @@ def test_unary_op_local_thread_wise(op_type, dtype):
tvm.testing.assert_allclose(A_ref, A.numpy(), atol=1e-2, rtol=1e-2)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("shape", [(8,), (16, 16), (5, 5)])
@pytest.mark.parametrize("A_dtype", ["float16", "float32"])
@pytest.mark.parametrize("B_dtype", ["float16", "float32"])
@@ -831,6 +848,8 @@ def test_cast_thread_local(shape, A_dtype, B_dtype):
tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"),
("float32", "bfloat16")])
def test_cast_warpgroup_local_view(A_dtype, B_dtype):
"""T.cast in warpgroup scope with offset (tid_in_wg + layout offset).
Covers offset/tid_in_wg/warpgroup scope.""" # noqa: E501
@@ -884,6 +903,8 @@ def test_cast_warpgroup_local_view(A_dtype, B_dtype):
tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"),
("float32", "bfloat16")])
def test_cast_warpgroup_src_layout_to_flat_uses_vec2_intrinsic(A_dtype,
B_dtype):
"""Regression: GEMM-epilogue cast pattern must emit the packed vec2 cuda
intrinsic.
@@ -944,6 +965,8 @@ def
test_cast_warpgroup_src_layout_to_flat_uses_vec2_intrinsic(A_dtype, B_dtype)
tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"),
("float32", "bfloat16")])
def test_cast_cta_local_view(A_dtype, B_dtype):
"""T.cast with view+layout in CTA scope (128 threads,
register->register)."""
@@ -988,6 +1011,8 @@ def test_cast_cta_local_view(A_dtype, B_dtype):
tvm.testing.assert_allclose(B.numpy(), B_ref, atol=1e-2)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("A_dtype,B_dtype", [("float32", "float16"),
("float32", "bfloat16")])
@pytest.mark.parametrize("slice_start,slice_end", [(0, 4), (2, 6), (4, 8)])
def test_cast_local_view_sliced(A_dtype, B_dtype, slice_start, slice_end):
@@ -1087,6 +1112,8 @@ def test_cast_layout_partition_and_validation():
check(part)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("slice_start,slice_end", [(0, 2), (2, 4)])
def test_cast_mixed_axes_and_subregion(slice_start, slice_end):
"""Test cast with mixed axes and subregion."""
@@ -1095,7 +1122,7 @@ def test_cast_mixed_axes_and_subregion(slice_start,
slice_end):
LOCAL_LEN = 4
full_shape = (8, N_WARPS, 4, LOCAL_LEN)
g_layout = TileLayout(S[full_shape])
- cast_layout = TileLayout(S[full_shape : (4 @ laneid, 2 @ warpid, 1 @
laneid, 1)])
+ cast_layout = TileLayout(S[full_shape : (4 @ laneid, 1 @ warpid, 1 @
laneid, 1)])
A_ref = np.zeros(full_shape, dtype="float32")
for j in range(full_shape[0]):
@@ -1207,8 +1234,12 @@ def test_cast_validate_extent_mismatch_rejected():
target = tvm.target.Target("cuda")
with target:
mod = tvm.IRModule({"main": kernel})
+ # The mismatched dst also fails the scope-level check (thread axes
don't
+ # span the full CTA), which fires first — either rejection is fine.
with pytest.raises(
- Exception, match="tile_local_valid|layout signature
mismatch|thread part mismatch"
+ Exception,
+ match="tile_local_valid|layout signature mismatch|thread part
mismatch"
+ "|do not tile a complete|not the full",
):
tvm.compile(mod, target=target, tir_pipeline="tirx")
@@ -1277,5 +1308,138 @@ def test_cast_vec2_packed_dispatch(src_dtype,
dst_dtype, intrinsic):
), f"expected packed vec2 cast {intrinsic}; got:\n{src[:2000]}"
+# -----------------------------------------------------------------------------
+# Scope-level operand check: a warp/wg/cta reg op needs a scope-level layout
+# (thread axes spanning all the scope's threads), not a thread-local .local().
+# -----------------------------------------------------------------------------
+_SL_ROWS, _SL_COLS = 128, 8
+
+
+def _sl_compile(fn):
+ target = tvm.target.Target("cuda")
+ with target:
+ tvm.compile(tvm.IRModule({"main": fn}), target=target,
tir_pipeline="tirx")
+
+
+def test_cast_wg_rejects_thread_local_view():
+ """Tx.wg.cast on a .local() (thread-axis-stripped) view is rejected."""
+
+ @T.prim_func
+ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
+ A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ T.device_entry()
+ _bx = T.cta_id([1])
+ _wg = T.warpgroup_id([1])
+ tid = T.thread_id_in_wg([_SL_ROWS])
+ src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+ dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+ src_row = src.local(_SL_COLS)
+ for i in T.serial(_SL_COLS):
+ src_row[i] = A[tid, i]
+ Tx.wg.cast(dst.local(), src.local())
+ dst_row = dst.local(_SL_COLS)
+ for i in T.serial(_SL_COLS):
+ B[tid, i] = dst_row[i]
+
+ with pytest.raises(Exception, match="thread-local view"):
+ _sl_compile(kernel)
+
+
+def test_cast_cta_rejects_thread_local_view():
+ """Tx.cta.cast on a .local() view is rejected (cta -> tx)."""
+
+ @T.prim_func
+ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
+ A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ T.device_entry()
+ _bx = T.cta_id([1])
+ tx_var = T.thread_id([_SL_ROWS])
+ src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]))
+ dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tx, 1)]))
+ src_row = src.local(_SL_COLS)
+ for i in T.serial(_SL_COLS):
+ src_row[i] = A[tx_var, i]
+ Tx.cta.cast(dst.local(), src.local())
+ dst_row = dst.local(_SL_COLS)
+ for i in T.serial(_SL_COLS):
+ B[tx_var, i] = dst_row[i]
+
+ with pytest.raises(Exception, match="thread-local view"):
+ _sl_compile(kernel)
+
+
+def test_cast_wg_rejects_partial_thread_coverage():
+ """A tid_in_wg layout covering only 64 of the 128 wg threads is
rejected."""
+ half = 64
+
+ @T.prim_func
+ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
+ A = T.match_buffer(A_ptr, (half, _SL_COLS), "float32",
layout=TileLayout(S[(half, _SL_COLS)]))
+ B = T.match_buffer(B_ptr, (half, _SL_COLS), "float16",
layout=TileLayout(S[(half, _SL_COLS)]))
+ T.device_entry()
+ _bx = T.cta_id([1])
+ _wg = T.warpgroup_id([1])
+ tid = T.thread_id_in_wg([_SL_ROWS])
+ src = T.alloc_buffer((half, _SL_COLS), "float32", scope="local",
layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+ dst = T.alloc_buffer((half, _SL_COLS), "float16", scope="local",
layout=TileLayout(S[(half, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+ src_row = src.local(_SL_COLS)
+ for i in T.serial(_SL_COLS):
+ src_row[i] = A[tid, i]
+ Tx.wg.cast(dst, src)
+ dst_row = dst.local(_SL_COLS)
+ for i in T.serial(_SL_COLS):
+ B[tid, i] = dst_row[i]
+
+ with pytest.raises(Exception, match="not the full 128"):
+ _sl_compile(kernel)
+
+
+def test_cast_wg_accepts_wg_level_layout():
+ """Tx.wg.cast on a wg-level (tid_in_wg-distributed) layout compiles."""
+
+ @T.prim_func
+ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
+ A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ T.device_entry()
+ _bx = T.cta_id([1])
+ _wg = T.warpgroup_id([1])
+ tid = T.thread_id_in_wg([_SL_ROWS])
+ src = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float32", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+ dst = T.alloc_buffer((_SL_ROWS, _SL_COLS), "float16", scope="local",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS) : (1 @ tid_in_wg, 1)]))
+ src_row = src.local(_SL_COLS)
+ for i in T.serial(_SL_COLS):
+ src_row[i] = A[tid, i]
+ Tx.wg.cast(dst, src)
+ dst_row = dst.local(_SL_COLS)
+ for i in T.serial(_SL_COLS):
+ B[tid, i] = dst_row[i]
+
+ _sl_compile(kernel)
+
+
+def test_cast_thread_accepts_local_view():
+ """thread scope is exempt: a thread-axis-free local tile still compiles."""
+
+ @T.prim_func
+ def kernel(A_ptr: T.handle, B_ptr: T.handle) -> None:
+ A = T.match_buffer(A_ptr, (_SL_ROWS, _SL_COLS), "float32",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ B = T.match_buffer(B_ptr, (_SL_ROWS, _SL_COLS), "float16",
layout=TileLayout(S[(_SL_ROWS, _SL_COLS)]))
+ T.device_entry()
+ _bx = T.cta_id([1])
+ tx_var = T.thread_id([_SL_ROWS])
+ src = T.alloc_buffer((_SL_COLS,), "float32", scope="local",
layout=TileLayout(S[(_SL_COLS,)]))
+ dst = T.alloc_buffer((_SL_COLS,), "float16", scope="local",
layout=TileLayout(S[(_SL_COLS,)]))
+ for i in T.serial(_SL_COLS):
+ src[i] = A[tx_var, i]
+ Tx.cast(dst, src)
+ for i in T.serial(_SL_COLS):
+ B[tx_var, i] = dst[i]
+
+ _sl_compile(kernel)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
index e0a270e709..32ac00e39d 100644
---
a/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
+++
b/tests/python/tirx/operator/tile_primitive/cuda/gemm_async/test_gemm_async.py
@@ -32,6 +32,7 @@ import tvm.testing
from tvm.ir.type import PointerType, PrimType
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
from tvm.tirx.cuda.operator.tile_primitive.gemm_async import sf_tmem_layout
from tvm.tirx.cuda.operator.tile_primitive.tma_utils import (
mma_atom_layout,
@@ -167,6 +168,8 @@ def pack_sf_fp8_uint32(sf_uint8, n_total=128):
return packed
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize(
"task",
[
@@ -293,6 +296,8 @@ def test_gemm_tcgen05_cta_group_1(task):
np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
def test_gemm_tcgen05_cta_group_1_layout_f_m64():
"""M=64 MMA with C operand allocated as Layout F (datapath="F").
@@ -405,6 +410,8 @@ def test_gemm_tcgen05_cta_group_1_layout_f_m64():
np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-2, rtol=1e-2)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize(
"task",
[
@@ -545,6 +552,8 @@ def test_gemm_tcgen05_cta_group_2(task):
np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
def test_gemm_tcgen05_cta_group_2_layout_b():
"""Test cta_group=2 with Layout B (2x2 datapath, M=128 total, 64 per CTA).
@@ -675,6 +684,8 @@ def test_gemm_tcgen05_cta_group_2_layout_b():
np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
@pytest.mark.parametrize(
"task",
@@ -864,6 +875,8 @@ def test_gemm_block_scaled_fp8_cta_group_1(task):
np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
@pytest.mark.parametrize(
"task",
@@ -1089,6 +1102,8 @@ def test_gemm_block_scaled_fp8_cta_group_2(task):
np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
def test_gemm_block_scaled_nvfp4_cta_group_1():
"""Test block-scaled nvfp4 GEMM with cta_group=1.
@@ -1258,6 +1273,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_1():
np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
def test_gemm_block_scaled_nvfp4_cta_group_2():
"""Test block-scaled nvfp4 GEMM with cta_group=2.
@@ -1462,6 +1479,8 @@ def test_gemm_block_scaled_nvfp4_cta_group_2():
np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1.0, rtol=0.15)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.skipif(ml_dtypes is None, reason="Requires ml_dtypes")
def test_gemm_block_scaled_fp8_sf_id():
"""Test sf_id auto-derivation from layout for fp8 block-scaled MMA.
@@ -1681,6 +1700,8 @@ def test_gemm_block_scaled_fp8_sf_id():
)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize(
"task",
[
@@ -1960,6 +1981,8 @@ def test_gemm_tcgen05_arbitrary_tiles(task):
np.testing.assert_allclose(C_tvm.numpy(), C_ref, atol=1e-3, rtol=1e-3)
[email protected]
[email protected](not env.has_cuda_compute(10), reason="need cuda compute >=
10.0")
@pytest.mark.parametrize("k_lo,k_hi", [(0, 16), (0, 32), (16, 32), (16, 48),
(32, 64)])
def test_gemm_tcgen05_contiguous_kslice_partial_k(k_lo, k_hi):
"""A slice on the *contiguous* (K) axis of a swizzled gemm_async operand
must
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py
b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py
index 67cc1e0bd6..0402719ba1 100644
---
a/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py
+++
b/tests/python/tirx/operator/tile_primitive/cuda/permute_layout/test_permute_layout.py
@@ -43,6 +43,7 @@ import tvm
import tvm.testing
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
# Helpers exposed by the dispatcher module for direct algorithm tests.
from tvm.tirx.cuda.operator.tile_primitive.permute_layout.warp_xor_swizzle
import (
@@ -167,6 +168,8 @@ def _compile_and_run(prim_func, np_inputs):
return [t.numpy() for t in tensors], mod.mod.imports[0].inspect_source()
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@needs_cuda
@pytest.mark.parametrize(
"name, pipe, blk, dtype",
@@ -231,6 +234,8 @@ def test_sf_blockwise_transpose(name, pipe, blk, dtype):
np.testing.assert_array_equal(B_flat, ref, err_msg=f"{name} stage {s}")
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@needs_cuda
def test_identity_passes_through_as_copy():
"""L_src == L_dst should still compile and produce a correct (identity)
copy."""
@@ -255,6 +260,8 @@ def test_identity_passes_through_as_copy():
np.testing.assert_array_equal(B_out, A_np)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@needs_cuda
@pytest.mark.parametrize("dtype", ["uint32", "int32", "float32"])
@pytest.mark.parametrize(
diff --git
a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py
b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py
index 0474ad2dc4..9031aa4f48 100644
--- a/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py
+++ b/tests/python/tirx/operator/tile_primitive/cuda/reduction/test_reduction.py
@@ -21,6 +21,7 @@ import tvm
import tvm.testing
from tvm.script import tirx as T
from tvm.script.tirx import tile as Tx
+from tvm.testing import env
from tvm.tirx.layout import R, S, TileLayout, laneid, wg_local_layout
@@ -41,6 +42,8 @@ from tvm.tirx.layout import R, S, TileLayout, laneid,
wg_local_layout
((32, 32), (32,), (-1,), (1, 1), (2,), (5, 8), (5,)),
],
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["sum", "max", "min"])
@pytest.mark.parametrize("dtype", ["float32", "float16"])
@pytest.mark.parametrize("accum", [False, True])
@@ -129,6 +132,8 @@ def test_reduction_shared(
tvm.testing.assert_allclose(ref, B.numpy()[tuple(reduce_slice_dst)],
atol=atol)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("exec_scope", ["warp", "warpgroup", "thread"])
@pytest.mark.parametrize("op_type", ["sum", "max", "min"])
@pytest.mark.parametrize("accum", [False, True])
@@ -264,6 +269,8 @@ def test_reduction_shared_subscope(exec_scope, op_type,
accum):
((2, 3, 4), (3, 4), (0,)),
],
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["sum", "max", "min"])
@pytest.mark.parametrize("accum", [False, True])
def test_reduction_local_thread_wise(src_shape, dst_shape, axes, op_type,
accum):
@@ -367,6 +374,8 @@ def test_reduction_local_thread_wise(src_shape, dst_shape,
axes, op_type, accum)
((4, 8), (1, 8), (1,), False, None),
],
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["sum", "max", "min"])
def test_reduction_local_view_basic(inner_dims, dst_dims, axes, accum,
slice_end, op_type):
"""Test view-based local reduction with simple purely-local layouts."""
@@ -484,6 +493,8 @@ def test_reduction_local_view_basic(inner_dims, dst_dims,
axes, accum, slice_end
tvm.testing.assert_allclose(ref, B.numpy(), atol=1e-5)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("n_groups, n_warps", [(1, 1), (1, 4), (2, 8)])
@pytest.mark.parametrize("op_type", ["sum", "max", "min"])
@pytest.mark.parametrize("dtype", ["float32", "float16"])
@@ -616,6 +627,8 @@ def test_reduction_local_view_complex(n_groups, n_warps,
op_type, dtype, shuffle
tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("reduction_len", [8, 16, 64, 128, 256, 7, 10, 15,
100])
@pytest.mark.parametrize("op_type", ["max", "min"])
@pytest.mark.parametrize("accum", [False, True])
@@ -685,6 +698,8 @@ def
test_reduction_local_optimized_3input_maxmin(reduction_len, op_type, accum):
tvm.testing.assert_allclose(B_ref, B.numpy()[0], atol=1e-5)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("reduction_len", [8, 16, 64, 128, 256, 9, 17, 63, 65,
100])
@pytest.mark.parametrize("accum", [False, True])
def test_reduction_local_optimized_packed_add_sum(reduction_len, accum):
@@ -746,6 +761,8 @@ def
test_reduction_local_optimized_packed_add_sum(reduction_len, accum):
tvm.testing.assert_allclose(B_ref, B.numpy()[0], atol=1e-4)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["sum", "max"])
@pytest.mark.parametrize("dtype", ["float32", "float16"])
def test_reduction_op_warp_shuffle(op_type, dtype):
@@ -807,6 +824,8 @@ def test_reduction_op_warp_shuffle(op_type, dtype):
tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_type", ["sum", "max"])
@pytest.mark.parametrize("dtype", ["float32", "float16"])
def test_reduction_op_warp_shuffle_multi_elem(op_type, dtype):
@@ -875,6 +894,8 @@ def test_reduction_op_warp_shuffle_multi_elem(op_type,
dtype):
tvm.testing.assert_allclose(B_ref, B.numpy(), atol=atol)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_reduction_warp_shuffle_multi_warp_loop():
"""Test intra-warp + cross-warp reduction via T.sum in a for loop with
multiple warps.
@@ -951,6 +972,8 @@ def test_reduction_warp_shuffle_multi_warp_loop():
tvm.testing.assert_allclose(B_ref, B_dev.numpy(), atol=1e-3)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
@pytest.mark.parametrize("op_name", ["sum", "max"])
def test_reduction_warpgroup_wg_local_layout(op_name):
rows, cols = 128, 16
diff --git a/tests/python/tirx/test_buffer_print.py
b/tests/python/tirx/test_buffer_print.py
index 211f4d3903..dbd0da8f84 100644
--- a/tests/python/tirx/test_buffer_print.py
+++ b/tests/python/tirx/test_buffer_print.py
@@ -18,10 +18,12 @@
import re
import numpy as np
+import pytest
import tvm
import tvm.testing
from tvm.script import tirx as T
+from tvm.testing import env
def generate_random_data(shape, dtype):
@@ -181,6 +183,8 @@ def verify_cuda_code_string(func, expected_var_name,
expected_string_literal):
)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_print():
DEV = tvm.cuda()
target = tvm.target.Target("cuda")
diff --git a/tests/python/tirx/test_control_flow.py
b/tests/python/tirx/test_control_flow.py
index 1f905bd03c..9085c2b021 100644
--- a/tests/python/tirx/test_control_flow.py
+++ b/tests/python/tirx/test_control_flow.py
@@ -15,9 +15,11 @@
# specific language governing permissions and limitations
# under the License.
import numpy as np
+import pytest
import tvm
from tvm.script import tirx as T
+from tvm.testing import env
def run_test_break_continue(func, shape, expected):
@@ -32,6 +34,8 @@ def run_test_break_continue(func, shape, expected):
np.testing.assert_allclose(arr.numpy(), expected)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_break_continue1():
# fmt: off
@T.prim_func
@@ -53,6 +57,8 @@ def test_break_continue1():
run_test_break_continue(func, (10,), expected)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_break_continue2():
# fmt: off
@T.prim_func
@@ -79,6 +85,8 @@ def test_break_continue2():
run_test_break_continue(func, (9,), expected)
[email protected]
[email protected](not env.has_cuda(), reason="need cuda")
def test_break_continue3():
# fmt: off
@T.prim_func
diff --git a/tests/python/tirx/test_layout.py b/tests/python/tirx/test_layout.py
index e3711cb00c..0dcf212ce2 100644
--- a/tests/python/tirx/test_layout.py
+++ b/tests/python/tirx/test_layout.py
@@ -1733,5 +1733,40 @@ def test_slice_single_shard_skips_defensive_floormod():
# we just assert offset is non-empty and structurally sane (not None).
+def test_slice_tcgen05_frag_layout_scope_consistent():
+ """Slicing a wid_in_wg+laneid frag layout (tcgen05 16x256b) must stay
+ scope-consistent: the sliced result canonicalizes to a single tid_in_wg
+ chain over the full 128 threads (regression for the per-group-fusion bug).
+ """
+ frag = TileLayout(
+ S[(4, 2, 2, 8, 4, 4, 2) : (1 @ wid_in_wg, 16, 2, 4 @ laneid, 4, 1 @
laneid, 1)]
+ )
+
+ def thread_chain(layout):
+ canon = layout.canonicalize()
+ names = {it.axis.name for it in canon.shard if it.axis.is_thread()}
+ titers = sorted(
+ ((int(it.stride), int(it.extent)) for it in canon.shard if
it.axis.is_thread()),
+ )
+ running = 1
+ for stride, extent in titers:
+ assert stride == running, f"non-contiguous thread chain: {titers}"
+ running *= extent
+ return names, running
+
+ with tvm.target.Target("cuda"):
+ # Full-region slice and a column sub-slice must both canonicalize to a
+ # single tid_in_wg chain covering all 128 warpgroup threads.
+ full = frag.slice([128, 32], [(0, 128), (0, 32)])
+ names, total = thread_chain(full)
+ assert names == {"tid_in_wg"}, names
+ assert total == 128, total
+
+ col = frag.slice([128, 32], [(0, 128), (16, 32)])
+ names_c, total_c = thread_chain(col)
+ assert names_c == {"tid_in_wg"}, names_c
+ assert total_c == 128, total_c
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git a/tests/scripts/task_python_unittest.sh
b/tests/scripts/task_python_unittest.sh
index ec052281ad..15bb51bdf7 100755
--- a/tests/scripts/task_python_unittest.sh
+++ b/tests/scripts/task_python_unittest.sh
@@ -55,6 +55,7 @@ TEST_FILES=(
"tirx-analysis"
"tirx-base"
"tirx-transform"
+ "tirx"
"tvmscript"
"relax"
)