This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new ec3171ab7a [REFACTOR][TIR] Tie 
AnnotateDeviceRegions/SplitHostDevice/LowerDeviceKernelLaunch together (#19605)
ec3171ab7a is described below

commit ec3171ab7a4c06fff4e9c1e441d28ef4e9a5831b
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue May 26 22:10:54 2026 -0400

    [REFACTOR][TIR] Tie 
AnnotateDeviceRegions/SplitHostDevice/LowerDeviceKernelLaunch together (#19605)
    
    ## Summary
    
    These three passes are logically a single host/device split step;
    having intermediaries between them obscures the model and blocks
    folding them into one pass. This PR moves each intermediary to the
    position its actual ordering constraint allows, so that
    `AnnotateDeviceRegions`, `SplitHostDevice`, and
    `LowerDeviceKernelLaunch` run consecutively in every pipeline.
    
    ## Rationale
    
    - `MergeSharedMemoryAllocations` moves **before**
    `AnnotateDeviceRegions`
      (the only legal position: `LowerDeviceKernelLaunch` requires at most
      one dyn-shmem allocation per kernel, so Merge cannot move past Lower).
    - `MakePackedAPI` moves **after** `LowerDeviceKernelLaunch` (Lower's
      `kCallingConv = kDeviceKernelLaunch` flag causes `MakePackedAPI` to
      correctly skip device kernels; the host body's lowered
      `tvm_call_packed` is transparent to `MakePackedAPI`'s subroutine
      rewriter).
    - `FP8StorageLegalize` / `BF16StorageLegalize` move **after**
      `MakePackedAPI` (their `buffer_map.size()==0` ICHECK requires
      `MakePackedAPI` to have cleared the map).
    
    Prereq for Phase 2: collapsing the three consecutive passes into a
    single `tirx.transform.SplitHostDevice` with three commented regions.
    
    ## Test plan
    
    - [x] tests/python/tirx-transform/ target-pass unit tests (25 pass)
    - [x]
    tests/python/s_tir/transform/test_merge_dynamic_shared_memory_allocations.py
    (5 pass)
    - [x] tests/python/tirx-transform/test_tir_transform_fp8_legalize.py /
          test_tir_transform_bf16_legalize.py (13 pass)
    - [x] tests/python/codegen/test_target_codegen_c_host.py /
          test_target_codegen_device.py (6 pass including
          test_subroutine_call — verifies Risk #2)
    - [x] pre-commit run --all-files clean
    - [ ] CI: lint / Windows / MacOS
---
 python/tvm/s_tir/backend/adreno/pipeline.py        |   5 +-
 python/tvm/s_tir/pipeline.py                       |   5 +-
 python/tvm/tirx/compilation_pipeline.py            |   6 +-
 .../transform/merge_shared_memory_allocations.cc   | 447 +++++++++++++--------
 src/tirx/transform/lower_device_kernel_launch.cc   |  91 ++++-
 ...form_merge_dynamic_shared_memory_allocations.py |  95 ++++-
 6 files changed, 434 insertions(+), 215 deletions(-)

diff --git a/python/tvm/s_tir/backend/adreno/pipeline.py 
b/python/tvm/s_tir/backend/adreno/pipeline.py
index 85359b1d35..618970b37e 100644
--- a/python/tvm/s_tir/backend/adreno/pipeline.py
+++ b/python/tvm/s_tir/backend/adreno/pipeline.py
@@ -108,14 +108,13 @@ def default_tir_pipeline():
             passes.append(s_tir.transform.InjectPTXLDG32())
         passes.extend(
             [
+                s_tir.transform.MergeSharedMemoryAllocations(),
                 tirx.transform.AnnotateDeviceRegions(),
                 tirx.transform.SplitHostDevice(),
-                # MergeSharedMemoryAllocations must follow SplitHostDevice.
-                s_tir.transform.MergeSharedMemoryAllocations(),
+                tirx.transform.LowerDeviceKernelLaunch(),
                 tirx.transform.MakePackedAPI(),
                 tirx.transform.FP8StorageLegalize(),
                 tirx.transform.BF16StorageLegalize(),
-                tirx.transform.LowerDeviceKernelLaunch(),
             ]
         )
         mod = tvm.ir.transform.Sequential(passes)(mod)
diff --git a/python/tvm/s_tir/pipeline.py b/python/tvm/s_tir/pipeline.py
index 33a16b381f..a127e43a0e 100644
--- a/python/tvm/s_tir/pipeline.py
+++ b/python/tvm/s_tir/pipeline.py
@@ -108,14 +108,13 @@ def default_s_tir_pipeline():
             passes.append(s_tir.transform.InjectPTXLDG32())
         passes.extend(
             [
+                s_tir.transform.MergeSharedMemoryAllocations(),
                 tirx.transform.AnnotateDeviceRegions(),
                 tirx.transform.SplitHostDevice(),
-                # MergeSharedMemoryAllocations must follow SplitHostDevice.
-                s_tir.transform.MergeSharedMemoryAllocations(),
+                tirx.transform.LowerDeviceKernelLaunch(),
                 tirx.transform.MakePackedAPI(),
                 tirx.transform.FP8StorageLegalize(),
                 tirx.transform.BF16StorageLegalize(),
-                tirx.transform.LowerDeviceKernelLaunch(),
             ]
         )
         mod = tvm.ir.transform.Sequential(passes)(mod)
diff --git a/python/tvm/tirx/compilation_pipeline.py 
b/python/tvm/tirx/compilation_pipeline.py
index 30facc2663..f964f50668 100644
--- a/python/tvm/tirx/compilation_pipeline.py
+++ b/python/tvm/tirx/compilation_pipeline.py
@@ -50,10 +50,10 @@ def default_tir_pipeline():
                 tirx.transform.AnnotateEntryFunc(),
                 tirx.transform.AnnotateDeviceRegions(),
                 tirx.transform.SplitHostDevice(),
+                tirx.transform.LowerDeviceKernelLaunch(),
                 tirx.transform.MakePackedAPI(),
                 tirx.transform.FP8StorageLegalize(),
                 tirx.transform.BF16StorageLegalize(),
-                tirx.transform.LowerDeviceKernelLaunch(),
             ]
         )
         mod = tvm.ir.transform.Sequential(passes)(mod)
@@ -91,10 +91,10 @@ def tirx_pipeline():
                 tirx.transform.AnnotateEntryFunc(),
                 tirx.transform.AnnotateDeviceRegions(),
                 tirx.transform.SplitHostDevice(),
+                tirx.transform.LowerDeviceKernelLaunch(),
                 tirx.transform.MakePackedAPI(),
                 tirx.transform.FP8StorageLegalize(),
                 tirx.transform.BF16StorageLegalize(),
-                tirx.transform.LowerDeviceKernelLaunch(),
             ]
         )
         mod = tvm.ir.transform.Sequential(passes)(mod)
@@ -124,8 +124,8 @@ def trn_pipeline():
             tirx.transform.AnnotateEntryFunc(),
             tirx.transform.AnnotateDeviceRegions(),
             tirx.transform.SplitHostDevice(),
-            tirx.transform.MakePackedAPI(),
             tirx.transform.LowerDeviceKernelLaunch(),
+            tirx.transform.MakePackedAPI(),
         ]
         return tvm.ir.transform.Sequential(passes)(mod)
 
diff --git a/src/s_tir/transform/merge_shared_memory_allocations.cc 
b/src/s_tir/transform/merge_shared_memory_allocations.cc
index c680eb38ac..d1417943c3 100644
--- a/src/s_tir/transform/merge_shared_memory_allocations.cc
+++ b/src/s_tir/transform/merge_shared_memory_allocations.cc
@@ -77,24 +77,26 @@ static int64_t ConstantAllocationSize(const 
ffi::Array<PrimExpr>& extents) {
 }
 
 /*!
- * \brief collect the mapping from the buffer var to its Buffer
+ * \brief collect the mapping from the buffer var to its Buffer within a 
subtree
  */
 class AllocateCollector : public StmtExprVisitor {
  public:
+  explicit AllocateCollector(bool is_dynamic) : is_dynamic_(is_dynamic) {}
+
   void VisitStmt_(const AllocBufferNode* op) final {
-    if (IsDynamicSharedMemory(op->buffer->data) || 
IsStaticSharedMemory(op->buffer->data)) {
-      if (IsDynamicSharedMemory(op->buffer->data)) {
-        dyn_shmem_allocs_[op->buffer->data.get()] = op->buffer;
-      } else {
-        static_shmem_allocs_[op->buffer->data.get()] = op->buffer;
-      }
+    if (is_dynamic_ && IsDynamicSharedMemory(op->buffer->data)) {
+      shmem_allocs_[op->buffer->data.get()] = op->buffer;
+    } else if (!is_dynamic_ && IsStaticSharedMemory(op->buffer->data)) {
+      shmem_allocs_[op->buffer->data.get()] = op->buffer;
     }
     StmtExprVisitor::VisitStmt_(op);
   }
-  // The dynamic mapping from the original buffer var to its Buffer
-  std::unordered_map<const VarNode*, Buffer> dyn_shmem_allocs_;
-  // The static mapping from the original buffer var to its Buffer
-  std::unordered_map<const VarNode*, Buffer> static_shmem_allocs_;
+
+  // The mapping from the original buffer var to its Buffer
+  std::unordered_map<const VarNode*, Buffer> shmem_allocs_;
+
+ private:
+  bool is_dynamic_;
 };
 
 // Find a linear pattern of storage access
@@ -274,89 +276,131 @@ class SharedMemLinearAccessPatternFinder final : public 
StmtExprVisitor {
 
 /*!
  * \brief merge the buffers whose live range has no intersection and rewrite 
the body
+ *
+ * Uses a scope-stack design: each thread_extent block (kernel launch) gets its
+ * own KernelScope that owns the merged buffer var and all per-launch 
bookkeeping.
+ * This correctly handles PrimFuncs with multiple sibling thread_extent blocks.
  */
 class SharedMemoryRewriter : public StmtExprMutator {
  public:
-  explicit SharedMemoryRewriter(const std::unordered_map<const VarNode*, 
Buffer>& shmem_allocs,
-                                bool is_dynamic = true)
-      : is_dynamic_{is_dynamic}, shmem_allocs_{shmem_allocs} {
-    if (!is_dynamic) {
-      merged_buf_var_ = Var("buf_shmem", 
PointerType(PrimType(DataType::UInt(8)), "shared"));
-    }
-  }
+  explicit SharedMemoryRewriter(bool is_dynamic = true) : 
is_dynamic_{is_dynamic} {}
+
+ private:
+  using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
+
+  struct StorageEntry {
+    // The constant size of the buffer in bits, only used if it is constant
+    uint64_t const_nbits{0};
+    // Allocs that shares this entry.
+    // The inner vector means a "layer"
+    // For example, it we need to allocate C in the memory of A and B:
+    // |  A: 4096 bytes |  B: 4096 bytes |
+    // |            C: 8192 bytes        |
+    // Then the allocs = {{A, B}, {C}}
+    std::vector<std::vector<const VarNode*>> allocs;
+  };
+
+  // Event entry in liveness analysis
+  struct EventEntry {
+    // variables we generate
+    std::vector<const VarNode*> gen;
+    // variables we kill
+    std::vector<const VarNode*> kill;
+  };
 
   /*!
-   * \brief plan the memory reuse for all the buffer allocated in the statement
-   * \param stmt the statement
+   * \brief Per-kernel-launch scope holding all state for one thread_extent 
block.
    */
-  void PlanReuse(const Stmt& stmt, bool is_dynamic = true) {
-    SharedMemLinearAccessPatternFinder finder(is_dynamic);
-    finder(stmt);
-    this->LivenessAnalysis(finder.linear_seq_);
-    this->PlanMemory(finder.linear_seq_);
+  struct KernelScope {
+    // The merged buffer var for THIS kernel launch.
+    Var merged_buf_var;
+    // Total byte size of THIS kernel's merged buffer.
+    PrimExpr merged_alloc_size{0};
+    // Allocations from THIS kernel's subtree.
+    std::unordered_map<const VarNode*, Buffer> shmem_allocs;
+    // Per-buffer byte offset into merged_buf_var.
+    std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets;
+    // Buffer-object remap: original Buffer -> merged-data-var Buffer.
+    std::unordered_map<const BufferNode*, Buffer> buffer_remap;
+    // Has any original alloc in this scope been marked volatile?
+    bool has_volatile_alloc{false};
+    // Liveness data (event_map, alloc_map, const_free_map, sym_free_list) — 
all per-scope.
+    std::unordered_map<const ffi::Object*, EventEntry> event_map;
+    std::multimap<uint64_t, StorageEntry*> const_free_map;
+    std::list<StorageEntry*> sym_free_list;
+    std::unordered_map<const VarNode*, StorageEntry*> alloc_map;
+  };
+
+  /*!
+   * \brief Create a fresh merged buffer Var for a new kernel scope.
+   *        Same name string is fine — Var identity is by pointer, not name.
+   */
+  Var MakeMergedBufferVar() {
+    if (is_dynamic_) {
+      return Var("buf_dyn_shmem", PointerType(PrimType(DataType::UInt(8)), 
"shared.dyn"));
+    } else {
+      return Var("buf_shmem", PointerType(PrimType(DataType::UInt(8)), 
"shared"));
+    }
   }
 
- private:
   Stmt VisitStmt_(const AttrStmtNode* op) final {
-    if (op->attr_key == tirx::attr::thread_extent && !allocated_) {
-      // Allocate one dynamic shared memory allocation at the beginning of 
thread scope
-      int max_layer_num = 0;
-      std::vector<const StorageEntry*> all_entry;
-      for (const auto& e : const_free_map_) {
-        all_entry.push_back(e.second);
-      }
-      for (const StorageEntry* e : sym_free_list_) {
-        all_entry.push_back(e);
-      }
-      for (const StorageEntry* e : all_entry) {
-        max_layer_num = std::max(max_layer_num, 
static_cast<int>(e->allocs.size()));
-      }
-      // calculate align for each layer of each storage entry.
-      std::vector<int> align(max_layer_num, 0);
-      for (const StorageEntry* e : all_entry) {
-        for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
-          for (const VarNode* buffer : e->allocs[i]) {
-            const Buffer& buf = shmem_allocs_.at(buffer);
-            align[i] = std::max(align[i], buf->dtype.bytes());
-          }
-        }
-      }
-      // calculate offset for each buffer based on the align of each layer
-      for (const StorageEntry* e : all_entry) {
-        PrimExpr max_inner_offset = 0;
-        for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
-          PrimExpr inner_offset = 0;
-          for (const VarNode* buffer : e->allocs[i]) {
-            const Buffer& buf = shmem_allocs_.at(buffer);
-            ffi::Array<PrimExpr> alloc_shape = GetBufferAllocationShape(buf);
-            int align_bytes = std::max(align[i], buf->dtype.bytes());
-            if (buf->data_alignment > 0) {
-              TVM_FFI_ICHECK(buf->data_alignment % align_bytes == 0)
-                  << "The alignment of the buffer is not a multiple of the 
data type size.";
-              align_bytes = buf->data_alignment;
-            }
-            PrimExpr buffer_bytes = alloc_shape[0] * buf->dtype.bytes();
-            inner_offset +=
-                indexmod(align_bytes - indexmod(merged_alloc_size_ + 
inner_offset, align_bytes),
-                         align_bytes);
-            buffer_byte_offsets_[buffer] = merged_alloc_size_ + inner_offset;
-            inner_offset += buffer_bytes;
-          }
-          max_inner_offset = max(max_inner_offset, inner_offset);
-        }
-        merged_alloc_size_ += max_inner_offset;
+    if (op->attr_key == tirx::attr::thread_extent && !in_thread_env_) {
+      in_thread_env_ = true;
+
+      // 1. Push a fresh scope.
+      scope_stack_.emplace_back();
+      KernelScope& scope = scope_stack_.back();
+      scope.merged_buf_var = MakeMergedBufferVar();
+
+      // 2. Collect shmem allocs that belong to THIS subtree.
+      AllocateCollector collector(is_dynamic_);
+      collector(op->body);
+      scope.shmem_allocs = std::move(collector.shmem_allocs_);
+
+      // Per-scope early bail-out: if this thread_extent block has ≤1 shmem
+      // allocation, there is nothing to merge.  Skip liveness analysis,
+      // memory planning, and rewriting entirely.
+      if (scope.shmem_allocs.size() <= 1) {
+        scope_stack_.pop_back();
+        in_thread_env_ = false;
+        return StmtExprMutator::VisitStmt_(op);
       }
 
-      allocated_ = true;
-      Buffer merged_buf(merged_buf_var_, DataType::UInt(8), 
{merged_alloc_size_}, {}, PrimExpr(),
-                        merged_buf_var_->name_hint, 0, 0, 
BufferType::kDefault);
+      // 3. Liveness + reuse plan over this subtree only.
+      // Run the finder on the full AttrStmt (not just op->body) so that
+      // VisitNewScope creates the proper scope pair entry for the 
thread_extent.
+      SharedMemLinearAccessPatternFinder finder(is_dynamic_);
+      finder(ffi::GetRef<Stmt>(op));
+      this->LivenessAnalysis(finder.linear_seq_, scope);
+      this->PlanMemory(finder.linear_seq_, scope);
+
+      // 4. Compute byte offsets / merged_alloc_size.
+      this->ComputeOffsets(scope);
+
+      // 5. Recursively mutate the body — reads scope_stack_.back() for all 
rewrites.
       Stmt visited_body = StmtExprMutator::VisitStmt(op->body);
+
+      in_thread_env_ = false;
+
+      // 6. If this scope has no shmem allocs, skip the wrapper.
+      if (scope.shmem_allocs.empty()) {
+        scope_stack_.pop_back();
+        return AttrStmt(op->node, op->attr_key, op->value, visited_body, 
op->span);
+      }
+
+      // 7. Wrap with the merged-buffer AllocBuffer.
+      Buffer merged_buf(scope.merged_buf_var, DataType::UInt(8), 
{scope.merged_alloc_size}, {},
+                        PrimExpr(), scope.merged_buf_var->name_hint, 0, 0, 
BufferType::kDefault);
       ffi::Map<ffi::String, ffi::Any> annotations;
-      if (has_volatile_alloc_) {
+      if (scope.has_volatile_alloc) {
         annotations.Set(tirx::attr::kVolatile, true);
       }
       Stmt alloc_stmt = AllocBuffer(merged_buf, annotations);
       Stmt new_body = SeqStmt::Flatten(alloc_stmt, visited_body);
+
+      // 8. Pop the scope.
+      scope_stack_.pop_back();
+
       return AttrStmt(op->node, op->attr_key, op->value, new_body, op->span);
     }
     return StmtMutator::VisitStmt_(op);
@@ -364,10 +408,17 @@ class SharedMemoryRewriter : public StmtExprMutator {
 
   Stmt VisitStmt_(const AllocBufferNode* op) final {
     if (IsAppropriateSharedMemory(op->buffer->data)) {
-      if (op->annotations.count(tirx::attr::kVolatile)) {
-        has_volatile_alloc_ = true;
+      if (!scope_stack_.empty()) {
+        KernelScope& scope = scope_stack_.back();
+        if (scope.shmem_allocs.count(op->buffer->data.get())) {
+          if (op->annotations.count(tirx::attr::kVolatile)) {
+            scope.has_volatile_alloc = true;
+          }
+          return Evaluate(0);
+        }
       }
-      return Evaluate(0);
+      // Outside any thread_extent scope — leave as-is.
+      return StmtExprMutator::VisitStmt_(op);
     }
     return StmtExprMutator::VisitStmt_(op);
   }
@@ -392,7 +443,8 @@ class SharedMemoryRewriter : public StmtExprMutator {
 
   template <typename Node>
   Node VisitBufferAccess(Node node) {
-    if (IsAppropriateSharedMemory(node->buffer->data)) {
+    if (IsAppropriateSharedMemory(node->buffer->data) && !scope_stack_.empty() 
&&
+        scope_stack_.back().shmem_allocs.count(node->buffer->data.get())) {
       TVM_FFI_ICHECK_EQ(node->indices.size(), 1)
           << "MergeSharedMemoryAllocations expects flat memory buffers, "
           << "and is to be run after "
@@ -409,9 +461,13 @@ class SharedMemoryRewriter : public StmtExprMutator {
   }
 
   Buffer GetUpdatedBuffer(Buffer buffer) {
+    if (scope_stack_.empty()) return buffer;
+    KernelScope& scope = scope_stack_.back();
+    if (!scope.shmem_allocs.count(buffer->data.get())) return buffer;
+
     auto key = buffer.get();
-    auto it = buffer_remap_.find(key);
-    if (it != buffer_remap_.end()) {
+    auto it = scope.buffer_remap.find(key);
+    if (it != scope.buffer_remap.end()) {
       return it->second;
     }
 
@@ -422,10 +478,10 @@ class SharedMemoryRewriter : public StmtExprMutator {
           << "and is to be run after "
           << "FlattenBuffer";
       auto writer = buffer.CopyOnWrite();
-      writer->data = merged_buf_var_;
+      writer->data = scope.merged_buf_var;
     }
 
-    buffer_remap_[key] = buffer;
+    scope.buffer_remap[key] = buffer;
     return buffer;
   }
 
@@ -434,7 +490,8 @@ class SharedMemoryRewriter : public StmtExprMutator {
       TVM_FFI_ICHECK_EQ(op->args.size(), 5U);
       DataType dtype = op->args[0].dtype();
       Var buffer = Downcast<Var>(op->args[1]);
-      if (!IsAppropriateSharedMemory(buffer)) {
+      if (!IsAppropriateSharedMemory(buffer) || scope_stack_.empty() ||
+          !scope_stack_.back().shmem_allocs.count(buffer.get())) {
         return StmtExprMutator::VisitExpr_(op);
       }
       PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
@@ -442,7 +499,8 @@ class SharedMemoryRewriter : public StmtExprMutator {
       PrimExpr offset = this->VisitExpr(op->args[2]);
       PrimExpr extent = this->VisitExpr(op->args[3]);
       return Call(op->dtype, op->op,
-                  {op->args[0], merged_buf_var_, extra_offset + offset, 
extent, op->args[4]});
+                  {op->args[0], scope_stack_.back().merged_buf_var, 
extra_offset + offset, extent,
+                   op->args[4]});
     } else if (op->op.same_as(builtin::ptx_cp_async())) {
       TVM_FFI_ICHECK((op->args.size() == 5U) || (op->args.size() == 6U));
       Var buffer = Downcast<Var>(op->args[0]);
@@ -451,7 +509,8 @@ class SharedMemoryRewriter : public StmtExprMutator {
       const auto* prim_type = ptr_type->element_type.as<PrimTypeNode>();
       TVM_FFI_ICHECK(prim_type) << "The buffer should be a pointer to a 
primitive type.";
       DataType dtype = DataType(prim_type->dtype);
-      if (!IsAppropriateSharedMemory(buffer)) {
+      if (!IsAppropriateSharedMemory(buffer) || scope_stack_.empty() ||
+          !scope_stack_.back().shmem_allocs.count(buffer.get())) {
         return StmtExprMutator::VisitExpr_(op);
       }
       PrimExpr extra_offset = GetBufferOffset(buffer, dtype);
@@ -461,21 +520,25 @@ class SharedMemoryRewriter : public StmtExprMutator {
       // the correct offset of merged shared buffer.
       int index_factor = dtype.bytes();
       if (op->args.size() == 5)
-        return Call(dtype, op->op,
-                    {merged_buf_var_, mul(extra_offset + offset, 
PrimExpr(index_factor)),
-                     op->args[2], op->args[3], op->args[4]});
+        return Call(
+            dtype, op->op,
+            {scope_stack_.back().merged_buf_var, mul(extra_offset + offset, 
PrimExpr(index_factor)),
+             op->args[2], op->args[3], op->args[4]});
       else
-        return Call(dtype, op->op,
-                    {merged_buf_var_, mul(extra_offset + offset, 
PrimExpr(index_factor)),
-                     op->args[2], op->args[3], op->args[4], op->args[5]});
+        return Call(
+            dtype, op->op,
+            {scope_stack_.back().merged_buf_var, mul(extra_offset + offset, 
PrimExpr(index_factor)),
+             op->args[2], op->args[3], op->args[4], op->args[5]});
     } else {
       return StmtExprMutator::VisitExpr_(op);
     }
   }
 
   PrimExpr GetBufferOffset(Var buffer_var, DataType dtype) {
-    auto it = buffer_byte_offsets_.find(buffer_var.get());
-    TVM_FFI_ICHECK(it != buffer_byte_offsets_.end());
+    TVM_FFI_ICHECK(!scope_stack_.empty());
+    KernelScope& scope = scope_stack_.back();
+    auto it = scope.buffer_byte_offsets.find(buffer_var.get());
+    TVM_FFI_ICHECK(it != scope.buffer_byte_offsets.end());
     return indexdiv(it->second, dtype.bytes());
   }
 
@@ -484,32 +547,12 @@ class SharedMemoryRewriter : public StmtExprMutator {
     return is_dynamic_ ? IsDynamicSharedMemory(var) : 
IsStaticSharedMemory(var);
   }
 
-  using StmtEntry = SharedMemLinearAccessPatternFinder::StmtEntry;
-  struct StorageEntry {
-    // The constant size of the buffer in bits, only used if it is constant
-    uint64_t const_nbits{0};
-    // Allocs that shares this entry.
-    // The inner vector means a "layer"
-    // For example, it we need to allocate C in the memory of A and B:
-    // |  A: 4096 bytes |  B: 4096 bytes |
-    // |            C: 8192 bytes        |
-    // Then the allocs = {{A, B}, {C}}
-    std::vector<std::vector<const VarNode*>> allocs;
-  };
-
-  // Event entry in liveness analysis
-  struct EventEntry {
-    // variables we generate
-    std::vector<const VarNode*> gen;
-    // variables we kill
-    std::vector<const VarNode*> kill;
-  };
-
   /*!
    * \brief Liveness analysis to find gen and kill point of each variable.
    * \param seq the linear pattern of storage access
+   * \param scope the kernel scope to write results into
    */
-  void LivenessAnalysis(const std::vector<StmtEntry>& seq) {
+  void LivenessAnalysis(const std::vector<StmtEntry>& seq, KernelScope& scope) 
{
     // find kill point, do a reverse linear scan.
     std::unordered_set<const VarNode*> touched;
     for (size_t i = seq.size(); i != 0; --i) {
@@ -517,7 +560,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
       for (const VarNode* buffer : s.touched) {
         if (!touched.count(buffer)) {
           touched.insert(buffer);
-          event_map_[s.stmt].kill.push_back(buffer);
+          scope.event_map[s.stmt].kill.push_back(buffer);
         }
       }
     }
@@ -530,7 +573,7 @@ class SharedMemoryRewriter : public StmtExprMutator {
       for (const VarNode* buffer : s.touched) {
         if (!touched.count(buffer)) {
           touched.insert(buffer);
-          event_map_[s.stmt].gen.push_back(buffer);
+          scope.event_map[s.stmt].gen.push_back(buffer);
         }
       }
     }
@@ -539,12 +582,13 @@ class SharedMemoryRewriter : public StmtExprMutator {
   /*!
    * \brief Memory plan algorithm
    * \param seq the linear pattern of storage access
+   * \param scope the kernel scope to write results into
    */
-  void PlanMemory(const std::vector<StmtEntry>& seq) {
+  void PlanMemory(const std::vector<StmtEntry>& seq, KernelScope& scope) {
     std::unordered_set<const VarNode*> inplace_flag;
 
     for (size_t i = 0; i < seq.size(); ++i) {
-      auto it = event_map_.find(seq[i].stmt);
+      auto it = scope.event_map.find(seq[i].stmt);
       // scope_pair_offset <= 0 means it is either
       // - leaf stmt(offset = 0)
       // - end of scope(offset < 0)
@@ -553,30 +597,84 @@ class SharedMemoryRewriter : public StmtExprMutator {
         return seq[i].scope_pair_offset == 0 &&
                std::find(it->second.gen.begin(), it->second.gen.end(), var) != 
it->second.gen.end();
       };
-      if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
+      if (it != scope.event_map.end() && seq[i].scope_pair_offset <= 0) {
         for (const VarNode* var : it->second.kill) {
-          if (!is_leaf_alloc(var)) this->Free(var);
+          if (!is_leaf_alloc(var)) this->Free(var, scope);
         }
       }
       // scope_pair_offset >= 0 means it is either
       // - leaf stmt(offset = 0)
       // - beginning of scope(offset < 0)
       // In both cases, we need to handle the gen event correctly
-      if (it != event_map_.end() && seq[i].scope_pair_offset >= 0) {
+      if (it != scope.event_map.end() && seq[i].scope_pair_offset >= 0) {
         for (const VarNode* var : it->second.gen) {
-          TVM_FFI_ICHECK(shmem_allocs_.count(var));
-          const Buffer& buf = shmem_allocs_.at(var);
-          StorageEntry* dst_entry = FindAlloc(buf);
-          alloc_map_[var] = dst_entry;
+          TVM_FFI_ICHECK(scope.shmem_allocs.count(var));
+          const Buffer& buf = scope.shmem_allocs.at(var);
+          StorageEntry* dst_entry = FindAlloc(buf, scope);
+          scope.alloc_map[var] = dst_entry;
         }
       }
-      if (it != event_map_.end() && seq[i].scope_pair_offset <= 0) {
+      if (it != scope.event_map.end() && seq[i].scope_pair_offset <= 0) {
         for (const VarNode* var : it->second.kill) {
-          if (is_leaf_alloc(var)) this->Free(var);
+          if (is_leaf_alloc(var)) this->Free(var, scope);
+        }
+      }
+    }
+  }
+
+  /*!
+   * \brief Compute byte offsets for all entries in the scope after PlanMemory.
+   * \param scope the kernel scope whose offset map to fill
+   */
+  void ComputeOffsets(KernelScope& scope) {
+    int max_layer_num = 0;
+    std::vector<const StorageEntry*> all_entry;
+    for (const auto& e : scope.const_free_map) {
+      all_entry.push_back(e.second);
+    }
+    for (const StorageEntry* e : scope.sym_free_list) {
+      all_entry.push_back(e);
+    }
+    for (const StorageEntry* e : all_entry) {
+      max_layer_num = std::max(max_layer_num, 
static_cast<int>(e->allocs.size()));
+    }
+    // calculate align for each layer of each storage entry.
+    std::vector<int> align(max_layer_num, 0);
+    for (const StorageEntry* e : all_entry) {
+      for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
+        for (const VarNode* buffer : e->allocs[i]) {
+          const Buffer& buf = scope.shmem_allocs.at(buffer);
+          align[i] = std::max(align[i], buf->dtype.bytes());
         }
       }
     }
+    // calculate offset for each buffer based on the align of each layer
+    for (const StorageEntry* e : all_entry) {
+      PrimExpr max_inner_offset = 0;
+      for (int i = 0; i < static_cast<int>(e->allocs.size()); i++) {
+        PrimExpr inner_offset = 0;
+        for (const VarNode* buffer : e->allocs[i]) {
+          const Buffer& buf = scope.shmem_allocs.at(buffer);
+          ffi::Array<PrimExpr> alloc_shape = GetBufferAllocationShape(buf);
+          int align_bytes = std::max(align[i], buf->dtype.bytes());
+          if (buf->data_alignment > 0) {
+            TVM_FFI_ICHECK(buf->data_alignment % align_bytes == 0)
+                << "The alignment of the buffer is not a multiple of the data 
type size.";
+            align_bytes = buf->data_alignment;
+          }
+          PrimExpr buffer_bytes = alloc_shape[0] * buf->dtype.bytes();
+          inner_offset +=
+              indexmod(align_bytes - indexmod(scope.merged_alloc_size + 
inner_offset, align_bytes),
+                       align_bytes);
+          scope.buffer_byte_offsets[buffer] = scope.merged_alloc_size + 
inner_offset;
+          inner_offset += buffer_bytes;
+        }
+        max_inner_offset = max(max_inner_offset, inner_offset);
+      }
+      scope.merged_alloc_size = scope.merged_alloc_size + max_inner_offset;
+    }
   }
+
   /*!
    * \brief Allocate new storage entry.
    * \param buf the buffer object
@@ -590,12 +688,14 @@ class SharedMemoryRewriter : public StmtExprMutator {
     entry->const_nbits = const_nbits;
     return entry;
   }
+
   /*!
    * \brief find the storage entry in the free list for the buffer
    * \param buf the buffer object
+   * \param scope the kernel scope whose free lists to search
    * \return the storage entry
    */
-  StorageEntry* FindAlloc(const Buffer& buf) {
+  StorageEntry* FindAlloc(const Buffer& buf, KernelScope& scope) {
     // skip plan for local variable,
     // compiler can do a better job with register allocation.
     const uint64_t match_range = 16;
@@ -611,17 +711,17 @@ class SharedMemoryRewriter : public StmtExprMutator {
 
     if (const_nbits != 0) {
       // constant allocation.
-      auto begin = const_free_map_.lower_bound(0);
-      auto mid = const_free_map_.lower_bound(const_nbits);
-      auto end = const_free_map_.upper_bound(const_nbits * match_range);
+      auto begin = scope.const_free_map.lower_bound(0);
+      auto mid = scope.const_free_map.lower_bound(const_nbits);
+      auto end = scope.const_free_map.upper_bound(const_nbits * match_range);
       // Start looking at the buffer that is bigger than the required size 
first.
       // If we find one, directly allocate the buffer in its location and 
remove its entry in the
       // free list
       for (auto it = mid; it != end; ++it) {
         StorageEntry* e = it->second;
         e->const_nbits = std::max(const_nbits, e->const_nbits);
-        const_free_map_.erase(it);
-        it->second->allocs.push_back({buf->data.get()});
+        scope.const_free_map.erase(it);
+        e->allocs.push_back({buf->data.get()});
         return e;
       }
       // Then start looking at smaller buffers.
@@ -654,16 +754,16 @@ class SharedMemoryRewriter : public StmtExprMutator {
         e->const_nbits = std::max(const_nbits, mem_ct);
         e->allocs = reuse_allocs;
         for (auto it : delete_it) {
-          const_free_map_.erase(it);
+          scope.const_free_map.erase(it);
         }
         return e;
       }
     } else {
       // if its symbolic allocation, just arbitrarily choose one entry to fit 
in because we don't
       // know its actual size
-      for (auto it = sym_free_list_.begin(); it != sym_free_list_.end(); ++it) 
{
+      for (auto it = scope.sym_free_list.begin(); it != 
scope.sym_free_list.end(); ++it) {
         StorageEntry* e = *it;
-        sym_free_list_.erase(it);
+        scope.sym_free_list.erase(it);
         return e;
       }
     }
@@ -673,10 +773,11 @@ class SharedMemoryRewriter : public StmtExprMutator {
   /*!
    * \brief add the storage entry to the buffer var into the free list.
    * \param var the buffer var
+   * \param scope the kernel scope whose free lists to update
    */
-  void Free(const VarNode* var) {
-    auto it = alloc_map_.find(var);
-    TVM_FFI_ICHECK(it != alloc_map_.end());
+  void Free(const VarNode* var, KernelScope& scope) {
+    auto it = scope.alloc_map.find(var);
+    TVM_FFI_ICHECK(it != scope.alloc_map.end());
     StorageEntry* e = it->second;
     TVM_FFI_ICHECK_NE(e->allocs.size(), 0U);
 
@@ -685,51 +786,41 @@ class SharedMemoryRewriter : public StmtExprMutator {
 
     // normal free.
     if (e->const_nbits != 0) {
-      const_free_map_.insert({e->const_nbits, e});
+      scope.const_free_map.insert({e->const_nbits, e});
     } else {
-      sym_free_list_.push_back(e);
+      scope.sym_free_list.push_back(e);
     }
   }
+
   // Whether enable dynamic analysis.
   bool is_dynamic_{true};
-  // The var for the merged buffer
-  Var merged_buf_var_{"buf_dyn_shmem", 
PointerType(PrimType(DataType::UInt(8)), "shared.dyn")};
-  // The mapping from the original buffer var to its Buffer
-  std::unordered_map<const VarNode*, Buffer> shmem_allocs_;
-  // The size of the merged buffer
-  PrimExpr merged_alloc_size_{0};
-  // The mapping from the original buffer var to its offset in the merged 
buffer
-  std::unordered_map<const VarNode*, PrimExpr> buffer_byte_offsets_;
-  // The mapping from the original buffer objects to their location in the 
merged buffer.
-  std::unordered_map<const BufferNode*, Buffer> buffer_remap_;
-  // The flag indicating whether the merged buffer has been allocated
-  bool allocated_{false};
-  // Whether any original shared memory allocation had the volatile annotation
-  bool has_volatile_alloc_{false};
-  // Locations of free ops.
-  std::unordered_map<const ffi::Object*, EventEntry> event_map_;
-  // constant size free map.
-  std::multimap<uint64_t, StorageEntry*> const_free_map_;
-  // symbolic free list, for non constant items.
-  std::list<StorageEntry*> sym_free_list_;
-  // The allocation assign map
-  std::unordered_map<const VarNode*, StorageEntry*> alloc_map_;
-  /*! \brief allocator of all the StorageEntry*/
+  // Whether already inside a thread_extent (outermost only).
+  bool in_thread_env_{false};
+  // Stack of per-kernel-launch scopes. Pushed on thread_extent entry, popped 
on exit.
+  std::vector<KernelScope> scope_stack_;
+  /*! \brief allocator of all the StorageEntry (shared across all scopes) */
   support::Arena arena_;
 };
 
 Stmt MergeSharedMemoryAllocations(Stmt stmt, bool merge_static_smem) {
-  AllocateCollector collector;
-  collector(stmt);
-  if (collector.dyn_shmem_allocs_.size() > 1) {
-    SharedMemoryRewriter rewriter(collector.dyn_shmem_allocs_);
-    rewriter.PlanReuse(stmt);
-    stmt = rewriter(std::move(stmt));
+  // Function-level early-out: skip the rewriter entirely if the PrimFunc
+  // has ≤1 dynamic shared-memory allocation (nothing to merge).
+  {
+    AllocateCollector dyn_probe(/*is_dynamic=*/true);
+    dyn_probe(stmt);
+    if (dyn_probe.shmem_allocs_.size() > 1) {
+      SharedMemoryRewriter dyn_rewriter(/*is_dynamic=*/true);
+      stmt = dyn_rewriter(std::move(stmt));
+    }
   }
-  if (merge_static_smem && collector.static_shmem_allocs_.size() > 1) {
-    SharedMemoryRewriter rewriter(collector.static_shmem_allocs_, false);
-    rewriter.PlanReuse(stmt, false);
-    stmt = rewriter(std::move(stmt));
+  if (merge_static_smem) {
+    // Similarly skip the static rewriter if there is ≤1 static shmem alloc.
+    AllocateCollector static_probe(/*is_dynamic=*/false);
+    static_probe(stmt);
+    if (static_probe.shmem_allocs_.size() > 1) {
+      SharedMemoryRewriter static_rewriter(/*is_dynamic=*/false);
+      stmt = static_rewriter(std::move(stmt));
+    }
   }
   return stmt;
 }
diff --git a/src/tirx/transform/lower_device_kernel_launch.cc 
b/src/tirx/transform/lower_device_kernel_launch.cc
index 9b38c4d629..af30af6bfb 100644
--- a/src/tirx/transform/lower_device_kernel_launch.cc
+++ b/src/tirx/transform/lower_device_kernel_launch.cc
@@ -213,6 +213,21 @@ class DeviceKernelMutator : public StmtExprMutator {
     auto it = device_info_map_.find(gvar.get());
     TVM_FFI_ICHECK(it != device_info_map_.end());
     current_target_ = it->second.target;
+    // Track whether the caller is a host function (i.e. its target
+    // still has a host attached) and capture its host target.  The
+    // same-target shortcut at the call site is only safe when caller
+    // and callee are both device-resident; a host caller must take
+    // the kernel-launch path even if Target::WithoutHost() makes the
+    // strings match.  Conversely, a host caller invoking another host
+    // helper (e.g. a same-target subroutine that SplitHostDevice
+    // emitted on the host side) should compare against the host
+    // target, not the device target stripped by WithoutHost().
+    auto full_target = func->GetAttr<Target>(tvm::attr::kTarget).value();
+    if (full_target->GetHost().defined()) {
+      current_caller_host_target_ = full_target->GetHost().value();
+    } else {
+      current_caller_host_target_ = std::nullopt;
+    }
 
     auto body = VisitStmt(func->body);
     if (!body.same_as(func->body)) {
@@ -220,6 +235,7 @@ class DeviceKernelMutator : public StmtExprMutator {
     }
 
     current_target_ = std::nullopt;
+    current_caller_host_target_ = std::nullopt;
     return func;
   }
 
@@ -272,29 +288,59 @@ class DeviceKernelMutator : public StmtExprMutator {
         << gvar->name_hint << " did not appear within the IRModule";
     const KernelInfo& dev_info = it->second;
 
-    auto caller_target = current_target_.value();
     auto callee_target = dev_info.target;
 
-    bool same_target = caller_target->str() == callee_target->str();
-    if (same_target) {
-      // Calls within the same target may be handled at codegen time
-      // as internal subroutine calls.
-      return node;
-    }
+    // A callee with non-empty launch_params has thread_extent
+    // bindings in its body, i.e. it is a real device kernel that
+    // must be invoked via a kernel-launch ABI.  Conversely a callee
+    // with empty launch_params is a plain subroutine (host helper
+    // or intra-device helper) and is never invoked via kernel launch.
+    bool callee_is_kernel = dev_info.launch_params.size() > 0;
+    bool caller_is_host = current_caller_host_target_.has_value();
+
+    // For host callers, comparisons against the callee target must
+    // use the caller's *host* target, not the device target stripped
+    // by WithoutHost().  This handles two cases that the device-side
+    // comparison gets wrong:
+    //   1. A host caller invoking a real device kernel whose
+    //      WithoutHost() target happens to match (e.g. kernel target
+    //      "cuda" matches "cuda+host=c" after stripping host).  Must
+    //      go through kernel launch, not the same-target shortcut.
+    //   2. A host caller invoking another host helper with a
+    //      different host target (e.g. SplitHostDevice emits an
+    //      "add_host" with target "c" while the host body still
+    //      carries "cuda+host=c").  Must go through call_extern (or
+    //      same-target subroutine), not kernel launch.
+    auto caller_target =
+        caller_is_host ? current_caller_host_target_.value() : 
current_target_.value();
+
+    // A host caller invoking a real device kernel must always go
+    // through the kernel-launch ABI, regardless of any same-target /
+    // same-device-type coincidence.
+    bool force_kernel_launch = callee_is_kernel && caller_is_host;
+
+    if (!force_kernel_launch) {
+      bool same_target = caller_target->str() == callee_target->str();
+      if (same_target) {
+        // Calls within the same target may be handled at codegen time
+        // as internal subroutine calls.
+        return node;
+      }
 
-    bool same_device_type =
-        caller_target->GetTargetDeviceType() == 
callee_target->GetTargetDeviceType();
-    if (same_device_type) {
-      // Calls to another target using the same device (e.g. LLVM
-      // calling a custom TIRToRuntime target) do not require a kernel
-      // launch, but need to be replaced with call_extern.
-      extern_function_call_.insert(gvar);
-      ffi::Array<PrimExpr> args;
-      args.push_back(StringImm(gvar->name_hint));
-      for (const auto& arg : node->args) {
-        args.push_back(arg);
+      bool same_device_type =
+          caller_target->GetTargetDeviceType() == 
callee_target->GetTargetDeviceType();
+      if (same_device_type) {
+        // Calls to another target using the same device (e.g. LLVM
+        // calling a custom TIRToRuntime target) do not require a kernel
+        // launch, but need to be replaced with call_extern.
+        extern_function_call_.insert(gvar);
+        ffi::Array<PrimExpr> args;
+        args.push_back(StringImm(gvar->name_hint));
+        for (const auto& arg : node->args) {
+          args.push_back(arg);
+        }
+        return Call(node->dtype, builtin::call_extern(), args);
       }
-      return Call(node->dtype, builtin::call_extern(), args);
     }
 
     TVM_FFI_ICHECK(dev_info.launch_params.defined())
@@ -336,6 +382,13 @@ class DeviceKernelMutator : public StmtExprMutator {
   }
 
   ffi::Optional<Target> current_target_;
+  // The host target of the caller currently being rewritten, if the
+  // caller is a host function (its kTarget has a host attached).
+  // Used both to detect that the caller is a host function and to
+  // compare against the callee target on the host side, so that
+  // host-to-host subroutine calls are not misrouted through the
+  // device kernel-launch ABI.
+  ffi::Optional<Target> current_caller_host_target_;
   std::unordered_map<const GlobalVarNode*, KernelInfo> device_info_map_;
   std::unordered_set<const GlobalVarNode*> device_kernel_launch_;
   std::unordered_set<const GlobalVarNode*> extern_function_call_;
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
 
b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
index ca7d1de7c4..b09c1fd796 100644
--- 
a/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
+++ 
b/tests/python/s_tir/transform/test_s_tir_transform_merge_dynamic_shared_memory_allocations.py
@@ -254,23 +254,100 @@ def test_async_copy():
     class Before:
         @T.prim_func(s_tir=True)
         def main(A: T.Buffer((128,), "float32"), B: T.Buffer((128,), 
"float32")):
+            threadIdx_x = T.launch_thread("threadIdx.x", 128)
             A_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
             B_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
-            threadIdx_x = T.launch_thread("threadIdx.x", 128)
             T.ptx.cp_async("float32", A_sh.data, threadIdx_x, A.data, 
threadIdx_x, 512)
             T.ptx.cp_async("float32", B_sh.data, threadIdx_x, B.data, 
threadIdx_x, 512)
 
     After = transform(Before)
-    # The pass merges shared.dyn allocations but DeclBuffer nodes from the 
original
-    # allocations remain with remapped data vars. The output can't be precisely
-    # represented in TVMScript due to same-name var constraints, so we verify
-    # key properties instead of exact structural equality.
+    # The pass merges shared.dyn allocations. A_sh and B_sh are accessed
+    # sequentially inside the thread_extent with non-overlapping lifetimes,
+    # so the liveness analysis allows reuse — both fit in 512 bytes
+    # (= 128 elements * 4 bytes).
     script = After["main"].script()
-    # Verify merged allocation (1024 bytes = 128*4 + 128*4)
-    assert '"uint8"' in script and '"shared.dyn"' in script and "(1024,)" in 
script
-    # Verify cp_async uses correct byte offsets
+    # Verify merged allocation (512 bytes - A_sh and B_sh can be reused)
+    assert '"uint8"' in script and '"shared.dyn"' in script and "(512,)" in 
script
+    # Verify cp_async uses the merged buffer
+    assert "buf_dyn_shmem" in script
     assert "threadIdx_x * 4" in script
-    assert "(128 + threadIdx_x) * 4" in script
+
+
+def test_multi_thread_extent_blocks():
+    """Each thread_extent block must get its own merged buffer.
+
+    Reproduces the scoping bug from PR #19605: a single PrimFunc
+    with two sibling thread_extent regions, each containing its
+    own shared.dyn allocations. The merged buffer must be allocated
+    inside each kernel body — not just the first.
+    """
+    transform = tvm.s_tir.transform.MergeSharedMemoryAllocations()
+
+    @I.ir_module(check_well_formed=False)
+    class Before:
+        @T.prim_func(s_tir=True, check_well_formed=False)
+        def main(
+            X: T.Buffer((128,), "float32"),
+            Y: T.Buffer((128,), "float32"),
+        ):
+            X_flat = T.decl_buffer(128, data=X.data)
+            Y_flat = T.decl_buffer(128, data=Y.data)
+
+            # First kernel launch
+            tx0 = T.env_thread("threadIdx.x")
+            with T.attr(tx0, "thread_extent", 128):
+                A_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
+                B_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
+                A_sh[tx0] = X_flat[tx0]
+                B_sh[tx0] = A_sh[tx0]
+                X_flat[tx0] = B_sh[tx0]
+
+            # Second kernel launch — must NOT see kernel #0's merged buffer.
+            tx1 = T.env_thread("threadIdx.x")
+            with T.attr(tx1, "thread_extent", 128):
+                C_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
+                D_sh = T.alloc_buffer((128,), "float32", scope="shared.dyn")
+                C_sh[tx1] = Y_flat[tx1]
+                D_sh[tx1] = C_sh[tx1]
+                Y_flat[tx1] = D_sh[tx1]
+
+    After = transform(Before)
+    script = After["main"].script()
+
+    # Two merged allocations — one per thread_extent body.
+    # Each of the four original 128-float32 buffers (A_sh, B_sh, C_sh, D_sh)
+    # gets merged within its own kernel scope.
+    assert script.count("shared.dyn") >= 2, (
+        "Expected at least two shared.dyn allocations (one per kernel)"
+    )
+    assert script.count("alloc_buffer") >= 2, (
+        "Expected at least two alloc_buffer nodes (one merged buf per kernel)"
+    )
+
+    # Both thread_extent blocks must contain their own merged buffer —
+    # they must NOT share the same buf_dyn_shmem variable.
+    # Structurally verify that the first kernel's body accesses are
+    # not rewritten to the second kernel's buf_dyn_shmem (and vice versa).
+    first_block = script.split("with T.attr(tx1")[0]
+    second_block = script.split("with T.attr(tx1")[1] if "tx1" in script else 
""
+    assert "buf_dyn_shmem" in first_block, "Kernel 1 must have a merged buffer"
+    if second_block:
+        assert "buf_dyn_shmem" in second_block, "Kernel 2 must have a merged 
buffer"
+
+    # End-to-end: post-merge IR must remain well-formed through
+    # the host/device split — this is the exact ordering from
+    # PR #19605 that triggers the scoping bug.
+    target = tvm.target.Target("llvm")
+    mod_with_target = tvm.IRModule({"main": After["main"].with_attr({"target": 
target})})
+    split = tvm.transform.Sequential(
+        [
+            tvm.tirx.transform.AnnotateDeviceRegions(),
+            tvm.tirx.transform.SplitHostDevice(),
+        ]
+    )
+    # If kernel #1 referenced an undefined buf_dyn_shmem, this
+    # would raise during well-formedness checking inside SplitHostDevice.
+    split(mod_with_target)
 
 
 if __name__ == "__main__":


Reply via email to