This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 9d13fc04d2 [S-TIR] Fix Segfault when applying Parallel during TIR
schedule rewriting (#19403)
9d13fc04d2 is described below
commit 9d13fc04d2cd0429b2c4fc2bd0b4785d2cfce904
Author: Neo Chien <[email protected]>
AuthorDate: Thu Apr 16 08:18:36 2026 +0800
[S-TIR] Fix Segfault when applying Parallel during TIR schedule rewriting
(#19403)
Hi Commiters,
This PR is trying to fix issues
https://github.com/apache/tvm/issues/18424. Any suggestions would be
appreciated if you are available.
### Root Cause
Unsafe dynamic-shape dereferences in `AdjustParallelVectorize` The code
assumed IntImm for buffer shape / loop extent and dereferenced directly.
With dynamic shapes, as<IntImmNode>() can be null, which can segfault
before any try/catch handles it.
### Solution
Replaced unsafe `IntImm` assumptions with null checks and
GetLoopIntExtent(...); if contiguous analysis is not possible,
conservatively disables that path instead of dereferencing null.
---------
Co-authored-by: cchung100m <[email protected]>
---
.../postproc/rewrite_parallel_vectorize_unroll.cc | 23 +++++++++++++++++----
...e_postproc_rewrite_parallel_vectorize_unroll.py | 24 ++++++++++++++++++++++
2 files changed, 43 insertions(+), 4 deletions(-)
diff --git
a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
index ebaa58660e..d05c9a32cb 100644
--- a/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
+++ b/src/s_tir/meta_schedule/postproc/rewrite_parallel_vectorize_unroll.cc
@@ -213,6 +213,7 @@ void AdjustParallelVectorize(const Schedule& sch, const
SBlockRV& block_rv,
// (vectorizable) axes
for (const BufferRegion& access : buffer_access) {
int fusible = 0;
+ bool can_analyze_contiguous_access = true;
std::vector<int64_t> strides;
// get strides for each loop var
for (const StmtSRef& loop_sref : loop_srefs) {
@@ -226,10 +227,22 @@ void AdjustParallelVectorize(const Schedule& sch, const
SBlockRV& block_rv,
stride = coef * buffer_stride;
break;
}
- buffer_stride *= access->buffer->shape[i].as<IntImmNode>()->value;
+ const auto* shape = access->buffer->shape[i].as<IntImmNode>();
+ if (shape == nullptr) {
+ can_analyze_contiguous_access = false;
+ break;
+ }
+ buffer_stride *= shape->value;
+ }
+ if (!can_analyze_contiguous_access) {
+ break;
}
strides.push_back(stride);
}
+ if (!can_analyze_contiguous_access) {
+ max_fusible = 0;
+ break;
+ }
int prev_used_iter = -1;
// check the number of fusible loops
for (int i = strides.size() - 1; i >= 0; i--) {
@@ -246,9 +259,11 @@ void AdjustParallelVectorize(const Schedule& sch, const
SBlockRV& block_rv,
prev_used_iter = i;
} else {
// contiguous memory access
- const auto* prev_loop = loop_srefs[prev_used_iter]->StmtAs<ForNode>();
- int64_t prev_used_iter_extent =
prev_loop->extent.as<IntImmNode>()->value;
- if (strides[i] == strides[prev_used_iter] * prev_used_iter_extent) {
+ const int64_t* prev_used_iter_extent =
GetLoopIntExtent(loop_srefs[prev_used_iter]);
+ if (prev_used_iter_extent == nullptr) {
+ break;
+ }
+ if (strides[i] == strides[prev_used_iter] * (*prev_used_iter_extent)) {
fusible++;
prev_used_iter = i;
} else {
diff --git
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
index f70f16ea6c..e7baabb1e6 100644
---
a/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
+++
b/tests/python/s_tir/meta_schedule/test_meta_schedule_postproc_rewrite_parallel_vectorize_unroll.py
@@ -181,6 +181,24 @@ def after_postproc_add(
add_compute[v0, v1, v2, v3, v4] = lhs[v0, v1, v2, v3, v4]
+ rhs[v0, v1, v2, v3, v4]
[email protected]_func
+def before_postproc_dynamic_shape_vectorize(
+ a: T.handle,
+ b: T.handle,
+) -> None:
+ n = T.int64()
+ A = T.match_buffer(a, (n,), dtype="float32")
+ B = T.match_buffer(b, (n,), dtype="float32")
+ with T.block("root"):
+ T.block_attr({"meta_schedule.vectorize": 64})
+ for i in T.serial(0, n):
+ with T.block("copy"):
+ vi = T.axis.spatial(n, i)
+ T.reads(A[vi])
+ T.writes(B[vi])
+ B[vi] = A[vi]
+
+
# fmt: on
# pylint:
enable=invalid-name,no-member,line-too-long,too-many-nested-blocks,no-self-argument,not-callable
@@ -269,5 +287,11 @@ def test_no_unroll_for_spatial_block():
assert_structural_equal_ignore_global_symbol(mod["main"], expected)
+def test_rewrite_parallel_vectorize_unroll_dynamic_shape_no_crash():
+ sch = Schedule(before_postproc_dynamic_shape_vectorize)
+ rule = RewriteParallelVectorizeUnroll()
+ assert rule.apply(sch)
+
+
if __name__ == "__main__":
tvm.testing.main()