This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2c72b902d7 [Relax][NN] Use int64 for RoPE apply flag (#19430)
2c72b902d7 is described below
commit 2c72b902d746e463d97d6bc48f721c8c499ef04d
Author: Xijing Wang <[email protected]>
AuthorDate: Thu Apr 23 20:31:55 2026 -0400
[Relax][NN] Use int64 for RoPE apply flag (#19430)
This patch aligns the dtype of the `apply_rope` flag used by
`llama_rope_with_position_map` with the host-side value passed through
Relax call_tir.
Previously the PrimFunc declared `apply_rope` as `T.int32`, while the
caller-side scalar value is represented as an int64 Relax PrimValue /
ShapeExpr value. This caused Relax well-formed analysis to reject the IR
with:
Argument N type mismatch: expected R.Prim("int32"), given
R.Prim(value=1)
The mismatch can be reproduced through downstream `nn.Module.export_tvm`
paths such as MLC-LLM `convert_weight` / `compile`.
This change updates:
- `llama_rope_with_position_map`: `apply_rope: T.int32` -> `T.int64`
- `PagedKVCache`: pass the split-rotary flag as `int64_t`
---
python/tvm/relax/frontend/nn/llm/position_embedding.py | 2 +-
src/runtime/vm/paged_kv_cache.cc | 2 +-
2 files changed, 2 insertions(+), 2 deletions(-)
diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index dec80c50c2..cec2ba65dc 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -529,7 +529,7 @@ def llama_rope_with_position_map( # pylint:
disable=too-many-arguments
var_q: T.handle,
var_k: T.handle,
var_v: T.handle,
- apply_rope: T.int32,
+ apply_rope: T.int64,
):
T.func_attr(
{
diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc
index 36f7697237..696f80e0c8 100644
--- a/src/runtime/vm/paged_kv_cache.cc
+++ b/src/runtime/vm/paged_kv_cache.cc
@@ -1337,7 +1337,7 @@ class PagedAttentionKVCacheObj : public
AttentionKVCacheObj {
}
if (!rope_ext_factors_.defined()) {
f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data,
k_data, v_data,
- static_cast<int>(rope_mode_ == RoPEMode::kNormal));
+ static_cast<int64_t>(rope_mode_ == RoPEMode::kNormal));
} else {
f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data,
k_data, v_data,
rope_ext_factors_.value());