grf53 opened a new pull request, #18630: URL: https://github.com/apache/tvm/pull/18630
## Description This PR improves the sliding window, previously considered only in `_attention_prefil()` and `_attention_decode()`(`CrossAttention`), to also be applied in the `_attention_prefill_ragged()`(`SelfAttention`) function. This allows us to calculate the attention value more properly for the requests with long prefill lengths. ## Background (Note that in TVM, `CrossAttention` refers to the attention of Q in a new seq for K and V already stored in the KV Cache, while `SelfAttention` refers to the self-attention for Q, K, and V within a new seq.) In layers applying a sliding window in models with SWA (e.g., Mistral) or per-layer SWA (e.g., Gemma 3), the `CrossAttention` uses a function with `sliding_window=True`. However, `SelfAttention` is not affected by applying sliding window or not in the current implementation. Since the attention of a new seq is the result of merging `SelfAttention` and `CrossAttention`, if the newly input sequence is longer than the sliding window size, the resulting attention cannot be considered appropriate. In past SWA-based models, their `sliding_window_size`s were often large enough, which may not have been a problem when each prefill input was not long. However, with the advent of per-layer SWA, models using shorter `sliding_window_sizes` (such as gpt-oss with `128`) emerged, and with the prevalence of agentic usage, a single prefill can become very long, which can lead to problems in performance or stability. ## Fixes In those functions with `sliding_window=True` in `CrossAttention`, the length of the KV chunk to be referenced is adjusted with `sliding_window_offset` and `sink_size` additionally provided in `length_info`, thereby applying the sliding window. However, `_attention_prefill_ragged()` in `SelfAttention` must apply a sliding window by applying an appropriate causal mask to the given tensors. Currently, `_causal_mask()` only uses the appropriate upper bound condition, forming a lower triangular matrix. To apply a sliding window, a lower bound condition must be added to form a diagonal band matrix. #### AS-IS `condition = col <= row` #### TO-BE `condition = (col <= row) and (col >= row - window_size)` However, since the `sliding_window_size` value is not provided for the function, we needed to provide it additionally. (I confirmed that using the `sliding_window_offset` value directly is not appropriate.) ## Notes There were some hard-coded values for gemma 3 models added in #17928, and I removed them in this PR. There is a mention of introducing optional parameters for those values, so the way I chose could be not match to the existing plan or intention. Or there could be a more preferred way to add lower bound condition in causal mask, I think. Please kindly let me know if there is a comment. -- This is an automated message from the Apache Git Service. To respond to the message, please log on to GitHub and use the URL above to go to the specific comment. To unsubscribe, e-mail: [email protected] For queries about this service, please contact Infrastructure at: [email protected] --------------------------------------------------------------------- To unsubscribe, e-mail: [email protected] For additional commands, e-mail: [email protected]
