gemini-code-assist[bot] commented on code in PR #18630:
URL: https://github.com/apache/tvm/pull/18630#discussion_r2657496028
##########
src/runtime/vm/paged_kv_cache.cc:
##########
@@ -1408,8 +1417,8 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
// The auxiliary data structure on device must have been synchronized.
ICHECK(!dirty_aux_data_device_);
- if (attn_kind == AttnKind::kMHA) {
- MHASelfAttnInternal(q_data, k_data, v_data, o_data, lse_data, sm_scale);
+ if (attn_kind == AttnKind::kMHA || attn_kind == AttnKind::kMHASliding) {
+ MHASelfAttnInternal(layer_id, q_data, k_data, v_data, o_data, lse_data,
sm_scale);
Review Comment:

There appears to be a bug here. `MHASelfAttnInternal` expects a
`local_layer_id`, but it's being called with the global `layer_id`. The
`local_layer_id` is computed on line 1387 but is not used in this call. This
will lead to incorrect indexing of `attn_kinds_` inside `MHASelfAttnInternal`
and potentially other issues. You should pass `local_layer_id` instead.
```suggestion
MHASelfAttnInternal(local_layer_id, q_data, k_data, v_data, o_data,
lse_data, sm_scale);
```
##########
src/runtime/vm/paged_kv_cache.cc:
##########
@@ -766,6 +769,8 @@ class PagedAttentionKVCacheObj : public AttentionKVCacheObj
{
// introduce more sink. Therefore, we update the given attn sink size.
it->second.last_block_attn_sink_size = std::max(attn_sink_size -
prefix_length, 0);
it->second.sliding_window_size = sliding_window_size;
+ if (sliding_window_size_ == -1)
+ sliding_window_size_ = sliding_window_size;
Review Comment:

This logic to set `sliding_window_size_` is a bit fragile. If
`EnableSlidingWindowForSeq` is called for different sequences with different
`sliding_window_size` values, only the first one will take effect, and
subsequent calls with different sizes will be silently ignored. This could lead
to unexpected behavior. To make this assumption explicit and prevent silent
errors, I'd suggest adding a check to ensure that all calls use a consistent
`sliding_window_size`.
```suggestion
if (sliding_window_size_ == -1) {
sliding_window_size_ = sliding_window_size;
} else {
ICHECK_EQ(sliding_window_size_, sliding_window_size)
<< "Inconsistent sliding window sizes are not supported.
Previously got "
<< sliding_window_size_ << ", but now got " << sliding_window_size;
}
```
##########
python/tvm/relax/frontend/nn/llm/kv_cache.py:
##########
@@ -860,10 +860,15 @@ def _var(dtype):
return T.alloc_buffer((1,), dtype, scope="local")
-def _causal_mask(causal, row, col, kv_len, qo_len):
+def _causal_mask(causal, row, col, kv_len, qo_len, sliding_window_size=-1):
+ lower_bound_condition = T.if_then_else(
+ sliding_window_size > 0,
+ col >= kv_len - qo_len + row - sliding_window_size,
Review Comment:

There might be an off-by-one error in the lower bound condition for the
sliding window. The condition `col >= kv_len - qo_len + row -
sliding_window_size` defines a window of `sliding_window_size + 1` tokens. If
the intention is to have a window of exactly `sliding_window_size` tokens, the
condition should be `col >= kv_len - qo_len + row - sliding_window_size + 1`.
Many models like Mistral define the window as inclusive, resulting in
`sliding_window_size + 1` tokens, so this might be correct, but it's worth
confirming to ensure it matches the intended behavior of the models this will
be used with.
--
This is an automated message from the Apache Git Service.
To respond to the message, please log on to GitHub and use the
URL above to go to the specific comment.
To unsubscribe, e-mail: [email protected]
For queries about this service, please contact Infrastructure at:
[email protected]
---------------------------------------------------------------------
To unsubscribe, e-mail: [email protected]
For additional commands, e-mail: [email protected]