This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 2c72b902d7 [Relax][NN] Use int64 for RoPE apply flag (#19430)
2c72b902d7 is described below

commit 2c72b902d746e463d97d6bc48f721c8c499ef04d
Author: Xijing Wang <[email protected]>
AuthorDate: Thu Apr 23 20:31:55 2026 -0400

    [Relax][NN] Use int64 for RoPE apply flag (#19430)
    
    This patch aligns the dtype of the `apply_rope` flag used by
    `llama_rope_with_position_map` with the host-side value passed through
    Relax call_tir.
    
    Previously the PrimFunc declared `apply_rope` as `T.int32`, while the
    caller-side scalar value is represented as an int64 Relax PrimValue /
    ShapeExpr value. This caused Relax well-formed analysis to reject the IR
    with:
    
    Argument N type mismatch: expected R.Prim("int32"), given
    R.Prim(value=1)
    
    The mismatch can be reproduced through downstream `nn.Module.export_tvm`
    paths such as MLC-LLM `convert_weight` / `compile`.
    
    This change updates:
    - `llama_rope_with_position_map`: `apply_rope: T.int32` -> `T.int64`
    - `PagedKVCache`: pass the split-rotary flag as `int64_t`
---
 python/tvm/relax/frontend/nn/llm/position_embedding.py | 2 +-
 src/runtime/vm/paged_kv_cache.cc                       | 2 +-
 2 files changed, 2 insertions(+), 2 deletions(-)

diff --git a/python/tvm/relax/frontend/nn/llm/position_embedding.py 
b/python/tvm/relax/frontend/nn/llm/position_embedding.py
index dec80c50c2..cec2ba65dc 100644
--- a/python/tvm/relax/frontend/nn/llm/position_embedding.py
+++ b/python/tvm/relax/frontend/nn/llm/position_embedding.py
@@ -529,7 +529,7 @@ def llama_rope_with_position_map(  # pylint: 
disable=too-many-arguments
         var_q: T.handle,
         var_k: T.handle,
         var_v: T.handle,
-        apply_rope: T.int32,
+        apply_rope: T.int64,
     ):
         T.func_attr(
             {
diff --git a/src/runtime/vm/paged_kv_cache.cc b/src/runtime/vm/paged_kv_cache.cc
index 36f7697237..696f80e0c8 100644
--- a/src/runtime/vm/paged_kv_cache.cc
+++ b/src/runtime/vm/paged_kv_cache.cc
@@ -1337,7 +1337,7 @@ class PagedAttentionKVCacheObj : public 
AttentionKVCacheObj {
     }
     if (!rope_ext_factors_.defined()) {
       f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, 
k_data, v_data,
-                      static_cast<int>(rope_mode_ == RoPEMode::kNormal));
+                      static_cast<int64_t>(rope_mode_ == RoPEMode::kNormal));
     } else {
       f_split_rotary_(qkv_data_view, q_rope_position_map_view_, q_data, 
k_data, v_data,
                       rope_ext_factors_.value());

Reply via email to