This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 82293c8c11 [Relax][Frontend][KVCache] Extend masked sequence prefill 
to causal left-padding (#19431)
82293c8c11 is described below

commit 82293c8c1157743ef69549cb7bf23f52b9f342be
Author: Xijing Wang <[email protected]>
AuthorDate: Fri Apr 24 22:52:34 2026 -0400

    [Relax][Frontend][KVCache] Extend masked sequence prefill to causal 
left-padding (#19431)
    
    This PR extends `_attention_sequence_prefill_with_mask` to support a
    second mask regime for decoder-style embedding workloads.
    
    ### Summary
    
    - Keep the existing right-padded bidirectional behavior as
    `mask_mode="padded"`.
    - Add `mask_mode="causal_padded_left"` for left-padded causal sequence
    prefill.
    - Add a `softmax_update_causal_padded_left` macro for the online softmax
    mask.
    - Add tests for causal left-padding with zero, full, mixed, and GQA
    valid lengths.
    
    ### Motivation
    
    This is a TVM-side kernel dependency for the first-class embedding
    serving work tracked in mlc-ai/mlc-llm#3451.
    
    The existing masked sequence prefill kernel supports encoder-style
    batches where real tokens occupy the valid prefix `[0, valid_len)` and
    padding is on the right.
    
    Decoder-style embedding batches, such as the decoder-only embedding
    path, commonly left-pad variable-length inputs so the final real token /
    EOS lands at the same final column across the batch. This allows
    last-token pooling to read `output[:, -1, :]`, while still requiring
    causal masking within each valid suffix.
    
    For each batch row:
    
    - `mask_mode="padded"`: real tokens are `[0, valid_len)`.
    - `mask_mode="causal_padded_left"`: real tokens are `[seq_len -
    valid_len, seq_len)`, with `col <= row`.
    
    ### Testing
    
    - `git diff --check`
    - Attempted:
    `python -m pytest -q
    tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py -k
    'causal_padded_left or valid_len_mixed'`
---
 python/tvm/relax/frontend/nn/llm/_kernel_common.py |  53 ++++++-
 .../tvm/relax/frontend/nn/llm/_prefill_kernels.py  |  86 +++++++---
 ...test_frontend_nn_llm_sequence_prefill_masked.py | 173 +++++++++++++++++++--
 3 files changed, 268 insertions(+), 44 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/_kernel_common.py 
b/python/tvm/relax/frontend/nn/llm/_kernel_common.py
index 3f881daf70..e7a526cf19 100644
--- a/python/tvm/relax/frontend/nn/llm/_kernel_common.py
+++ b/python/tvm/relax/frontend/nn/llm/_kernel_common.py
@@ -347,7 +347,7 @@ def _make_prefill_macros(tile_x, tile_y, tile_z, tile_o, 
bdx, num_warps, group_s
         S_smem: T.Buffer, m_smem: T.Buffer, d_smem: T.Buffer, m_prev_smem: 
T.Buffer,
         m_new: T.Buffer, m_prev: T.Buffer, d_new: T.Buffer,
         ty: T.int32, tx: T.int32, LH_start: T.int32, L_kv_start: T.int32,
-        valid_len: T.int32, qo_len: T.int32,
+        valid_len: T.int32, qo_len: T.int32, kv_len: T.int32,
     ):
         # Same three-phase online softmax as softmax_update_causal but with a
         # per-batch right-padding mask in place of causal masking.
@@ -383,7 +383,56 @@ def _make_prefill_macros(tile_x, tile_y, tile_z, tile_o, 
bdx, num_warps, group_s
                     m_prev_smem[row] = m_prev[i]
         T.tvm_storage_sync("shared")
 
-    return init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, 
softmax_update_valid_length, advance_tile_batch, paged_store_output_lse
+    @T.macro
+    def softmax_update_causal_padded_left(
+        S_smem: T.Buffer, m_smem: T.Buffer, d_smem: T.Buffer, m_prev_smem: 
T.Buffer,
+        m_new: T.Buffer, m_prev: T.Buffer, d_new: T.Buffer,
+        ty: T.int32, tx: T.int32, LH_start: T.int32, L_kv_start: T.int32,
+        valid_len: T.int32, qo_len: T.int32, kv_len: T.int32,
+    ):
+        # Three-phase online softmax with left-padding + causal mask. Real
+        # queries occupy [qo_len - valid_len, qo_len); real keys occupy
+        # [kv_len - valid_len, kv_len). Causal keeps
+        # col <= row + (kv_len - qo_len) within those valid suffixes.
+        for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
+            row: T.int32 = i * bdx * num_warps + ty * bdx + tx
+            if row < tile_x:
+                with T.sblock("update1"):
+                    m_prev[i] = m_smem[row]
+                    m_new[i] = m_smem[row]
+                    row_: T.int32 = (LH_start + row) // group_size
+                    pad_q: T.int32 = qo_len - valid_len
+                    pad_kv: T.int32 = kv_len - valid_len
+                    for j in T.serial(tile_z):
+                        col_: T.int32 = L_kv_start + j
+                        if tirx.And(tirx.And(row_ < qo_len, row_ >= pad_q), 
tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1)):
+                            m_new[i] = T.max(m_new[i], S_smem[row, j])
+                    d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
+        for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
+            row: T.int32 = i * bdx * num_warps + ty * bdx + tx
+            with T.sblock("update"):
+                for j in T.serial(tile_z):
+                    if row < tile_x:
+                        row_: T.int32 = (LH_start + row) // group_size
+                        pad_q: T.int32 = qo_len - valid_len
+                        pad_kv: T.int32 = kv_len - valid_len
+                        col_: T.int32 = L_kv_start + j
+                        if tirx.And(tirx.And(row_ < qo_len, row_ >= pad_q), 
tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1)):
+                            S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
+                        else:
+                            S_smem[row, j] = T.exp2(-5e4 - m_new[i])
+        for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
+            row: T.int32 = i * bdx * num_warps + ty * bdx + tx
+            if row < tile_x:
+                with T.sblock("update"):
+                    for j in T.serial(tile_z):
+                        d_new[i] += S_smem[row, j]
+                    m_smem[row] = m_new[i]
+                    d_smem[row] = d_new[i]
+                    m_prev_smem[row] = m_prev[i]
+        T.tvm_storage_sync("shared")
+
+    return init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, 
softmax_update_valid_length, advance_tile_batch, paged_store_output_lse, 
softmax_update_causal_padded_left
 
 
 def _get_prefill_kernel_config(h_kv, h_q, d, dtype, target: Target):
diff --git a/python/tvm/relax/frontend/nn/llm/_prefill_kernels.py 
b/python/tvm/relax/frontend/nn/llm/_prefill_kernels.py
index eba7e21133..2068db5bb4 100644
--- a/python/tvm/relax/frontend/nn/llm/_prefill_kernels.py
+++ b/python/tvm/relax/frontend/nn/llm/_prefill_kernels.py
@@ -27,7 +27,7 @@ K/V loading path that is specific to its storage layout.
 
 # pylint: 
disable=too-many-statements,too-many-arguments,invalid-name,line-too-long
 import math
-from typing import Any
+from typing import Any, Literal
 
 import tvm
 from tvm import tirx
@@ -212,7 +212,7 @@ def _attention_prefill(h_kv, h_q, d, dtype, sliding_window: 
bool, rope_scaling:
     if sliding_window:
         global_symbol += "_sliding_window"
 
-    init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _, 
advance_tile_batch, paged_store_output_lse = _make_prefill_macros(tile_x, 
tile_y, tile_z, tile_y, bdx, num_warps, group_size)
+    init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _, 
advance_tile_batch, paged_store_output_lse, *_ = _make_prefill_macros(tile_x, 
tile_y, tile_z, tile_y, bdx, num_warps, group_size)
 
     # pylint: disable=too-many-branches
     @T.prim_func
@@ -489,24 +489,60 @@ def _attention_sequence_prefill(h_kv, h_q, d, dtype, 
target: Target, causal=0, s
 
 
 
-def _attention_sequence_prefill_with_mask(h_kv, h_q, d, dtype, target: Target, 
sm_scale=1.0):
-    """Tiled sequence prefill kernel with a per-batch right-padding mask.
-
-    This is the counterpart of :func:`_attention_sequence_prefill` for batched
-    encoder-style inputs where each sample in the batch is padded to a common
-    ``seq_len`` but only the first ``valid_lens[b]`` tokens carry real content.
-    The kernel takes an extra ``valid_lens`` buffer of shape ``(batch_size,)``
-    and applies the padding mask inside the QKV load path and the online
-    softmax update, so no explicit mask tensor broadcast or additive bias is
-    needed on the host side.
-
-    Semantics: for batch ``b``, positions ``[0, valid_lens[b])`` are real and
-    positions ``[valid_lens[b], seq_len)`` are padding. Padding queries and
-    keys/values are zeroed at load time; padded ``(row, col)`` pairs are
-    excluded from the max/sum of the online softmax via a ``-inf`` slot.
+def _attention_sequence_prefill_with_mask(
+    h_kv, h_q, d, dtype, target: Target, sm_scale=1.0, *,
+    mask_mode: Literal["padded", "causal_padded_left"] = "padded",
+):
+    """Tiled sequence prefill kernel with a per-batch padding mask.
+
+    Supports two mask regimes selected by ``mask_mode``:
+
+    * ``"padded"`` (default) — bidirectional attention with right-padding.
+      For batch ``b``, positions ``[0, valid_lens[b])`` are real and
+      positions ``[valid_lens[b], seq_len)`` are padding. This is the
+      encoder-style batch regime.
+    * ``"causal_padded_left"`` — causal attention with left-padding. For
+      batch ``b``, positions ``[seq_len - valid_lens[b], seq_len)`` are
+      real and positions ``[0, seq_len - valid_lens[b])`` are padding;
+      the causal constraint additionally keeps ``col <= row`` within the
+      valid range. This is the decoder-embedding batch regime, where
+      last-token pooling is a cheap slice of the final row.
+
+    In both modes the kernel takes an extra ``valid_lens`` buffer of shape
+    ``(batch_size,)`` and applies the mask inside the QKV load path and the
+    online softmax update, so no explicit mask tensor broadcast or additive
+    bias is needed on the host side. Padding queries and keys/values are
+    zeroed at load time; masked ``(row, col)`` pairs are excluded from the
+    max/sum of the online softmax via a ``-inf`` slot. ``valid_len`` is the
+    per-batch real token count shared by Q and K/V; cross-attention with
+    independent Q/K validity is out of scope.
     """
     _, LOAD_VEC, group_size, bdx, num_warps, tile_x, tile_y, tile_z = 
_get_prefill_kernel_config(h_kv, h_q, d, dtype, target)
-    init_states, compute_s_gemm, _, compute_o_gemm, 
softmax_update_valid_length, *_ = _make_prefill_macros(tile_x, tile_y, tile_z, 
tile_y, bdx, num_warps, group_size)
+    (
+        init_states, compute_s_gemm, _, compute_o_gemm, 
softmax_update_valid_length,
+        _, _, softmax_update_causal_padded_left,
+    ) = _make_prefill_macros(tile_x, tile_y, tile_z, tile_y, bdx, num_warps, 
group_size)
+
+    softmax_update = (
+        softmax_update_valid_length
+        if mask_mode == "padded"
+        else softmax_update_causal_padded_left
+    )
+
+    def _q_row_valid(row, valid_len, qo_len):
+        # Row-validity predicate for Q load (TIR expression); mask_mode is
+        # captured at closure time so the prim_func body stays specialised.
+        if mask_mode == "padded":
+            return tirx.And(row < qo_len, row < valid_len)
+        pad = qo_len - valid_len
+        return tirx.And(row < qo_len, row >= pad)
+
+    def _kv_col_valid(col, valid_len, kv_len):
+        # Column-validity predicate for K/V load (TIR expression).
+        if mask_mode == "padded":
+            return tirx.And(col < kv_len, col < valid_len)
+        pad = kv_len - valid_len
+        return tirx.And(col < kv_len, col >= pad)
 
     @T.prim_func
     def batch_sequence_prefill_kv_masked(  # pylint: disable=too-many-branches
@@ -551,7 +587,7 @@ def _attention_sequence_prefill_with_mask(h_kv, h_q, d, 
dtype, target: Target, s
 
                             init_states(m_smem, d_smem, O_local, ty, tx)
 
-                            # Load Q; padded rows are zeroed so they 
contribute nothing downstream.
+                            # Load Q; rows outside the valid range are zeroed 
so they contribute nothing downstream.
                             for li, lj in T.grid(tile_x, tile_y):
                                 with T.sblock("Q_load"):
                                     i, j = T.axis.remap("SS", [li, lj])
@@ -559,7 +595,7 @@ def _attention_sequence_prefill_with_mask(h_kv, h_q, d, 
dtype, target: Target, s
                                     T.writes()
                                     cur_L = (LH_start + i) // group_size
                                     cur_H_qo = by * group_size + (LH_start + 
i) % group_size
-                                    if tirx.And(cur_L < qo_len, cur_L < 
valid_len):
+                                    if _q_row_valid(cur_L, valid_len, qo_len):
                                         Q_smem[i, j] = q[b_idx, cur_L, 
cur_H_qo, j]
                                     else:
                                         Q_smem[i, j] = 0.0
@@ -574,7 +610,7 @@ def _attention_sequence_prefill_with_mask(h_kv, h_q, d, 
dtype, target: Target, s
                                         T.reads()
                                         T.writes()
                                         cur_L = L_kv_start + i
-                                        if tirx.And(cur_L < kv_len, cur_L < 
valid_len):
+                                        if _kv_col_valid(cur_L, valid_len, 
kv_len):
                                             K_smem[i, j] = k[b_idx, L_kv_base 
+ cur_L, by, j]
                                         else:
                                             K_smem[i, j] = 0.0
@@ -585,14 +621,14 @@ def _attention_sequence_prefill_with_mask(h_kv, h_q, d, 
dtype, target: Target, s
                                         T.reads()
                                         T.writes()
                                         cur_L = L_kv_start + i
-                                        if tirx.And(cur_L < kv_len, cur_L < 
valid_len):
+                                        if _kv_col_valid(cur_L, valid_len, 
kv_len):
                                             V_smem[i, j] = v[b_idx, L_kv_base 
+ cur_L, by, j]
                                         else:
                                             V_smem[i, j] = 0.0
                                 T.tvm_storage_sync("shared")
 
                                 compute_s_gemm(Q_smem, K_smem, S_local, 
S_smem, sm_scale)
-                                softmax_update_valid_length(S_smem, m_smem, 
d_smem, m_prev_smem, m_new, m_prev, d_new, ty, tx, LH_start, L_kv_start, 
valid_len, qo_len)
+                                softmax_update(S_smem, m_smem, d_smem, 
m_prev_smem, m_new, m_prev, d_new, ty, tx, LH_start, L_kv_start, valid_len, 
qo_len, kv_len)
                                 compute_o_gemm(S_smem, V_smem, O_local, 
m_prev_smem, m_smem)
 
                             # Store O
@@ -741,7 +777,7 @@ def _attention_prefill_ragged_cpu(h_kv, h_q, d_qk, d_v, 
dtype, rope_scaling: dic
 
 def _attention_prefill_ragged(h_kv, h_q, d_qk, d_v, dtype, rope_scaling: 
dict[str, Any], target: Target):
     NUM_BLKS, LOAD_VEC, group_size, bdx, num_warps, tile_x, tile_y, tile_z = 
_get_prefill_kernel_config(h_kv, h_q, d_qk, dtype, target)
-    init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _, 
advance_tile_batch, paged_store_output_lse = _make_prefill_macros(tile_x, 
tile_y, tile_z, d_v, bdx, num_warps, group_size)
+    init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _, 
advance_tile_batch, paged_store_output_lse, *_ = _make_prefill_macros(tile_x, 
tile_y, tile_z, d_v, bdx, num_warps, group_size)
 
     @T.prim_func
     def batch_prefill_ragged_kv(  # pylint: disable=too-many-branches
@@ -874,7 +910,7 @@ def _attention_prefill_ragged(h_kv, h_q, d_qk, d_v, dtype, 
rope_scaling: dict[st
 def _attention_prefill_mla(h_q, d_latent, d_rope, dtype, sliding_window: bool, 
target: Target, page_size: int = 16):
     d_qk = d_latent + d_rope
     NUM_BLKS, LOAD_VEC, group_size, bdx, num_warps, tile_x, tile_y, tile_z = 
_get_prefill_kernel_config(1, h_q, d_qk, dtype, target)
-    init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _, 
advance_tile_batch, paged_store_output_lse = _make_prefill_macros(tile_x, 
tile_y, tile_z, d_latent, bdx, num_warps, group_size)
+    init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _, 
advance_tile_batch, paged_store_output_lse, *_ = _make_prefill_macros(tile_x, 
tile_y, tile_z, d_latent, bdx, num_warps, group_size)
 
     global_symbol = "batch_prefill_paged_kv_mla"
     if sliding_window:
diff --git a/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py 
b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
index b64ef459d8..51ea992268 100644
--- a/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
+++ b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
@@ -16,22 +16,28 @@
 # under the License.
 """Focused correctness tests for ``_attention_sequence_prefill_with_mask``.
 
-The masked variant is the encoder-style counterpart of
-``_attention_sequence_prefill``: each sample in a padded batch carries its
-own ``valid_len`` and the kernel applies the padding mask inside the QKV
-load path and the online softmax update. These tests cover the four shape
-/ mask regimes that can break the kernel independently of any scheduler
+The masked variant supports two regimes selected by ``mask_mode``:
+
+* ``"padded"`` — encoder-style right-padded bidirectional attention.
+* ``"causal_padded_left"`` — decoder-embedding-style left-padded causal
+  attention. Real tokens occupy ``[seq_len - valid_len, seq_len)`` and
+  the causal constraint keeps ``col <= row`` within the valid range.
+
+In both regimes each sample in a padded batch carries its own
+``valid_len`` and the kernel applies the mask inside the QKV load path
+and the online softmax update. These tests cover the four shape / mask
+regimes that can break each kernel independently of any scheduler
 tuning:
 
 * ``valid_len == 0``       — entire batch row is padding
 * ``valid_len == seq_len`` — full-length row, must match the unmasked kernel
-* mixed ``valid_lens``     — typical encoder batch
+* mixed ``valid_lens``     — typical padded batch
 * grouped-query attention  — ``h_q > h_kv`` with ``group_size > 1``
 
-The reference is a float32 NumPy implementation of masked softmax attention
-restricted to the valid prefix, so the kernel is only compared on the
-unpadded positions (padded positions are intentionally free to contain
-arbitrary garbage).
+The references are float32 NumPy implementations of masked softmax
+attention restricted to the valid prefix/suffix, so the kernel is only
+compared on the unpadded positions (padded positions are intentionally
+free to contain arbitrary garbage).
 """
 # ruff: noqa: E501
 import math
@@ -44,7 +50,7 @@ from tvm.relax.frontend.nn.llm.kv_cache import 
_attention_sequence_prefill_with_
 
 
 def _reference_masked_attention(q, k, v, valid_lens, sm_scale):
-    """NumPy fp32 reference. Only the first ``valid_lens[b]`` rows are 
written."""
+    """Right-pad bidirectional reference. Only the first ``valid_lens[b]`` 
rows are written."""
     batch, seq_q, h_q, d = q.shape
     _, seq_kv, h_kv, _ = k.shape
     group_size = h_q // h_kv
@@ -69,7 +75,42 @@ def _reference_masked_attention(q, k, v, valid_lens, 
sm_scale):
     return out
 
 
-def _build_masked_prefill(h_kv, h_q, d, dtype, target):
+def _reference_masked_attention_causal_padded_left(q, k, v, valid_lens, 
sm_scale):
+    """Left-pad causal reference.
+
+    Real tokens occupy ``[seq_q - valid_len, seq_q)`` for queries and
+    ``[seq_kv - valid_len, seq_kv)`` for keys/values. Only the valid query
+    suffix rows are written; padded rows stay zeroed.
+    """
+    batch, seq_q, h_q, d = q.shape
+    _, seq_kv, h_kv, _ = k.shape
+    group_size = h_q // h_kv
+    out = np.zeros_like(q, dtype=np.float32)
+    q32 = q.astype(np.float32)
+    k32 = k.astype(np.float32)
+    v32 = v.astype(np.float32)
+    for b in range(batch):
+        L = int(valid_lens[b])
+        if L == 0:
+            continue
+        pad_q = seq_q - L
+        pad_kv = seq_kv - L
+        for h in range(h_q):
+            hk = h // group_size
+            qh = q32[b, pad_q:, h, :]  # [L, d]
+            kh = k32[b, pad_kv:, hk, :]  # [L, d]
+            vh = v32[b, pad_kv:, hk, :]  # [L, d]
+            s = (qh @ kh.T) * sm_scale  # [L, L]
+            # Causal on the LxL valid block: mask upper triangle to -inf.
+            s = s + np.triu(np.full((L, L), -np.inf), k=1)
+            m = s.max(axis=-1, keepdims=True)
+            e = np.exp(s - m)
+            p = e / e.sum(axis=-1, keepdims=True)
+            out[b, pad_q:, h, :] = p @ vh
+    return out
+
+
+def _build_masked_prefill(h_kv, h_q, d, dtype, target, mask_mode="padded"):
     sm_scale = 1.0 / math.sqrt(d)
     tir_func = _attention_sequence_prefill_with_mask(
         h_kv=h_kv,
@@ -78,6 +119,7 @@ def _build_masked_prefill(h_kv, h_q, d, dtype, target):
         dtype=dtype,
         target=target,
         sm_scale=sm_scale,
+        mask_mode=mask_mode,
     )
     mod = tvm.IRModule({"main": tir_func})
     return tvm.tirx.build(mod["main"], target=target), sm_scale
@@ -93,17 +135,21 @@ def _run_case(
     batch,
     seq,
     valid_lens,
+    seq_kv=None,
     dtype="float16",
     seed=0,
+    mask_mode="padded",
 ):
     target = tvm.target.Target(target)
-    built, sm_scale = _build_masked_prefill(h_kv, h_q, d, dtype, target)
+    built, sm_scale = _build_masked_prefill(h_kv, h_q, d, dtype, target, 
mask_mode=mask_mode)
 
+    if seq_kv is None:
+        seq_kv = seq
     np_dtype = {"float16": np.float16, "float32": np.float32}[dtype]
     rng = np.random.default_rng(seed)
     q_np = (rng.standard_normal((batch, seq, h_q, d)) * 0.1).astype(np_dtype)
-    k_np = (rng.standard_normal((batch, seq, h_kv, d)) * 0.1).astype(np_dtype)
-    v_np = (rng.standard_normal((batch, seq, h_kv, d)) * 0.1).astype(np_dtype)
+    k_np = (rng.standard_normal((batch, seq_kv, h_kv, d)) * 
0.1).astype(np_dtype)
+    v_np = (rng.standard_normal((batch, seq_kv, h_kv, d)) * 
0.1).astype(np_dtype)
     valid_np = np.asarray(valid_lens, dtype=np.int32)
     out_np = np.zeros((batch, seq, h_q, d), dtype=np_dtype)
     lse_np = np.zeros((batch, seq, h_q), dtype=np_dtype)
@@ -118,7 +164,10 @@ def _run_case(
     built.main(q_nd, k_nd, v_nd, valid_nd, out_nd, lse_nd)
 
     got = out_nd.numpy().astype(np.float32)
-    ref = _reference_masked_attention(q_np, k_np, v_np, valid_np, sm_scale)
+    if mask_mode == "padded":
+        ref = _reference_masked_attention(q_np, k_np, v_np, valid_np, sm_scale)
+    else:
+        ref = _reference_masked_attention_causal_padded_left(q_np, k_np, v_np, 
valid_np, sm_scale)
 
     # Only compare valid rows. Padding rows are undefined by design.
     rtol, atol = (2e-2, 2e-2) if dtype == "float16" else (1e-4, 1e-4)
@@ -126,7 +175,11 @@ def _run_case(
         L = int(valid_np[b])
         if L == 0:
             continue
-        np.testing.assert_allclose(got[b, :L], ref[b, :L], rtol=rtol, 
atol=atol)
+        if mask_mode == "padded":
+            np.testing.assert_allclose(got[b, :L], ref[b, :L], rtol=rtol, 
atol=atol)
+        else:
+            pad_q = seq - L
+            np.testing.assert_allclose(got[b, pad_q:], ref[b, pad_q:], 
rtol=rtol, atol=atol)
 
 
 @tvm.testing.requires_gpu
@@ -193,5 +246,91 @@ def test_valid_len_mixed_gqa(target, dev):
     )
 
 
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_causal_padded_left_valid_len_zero(target, dev):
+    """Causal left-pad: all samples are fully padded."""
+    _run_case(
+        target=target,
+        dev=dev,
+        h_kv=4,
+        h_q=4,
+        d=64,
+        batch=2,
+        seq=16,
+        valid_lens=[0, 0],
+        mask_mode="causal_padded_left",
+    )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_causal_padded_left_valid_len_full(target, dev):
+    """Causal left-pad: all samples are fully valid — degenerates to plain 
causal attention."""
+    _run_case(
+        target=target,
+        dev=dev,
+        h_kv=4,
+        h_q=4,
+        d=64,
+        batch=2,
+        seq=32,
+        valid_lens=[32, 32],
+        mask_mode="causal_padded_left",
+    )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_causal_padded_left_valid_len_mixed(target, dev):
+    """Causal left-pad: typical decoder-embedding batch with mixed lengths."""
+    _run_case(
+        target=target,
+        dev=dev,
+        h_kv=4,
+        h_q=4,
+        d=64,
+        batch=4,
+        seq=64,
+        valid_lens=[10, 64, 5, 33],
+        mask_mode="causal_padded_left",
+    )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_causal_padded_left_valid_len_mixed_gqa(target, dev):
+    """Causal left-pad: grouped-query attention with mixed lengths."""
+    _run_case(
+        target=target,
+        dev=dev,
+        h_kv=2,
+        h_q=4,
+        d=64,
+        batch=3,
+        seq=32,
+        valid_lens=[8, 32, 17],
+        mask_mode="causal_padded_left",
+    )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_causal_padded_left_qo_len_differs_from_kv_len(target, dev):
+    """Causal left-pad: Q and K/V may have different padded lengths."""
+    _run_case(
+        target=target,
+        dev=dev,
+        h_kv=2,
+        h_q=4,
+        d=64,
+        batch=3,
+        seq=32,
+        seq_kv=48,
+        valid_lens=[8, 32, 17],
+        mask_mode="causal_padded_left",
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to