This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9d13fc04d2 [S-TIR] Fix Segfault when applying Parallel during TIR 
schedule rewriting (#19403)
9d13fc04d2 is described below

commit 9d13fc04d2cd0429b2c4fc2bd0b4785d2cfce904
Author: Neo Chien <[email protected]>
AuthorDate: Thu Apr 16 08:18:36 2026 +0800

    [S-TIR] Fix Segfault when applying Parallel during TIR schedule rewriting 
(#19403)
    
    Hi Commiters,
    
    This PR is trying to fix issues
    https://github.com/apache/tvm/issues/18424. Any suggestions would be
    appreciated if you are available.
    
    ### Root Cause
    Unsafe dynamic-shape dereferences in `AdjustParallelVectorize` The code
    assumed IntImm for buffer shape / loop extent and dereferenced directly.
    With dynamic shapes, as<IntImmNode>() can be null, which can segfault
    before any try/catch handles it.
    
    ### Solution
    Replaced unsafe `IntImm` assumptions with null checks and
    GetLoopIntExtent(...); if contiguous analysis is not possible,
    conservatively disables that path instead of dereferencing null.
    
    ---------
    
    Co-authored-by: cchung100m <[email protected]>
---
 .../postproc/rewrite_parallel_vectorize_unroll.cc  | 23 +++++++++++++++++----
 ...e_postproc_rewrite_parallel_vectorize_unroll.py | 24 ++++++++++++++++++++++
 2 files changed, 43 insertions(+), 4 deletions(-)

diff --git 
a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc 
b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index ebaa58660e..d05c9a32cb 100644
--- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -213,6 +213,7 @@ void AdjustParallelVectorize(const Schedule& sch, const 
SBlockRV& block_rv,
   // (vectorizable) axes
   for (const BufferRegion& access : buffer_access) {
     int fusible = 0;
+    bool can_analyze_contiguous_access = true;
     std::vector<int64_t> strides;
     // get strides for each loop var
     for (const StmtSRef& loop_sref : loop_srefs) {
@@ -226,10 +227,22 @@ void AdjustParallelVectorize(const Schedule& sch, const 
SBlockRV& block_rv,
           stride = coef * buffer_stride;
           break;
         }
-        buffer_stride *= access->buffer->shape[i].as<IntImmNode>()->value;
+        const auto* shape = access->buffer->shape[i].as<IntImmNode>();
+        if (shape == nullptr) {
+          can_analyze_contiguous_access = false;
+          break;
+        }
+        buffer_stride *= shape->value;
+      }
+      if (!can_analyze_contiguous_access) {
+        break;
       }
       strides.push_back(stride);
     }
+    if (!can_analyze_contiguous_access) {
+      max_fusible = 0;
+      break;
+    }
     int prev_used_iter = -1;
     // check the number of fusible loops
     for (int i = strides.size() - 1; i >= 0; i--) {
@@ -246,9 +259,11 @@ void AdjustParallelVectorize(const Schedule& sch, const 
SBlockRV& block_rv,
         prev_used_iter = i;
       } else {
         // contiguous memory access
-        const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs<ForNode>();
-        int64_t prev_used_iter_extent = 
prev_loop->extent.as<IntImmNode>()->value;
-        if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) {
+        const int64_t* prev_used_iter_extent = 
GetLoopIntExtent(loop_srefs[prev_used_iter]);
+        if (prev_used_iter_extent == nullptr) {
+          break;
+        }
+        if (strides[i] == strides[prev_used_iter] * (*prev_used_iter_extent)) {
           fusible++;
           prev_used_iter = i;
         } else {
diff --git 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
index f70f16ea6c..e7baabb1e6 100644
--- 
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
+++ 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
@@ -181,6 +181,24 @@ def after_postproc_add(
                     add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4] 
+ rhs[v0, v1, v2, v3, v4]
 
 
[email protected]_func
+def before_postproc_dynamic_shape_vectorize(
+    a: T.handle,
+    b: T.handle,
+) -> None:
+    n = T.int64()
+    A = T.match_buffer(a, (n,), dtype="float32")
+    B = T.match_buffer(b, (n,), dtype="float32")
+    with T.block("root"):
+        T.block_attr({"meta_schedule.vectorize": 64})
+        for i in T.serial(0, n):
+            with T.block("copy"):
+                vi = T.axis.spatial(n, i)
+                T.reads(A[vi])
+                T.writes(B[vi])
+                B[vi] = A[vi]
+
+
 # fmt: on
 # pylint: 
enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable
 
@@ -269,5 +287,11 @@ def test_no_unroll_for_spatial_block():
     assert_structural_equal_ignore_global_symbol(mod["main"], expected)
 
 
+def test_rewrite_parallel_vectorize_unroll_dynamic_shape_no_crash():
+    sch = Schedule(before_postproc_dynamic_shape_vectorize)
+    rule = RewriteParallelVectorizeUnroll()
+    assert rule.apply(sch)
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to