This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 82293c8c11 [Relax][Frontend][KVCache] Extend masked sequence prefill
to causal left-padding (#19431)
82293c8c11 is described below
commit 82293c8c1157743ef69549cb7bf23f52b9f342be
Author: Xijing Wang <[email protected]>
AuthorDate: Fri Apr 24 22:52:34 2026 -0400
[Relax][Frontend][KVCache] Extend masked sequence prefill to causal
left-padding (#19431)
This PR extends `_attention_sequence_prefill_with_mask` to support a
second mask regime for decoder-style embedding workloads.
### Summary
- Keep the existing right-padded bidirectional behavior as
`mask_mode="padded"`.
- Add `mask_mode="causal_padded_left"` for left-padded causal sequence
prefill.
- Add a `softmax_update_causal_padded_left` macro for the online softmax
mask.
- Add tests for causal left-padding with zero, full, mixed, and GQA
valid lengths.
### Motivation
This is a TVM-side kernel dependency for the first-class embedding
serving work tracked in mlc-ai/mlc-llm#3451.
The existing masked sequence prefill kernel supports encoder-style
batches where real tokens occupy the valid prefix `[0, valid_len)` and
padding is on the right.
Decoder-style embedding batches, such as the decoder-only embedding
path, commonly left-pad variable-length inputs so the final real token /
EOS lands at the same final column across the batch. This allows
last-token pooling to read `output[:, -1, :]`, while still requiring
causal masking within each valid suffix.
For each batch row:
- `mask_mode="padded"`: real tokens are `[0, valid_len)`.
- `mask_mode="causal_padded_left"`: real tokens are `[seq_len -
valid_len, seq_len)`, with `col <= row`.
### Testing
- `git diff --check`
- Attempted:
`python -m pytest -q
tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py -k
'causal_padded_left or valid_len_mixed'`
---
python/tvm/relax/frontend/nn/llm/_kernel_common.py | 53 ++++++-
.../tvm/relax/frontend/nn/llm/_prefill_kernels.py | 86 +++++++---
...test_frontend_nn_llm_sequence_prefill_masked.py | 173 +++++++++++++++++++--
3 files changed, 268 insertions(+), 44 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/_kernel_common.py
b/python/tvm/relax/frontend/nn/llm/_kernel_common.py
index 3f881daf70..e7a526cf19 100644
--- a/python/tvm/relax/frontend/nn/llm/_kernel_common.py
+++ b/python/tvm/relax/frontend/nn/llm/_kernel_common.py
@@ -347,7 +347,7 @@ def _make_prefill_macros(tile_x, tile_y, tile_z, tile_o,
bdx, num_warps, group_s
S_smem: T.Buffer, m_smem: T.Buffer, d_smem: T.Buffer, m_prev_smem:
T.Buffer,
m_new: T.Buffer, m_prev: T.Buffer, d_new: T.Buffer,
ty: T.int32, tx: T.int32, LH_start: T.int32, L_kv_start: T.int32,
- valid_len: T.int32, qo_len: T.int32,
+ valid_len: T.int32, qo_len: T.int32, kv_len: T.int32,
):
# Same three-phase online softmax as softmax_update_causal but with a
# per-batch right-padding mask in place of causal masking.
@@ -383,7 +383,56 @@ def _make_prefill_macros(tile_x, tile_y, tile_z, tile_o,
bdx, num_warps, group_s
m_prev_smem[row] = m_prev[i]
T.tvm_storage_sync("shared")
- return init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm,
softmax_update_valid_length, advance_tile_batch, paged_store_output_lse
+ @T.macro
+ def softmax_update_causal_padded_left(
+ S_smem: T.Buffer, m_smem: T.Buffer, d_smem: T.Buffer, m_prev_smem:
T.Buffer,
+ m_new: T.Buffer, m_prev: T.Buffer, d_new: T.Buffer,
+ ty: T.int32, tx: T.int32, LH_start: T.int32, L_kv_start: T.int32,
+ valid_len: T.int32, qo_len: T.int32, kv_len: T.int32,
+ ):
+ # Three-phase online softmax with left-padding + causal mask. Real
+ # queries occupy [qo_len - valid_len, qo_len); real keys occupy
+ # [kv_len - valid_len, kv_len). Causal keeps
+ # col <= row + (kv_len - qo_len) within those valid suffixes.
+ for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
+ row: T.int32 = i * bdx * num_warps + ty * bdx + tx
+ if row < tile_x:
+ with T.sblock("update1"):
+ m_prev[i] = m_smem[row]
+ m_new[i] = m_smem[row]
+ row_: T.int32 = (LH_start + row) // group_size
+ pad_q: T.int32 = qo_len - valid_len
+ pad_kv: T.int32 = kv_len - valid_len
+ for j in T.serial(tile_z):
+ col_: T.int32 = L_kv_start + j
+ if tirx.And(tirx.And(row_ < qo_len, row_ >= pad_q),
tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1)):
+ m_new[i] = T.max(m_new[i], S_smem[row, j])
+ d_new[i] = d_smem[row] * T.exp2(m_prev[i] - m_new[i])
+ for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
+ row: T.int32 = i * bdx * num_warps + ty * bdx + tx
+ with T.sblock("update"):
+ for j in T.serial(tile_z):
+ if row < tile_x:
+ row_: T.int32 = (LH_start + row) // group_size
+ pad_q: T.int32 = qo_len - valid_len
+ pad_kv: T.int32 = kv_len - valid_len
+ col_: T.int32 = L_kv_start + j
+ if tirx.And(tirx.And(row_ < qo_len, row_ >= pad_q),
tirx.And(col_ >= pad_kv, col_ < kv_len - qo_len + row_ + 1)):
+ S_smem[row, j] = T.exp2(S_smem[row, j] - m_new[i])
+ else:
+ S_smem[row, j] = T.exp2(-5e4 - m_new[i])
+ for i in T.serial(T.ceildiv(tile_x, bdx * num_warps)):
+ row: T.int32 = i * bdx * num_warps + ty * bdx + tx
+ if row < tile_x:
+ with T.sblock("update"):
+ for j in T.serial(tile_z):
+ d_new[i] += S_smem[row, j]
+ m_smem[row] = m_new[i]
+ d_smem[row] = d_new[i]
+ m_prev_smem[row] = m_prev[i]
+ T.tvm_storage_sync("shared")
+
+ return init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm,
softmax_update_valid_length, advance_tile_batch, paged_store_output_lse,
softmax_update_causal_padded_left
def _get_prefill_kernel_config(h_kv, h_q, d, dtype, target: Target):
diff --git a/python/tvm/relax/frontend/nn/llm/_prefill_kernels.py
b/python/tvm/relax/frontend/nn/llm/_prefill_kernels.py
index eba7e21133..2068db5bb4 100644
--- a/python/tvm/relax/frontend/nn/llm/_prefill_kernels.py
+++ b/python/tvm/relax/frontend/nn/llm/_prefill_kernels.py
@@ -27,7 +27,7 @@ K/V loading path that is specific to its storage layout.
# pylint:
disable=too-many-statements,too-many-arguments,invalid-name,line-too-long
import math
-from typing import Any
+from typing import Any, Literal
import tvm
from tvm import tirx
@@ -212,7 +212,7 @@ def _attention_prefill(h_kv, h_q, d, dtype, sliding_window:
bool, rope_scaling:
if sliding_window:
global_symbol += "_sliding_window"
- init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _,
advance_tile_batch, paged_store_output_lse = _make_prefill_macros(tile_x,
tile_y, tile_z, tile_y, bdx, num_warps, group_size)
+ init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _,
advance_tile_batch, paged_store_output_lse, *_ = _make_prefill_macros(tile_x,
tile_y, tile_z, tile_y, bdx, num_warps, group_size)
# pylint: disable=too-many-branches
@T.prim_func
@@ -489,24 +489,60 @@ def _attention_sequence_prefill(h_kv, h_q, d, dtype,
target: Target, causal=0, s
-def _attention_sequence_prefill_with_mask(h_kv, h_q, d, dtype, target: Target,
sm_scale=1.0):
- """Tiled sequence prefill kernel with a per-batch right-padding mask.
-
- This is the counterpart of :func:`_attention_sequence_prefill` for batched
- encoder-style inputs where each sample in the batch is padded to a common
- ``seq_len`` but only the first ``valid_lens[b]`` tokens carry real content.
- The kernel takes an extra ``valid_lens`` buffer of shape ``(batch_size,)``
- and applies the padding mask inside the QKV load path and the online
- softmax update, so no explicit mask tensor broadcast or additive bias is
- needed on the host side.
-
- Semantics: for batch ``b``, positions ``[0, valid_lens[b])`` are real and
- positions ``[valid_lens[b], seq_len)`` are padding. Padding queries and
- keys/values are zeroed at load time; padded ``(row, col)`` pairs are
- excluded from the max/sum of the online softmax via a ``-inf`` slot.
+def _attention_sequence_prefill_with_mask(
+ h_kv, h_q, d, dtype, target: Target, sm_scale=1.0, *,
+ mask_mode: Literal["padded", "causal_padded_left"] = "padded",
+):
+ """Tiled sequence prefill kernel with a per-batch padding mask.
+
+ Supports two mask regimes selected by ``mask_mode``:
+
+ * ``"padded"`` (default) — bidirectional attention with right-padding.
+ For batch ``b``, positions ``[0, valid_lens[b])`` are real and
+ positions ``[valid_lens[b], seq_len)`` are padding. This is the
+ encoder-style batch regime.
+ * ``"causal_padded_left"`` — causal attention with left-padding. For
+ batch ``b``, positions ``[seq_len - valid_lens[b], seq_len)`` are
+ real and positions ``[0, seq_len - valid_lens[b])`` are padding;
+ the causal constraint additionally keeps ``col <= row`` within the
+ valid range. This is the decoder-embedding batch regime, where
+ last-token pooling is a cheap slice of the final row.
+
+ In both modes the kernel takes an extra ``valid_lens`` buffer of shape
+ ``(batch_size,)`` and applies the mask inside the QKV load path and the
+ online softmax update, so no explicit mask tensor broadcast or additive
+ bias is needed on the host side. Padding queries and keys/values are
+ zeroed at load time; masked ``(row, col)`` pairs are excluded from the
+ max/sum of the online softmax via a ``-inf`` slot. ``valid_len`` is the
+ per-batch real token count shared by Q and K/V; cross-attention with
+ independent Q/K validity is out of scope.
"""
_, LOAD_VEC, group_size, bdx, num_warps, tile_x, tile_y, tile_z =
_get_prefill_kernel_config(h_kv, h_q, d, dtype, target)
- init_states, compute_s_gemm, _, compute_o_gemm,
softmax_update_valid_length, *_ = _make_prefill_macros(tile_x, tile_y, tile_z,
tile_y, bdx, num_warps, group_size)
+ (
+ init_states, compute_s_gemm, _, compute_o_gemm,
softmax_update_valid_length,
+ _, _, softmax_update_causal_padded_left,
+ ) = _make_prefill_macros(tile_x, tile_y, tile_z, tile_y, bdx, num_warps,
group_size)
+
+ softmax_update = (
+ softmax_update_valid_length
+ if mask_mode == "padded"
+ else softmax_update_causal_padded_left
+ )
+
+ def _q_row_valid(row, valid_len, qo_len):
+ # Row-validity predicate for Q load (TIR expression); mask_mode is
+ # captured at closure time so the prim_func body stays specialised.
+ if mask_mode == "padded":
+ return tirx.And(row < qo_len, row < valid_len)
+ pad = qo_len - valid_len
+ return tirx.And(row < qo_len, row >= pad)
+
+ def _kv_col_valid(col, valid_len, kv_len):
+ # Column-validity predicate for K/V load (TIR expression).
+ if mask_mode == "padded":
+ return tirx.And(col < kv_len, col < valid_len)
+ pad = kv_len - valid_len
+ return tirx.And(col < kv_len, col >= pad)
@T.prim_func
def batch_sequence_prefill_kv_masked( # pylint: disable=too-many-branches
@@ -551,7 +587,7 @@ def _attention_sequence_prefill_with_mask(h_kv, h_q, d,
dtype, target: Target, s
init_states(m_smem, d_smem, O_local, ty, tx)
- # Load Q; padded rows are zeroed so they
contribute nothing downstream.
+ # Load Q; rows outside the valid range are zeroed
so they contribute nothing downstream.
for li, lj in T.grid(tile_x, tile_y):
with T.sblock("Q_load"):
i, j = T.axis.remap("SS", [li, lj])
@@ -559,7 +595,7 @@ def _attention_sequence_prefill_with_mask(h_kv, h_q, d,
dtype, target: Target, s
T.writes()
cur_L = (LH_start + i) // group_size
cur_H_qo = by * group_size + (LH_start +
i) % group_size
- if tirx.And(cur_L < qo_len, cur_L <
valid_len):
+ if _q_row_valid(cur_L, valid_len, qo_len):
Q_smem[i, j] = q[b_idx, cur_L,
cur_H_qo, j]
else:
Q_smem[i, j] = 0.0
@@ -574,7 +610,7 @@ def _attention_sequence_prefill_with_mask(h_kv, h_q, d,
dtype, target: Target, s
T.reads()
T.writes()
cur_L = L_kv_start + i
- if tirx.And(cur_L < kv_len, cur_L <
valid_len):
+ if _kv_col_valid(cur_L, valid_len,
kv_len):
K_smem[i, j] = k[b_idx, L_kv_base
+ cur_L, by, j]
else:
K_smem[i, j] = 0.0
@@ -585,14 +621,14 @@ def _attention_sequence_prefill_with_mask(h_kv, h_q, d,
dtype, target: Target, s
T.reads()
T.writes()
cur_L = L_kv_start + i
- if tirx.And(cur_L < kv_len, cur_L <
valid_len):
+ if _kv_col_valid(cur_L, valid_len,
kv_len):
V_smem[i, j] = v[b_idx, L_kv_base
+ cur_L, by, j]
else:
V_smem[i, j] = 0.0
T.tvm_storage_sync("shared")
compute_s_gemm(Q_smem, K_smem, S_local,
S_smem, sm_scale)
- softmax_update_valid_length(S_smem, m_smem,
d_smem, m_prev_smem, m_new, m_prev, d_new, ty, tx, LH_start, L_kv_start,
valid_len, qo_len)
+ softmax_update(S_smem, m_smem, d_smem,
m_prev_smem, m_new, m_prev, d_new, ty, tx, LH_start, L_kv_start, valid_len,
qo_len, kv_len)
compute_o_gemm(S_smem, V_smem, O_local,
m_prev_smem, m_smem)
# Store O
@@ -741,7 +777,7 @@ def _attention_prefill_ragged_cpu(h_kv, h_q, d_qk, d_v,
dtype, rope_scaling: dic
def _attention_prefill_ragged(h_kv, h_q, d_qk, d_v, dtype, rope_scaling:
dict[str, Any], target: Target):
NUM_BLKS, LOAD_VEC, group_size, bdx, num_warps, tile_x, tile_y, tile_z =
_get_prefill_kernel_config(h_kv, h_q, d_qk, dtype, target)
- init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _,
advance_tile_batch, paged_store_output_lse = _make_prefill_macros(tile_x,
tile_y, tile_z, d_v, bdx, num_warps, group_size)
+ init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _,
advance_tile_batch, paged_store_output_lse, *_ = _make_prefill_macros(tile_x,
tile_y, tile_z, d_v, bdx, num_warps, group_size)
@T.prim_func
def batch_prefill_ragged_kv( # pylint: disable=too-many-branches
@@ -874,7 +910,7 @@ def _attention_prefill_ragged(h_kv, h_q, d_qk, d_v, dtype,
rope_scaling: dict[st
def _attention_prefill_mla(h_q, d_latent, d_rope, dtype, sliding_window: bool,
target: Target, page_size: int = 16):
d_qk = d_latent + d_rope
NUM_BLKS, LOAD_VEC, group_size, bdx, num_warps, tile_x, tile_y, tile_z =
_get_prefill_kernel_config(1, h_q, d_qk, dtype, target)
- init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _,
advance_tile_batch, paged_store_output_lse = _make_prefill_macros(tile_x,
tile_y, tile_z, d_latent, bdx, num_warps, group_size)
+ init_states, compute_s_gemm, softmax_update_causal, compute_o_gemm, _,
advance_tile_batch, paged_store_output_lse, *_ = _make_prefill_macros(tile_x,
tile_y, tile_z, d_latent, bdx, num_warps, group_size)
global_symbol = "batch_prefill_paged_kv_mla"
if sliding_window:
diff --git a/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
index b64ef459d8..51ea992268 100644
--- a/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
+++ b/tests/python/relax/test_frontend_nn_llm_sequence_prefill_masked.py
@@ -16,22 +16,28 @@
# under the License.
"""Focused correctness tests for ``_attention_sequence_prefill_with_mask``.
-The masked variant is the encoder-style counterpart of
-``_attention_sequence_prefill``: each sample in a padded batch carries its
-own ``valid_len`` and the kernel applies the padding mask inside the QKV
-load path and the online softmax update. These tests cover the four shape
-/ mask regimes that can break the kernel independently of any scheduler
+The masked variant supports two regimes selected by ``mask_mode``:
+
+* ``"padded"`` — encoder-style right-padded bidirectional attention.
+* ``"causal_padded_left"`` — decoder-embedding-style left-padded causal
+ attention. Real tokens occupy ``[seq_len - valid_len, seq_len)`` and
+ the causal constraint keeps ``col <= row`` within the valid range.
+
+In both regimes each sample in a padded batch carries its own
+``valid_len`` and the kernel applies the mask inside the QKV load path
+and the online softmax update. These tests cover the four shape / mask
+regimes that can break each kernel independently of any scheduler
tuning:
* ``valid_len == 0`` — entire batch row is padding
* ``valid_len == seq_len`` — full-length row, must match the unmasked kernel
-* mixed ``valid_lens`` — typical encoder batch
+* mixed ``valid_lens`` — typical padded batch
* grouped-query attention — ``h_q > h_kv`` with ``group_size > 1``
-The reference is a float32 NumPy implementation of masked softmax attention
-restricted to the valid prefix, so the kernel is only compared on the
-unpadded positions (padded positions are intentionally free to contain
-arbitrary garbage).
+The references are float32 NumPy implementations of masked softmax
+attention restricted to the valid prefix/suffix, so the kernel is only
+compared on the unpadded positions (padded positions are intentionally
+free to contain arbitrary garbage).
"""
# ruff: noqa: E501
import math
@@ -44,7 +50,7 @@ from tvm.relax.frontend.nn.llm.kv_cache import
_attention_sequence_prefill_with_
def _reference_masked_attention(q, k, v, valid_lens, sm_scale):
- """NumPy fp32 reference. Only the first ``valid_lens[b]`` rows are
written."""
+ """Right-pad bidirectional reference. Only the first ``valid_lens[b]``
rows are written."""
batch, seq_q, h_q, d = q.shape
_, seq_kv, h_kv, _ = k.shape
group_size = h_q // h_kv
@@ -69,7 +75,42 @@ def _reference_masked_attention(q, k, v, valid_lens,
sm_scale):
return out
-def _build_masked_prefill(h_kv, h_q, d, dtype, target):
+def _reference_masked_attention_causal_padded_left(q, k, v, valid_lens,
sm_scale):
+ """Left-pad causal reference.
+
+ Real tokens occupy ``[seq_q - valid_len, seq_q)`` for queries and
+ ``[seq_kv - valid_len, seq_kv)`` for keys/values. Only the valid query
+ suffix rows are written; padded rows stay zeroed.
+ """
+ batch, seq_q, h_q, d = q.shape
+ _, seq_kv, h_kv, _ = k.shape
+ group_size = h_q // h_kv
+ out = np.zeros_like(q, dtype=np.float32)
+ q32 = q.astype(np.float32)
+ k32 = k.astype(np.float32)
+ v32 = v.astype(np.float32)
+ for b in range(batch):
+ L = int(valid_lens[b])
+ if L == 0:
+ continue
+ pad_q = seq_q - L
+ pad_kv = seq_kv - L
+ for h in range(h_q):
+ hk = h // group_size
+ qh = q32[b, pad_q:, h, :] # [L, d]
+ kh = k32[b, pad_kv:, hk, :] # [L, d]
+ vh = v32[b, pad_kv:, hk, :] # [L, d]
+ s = (qh @ kh.T) * sm_scale # [L, L]
+ # Causal on the LxL valid block: mask upper triangle to -inf.
+ s = s + np.triu(np.full((L, L), -np.inf), k=1)
+ m = s.max(axis=-1, keepdims=True)
+ e = np.exp(s - m)
+ p = e / e.sum(axis=-1, keepdims=True)
+ out[b, pad_q:, h, :] = p @ vh
+ return out
+
+
+def _build_masked_prefill(h_kv, h_q, d, dtype, target, mask_mode="padded"):
sm_scale = 1.0 / math.sqrt(d)
tir_func = _attention_sequence_prefill_with_mask(
h_kv=h_kv,
@@ -78,6 +119,7 @@ def _build_masked_prefill(h_kv, h_q, d, dtype, target):
dtype=dtype,
target=target,
sm_scale=sm_scale,
+ mask_mode=mask_mode,
)
mod = tvm.IRModule({"main": tir_func})
return tvm.tirx.build(mod["main"], target=target), sm_scale
@@ -93,17 +135,21 @@ def _run_case(
batch,
seq,
valid_lens,
+ seq_kv=None,
dtype="float16",
seed=0,
+ mask_mode="padded",
):
target = tvm.target.Target(target)
- built, sm_scale = _build_masked_prefill(h_kv, h_q, d, dtype, target)
+ built, sm_scale = _build_masked_prefill(h_kv, h_q, d, dtype, target,
mask_mode=mask_mode)
+ if seq_kv is None:
+ seq_kv = seq
np_dtype = {"float16": np.float16, "float32": np.float32}[dtype]
rng = np.random.default_rng(seed)
q_np = (rng.standard_normal((batch, seq, h_q, d)) * 0.1).astype(np_dtype)
- k_np = (rng.standard_normal((batch, seq, h_kv, d)) * 0.1).astype(np_dtype)
- v_np = (rng.standard_normal((batch, seq, h_kv, d)) * 0.1).astype(np_dtype)
+ k_np = (rng.standard_normal((batch, seq_kv, h_kv, d)) *
0.1).astype(np_dtype)
+ v_np = (rng.standard_normal((batch, seq_kv, h_kv, d)) *
0.1).astype(np_dtype)
valid_np = np.asarray(valid_lens, dtype=np.int32)
out_np = np.zeros((batch, seq, h_q, d), dtype=np_dtype)
lse_np = np.zeros((batch, seq, h_q), dtype=np_dtype)
@@ -118,7 +164,10 @@ def _run_case(
built.main(q_nd, k_nd, v_nd, valid_nd, out_nd, lse_nd)
got = out_nd.numpy().astype(np.float32)
- ref = _reference_masked_attention(q_np, k_np, v_np, valid_np, sm_scale)
+ if mask_mode == "padded":
+ ref = _reference_masked_attention(q_np, k_np, v_np, valid_np, sm_scale)
+ else:
+ ref = _reference_masked_attention_causal_padded_left(q_np, k_np, v_np,
valid_np, sm_scale)
# Only compare valid rows. Padding rows are undefined by design.
rtol, atol = (2e-2, 2e-2) if dtype == "float16" else (1e-4, 1e-4)
@@ -126,7 +175,11 @@ def _run_case(
L = int(valid_np[b])
if L == 0:
continue
- np.testing.assert_allclose(got[b, :L], ref[b, :L], rtol=rtol,
atol=atol)
+ if mask_mode == "padded":
+ np.testing.assert_allclose(got[b, :L], ref[b, :L], rtol=rtol,
atol=atol)
+ else:
+ pad_q = seq - L
+ np.testing.assert_allclose(got[b, pad_q:], ref[b, pad_q:],
rtol=rtol, atol=atol)
@tvm.testing.requires_gpu
@@ -193,5 +246,91 @@ def test_valid_len_mixed_gqa(target, dev):
)
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_causal_padded_left_valid_len_zero(target, dev):
+ """Causal left-pad: all samples are fully padded."""
+ _run_case(
+ target=target,
+ dev=dev,
+ h_kv=4,
+ h_q=4,
+ d=64,
+ batch=2,
+ seq=16,
+ valid_lens=[0, 0],
+ mask_mode="causal_padded_left",
+ )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_causal_padded_left_valid_len_full(target, dev):
+ """Causal left-pad: all samples are fully valid — degenerates to plain
causal attention."""
+ _run_case(
+ target=target,
+ dev=dev,
+ h_kv=4,
+ h_q=4,
+ d=64,
+ batch=2,
+ seq=32,
+ valid_lens=[32, 32],
+ mask_mode="causal_padded_left",
+ )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_causal_padded_left_valid_len_mixed(target, dev):
+ """Causal left-pad: typical decoder-embedding batch with mixed lengths."""
+ _run_case(
+ target=target,
+ dev=dev,
+ h_kv=4,
+ h_q=4,
+ d=64,
+ batch=4,
+ seq=64,
+ valid_lens=[10, 64, 5, 33],
+ mask_mode="causal_padded_left",
+ )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_causal_padded_left_valid_len_mixed_gqa(target, dev):
+ """Causal left-pad: grouped-query attention with mixed lengths."""
+ _run_case(
+ target=target,
+ dev=dev,
+ h_kv=2,
+ h_q=4,
+ d=64,
+ batch=3,
+ seq=32,
+ valid_lens=[8, 32, 17],
+ mask_mode="causal_padded_left",
+ )
+
+
[email protected]_gpu
[email protected]_targets("cuda", "metal")
+def test_causal_padded_left_qo_len_differs_from_kv_len(target, dev):
+ """Causal left-pad: Q and K/V may have different padded lengths."""
+ _run_case(
+ target=target,
+ dev=dev,
+ h_kv=2,
+ h_q=4,
+ d=64,
+ batch=3,
+ seq=32,
+ seq_kv=48,
+ valid_lens=[8, 32, 17],
+ mask_mode="causal_padded_left",
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()