This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 0a79095b1d [S-TIR] Fix cache_read/cache_write region when inner block
has T.whe… (#19406)
0a79095b1d is described below
commit 0a79095b1dc6861ac2a162e8fe7b9c9d09e0a435
Author: Andrey Malyshev <[email protected]>
AuthorDate: Thu Apr 16 08:07:53 2026 +0300
[S-TIR] Fix cache_read/cache_write region when inner block has T.whe…
(#19406)
…re predicate
When the actual buffer access is gated by T.where on a nested (inner)
sblock, the outer block's own predicate is trivially true. Both
cache_write and cache_read were computing cache regions based only on
that outer predicate, producing allocations as large as the full loop
extent instead of the guarded region
Fix:
- Add CollectNestedBlockPredicates(), a single helper parameterised by
BufferIndexType (kRead / kWrite) that walks the outer block's body,
finds nested sblocks accessing the target buffer, and AND-combines their
predicates after substituting iter-var bindings into the outer scope.
- Add extra_predicate parameter to RelaxBufferRegion() and AND it with
the block's own predicate before region relaxation.
- cache_write: pass the collected nested-write predicate to
RelaxBufferRegion so the cache allocation is tightened.
- cache_read (Case 2 — input buffer): when a non-trivial nested-read
predicate exists, relax the consumer block's declared read region under
that predicate; otherwise fall back to the original scope_block->reads
path (preserves non-int32 dtypes in extents).
---
src/s_tir/schedule/primitive/cache_read_write.cc | 111 +++++++++++++-
.../schedule/test_tir_schedule_cache_read_write.py | 170 +++++++++++++++++++++
2 files changed, 273 insertions(+), 8 deletions(-)
diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc
b/src/s_tir/schedule/primitive/cache_read_write.cc
index baa80b38a9..b077386be1 100644
--- a/src/s_tir/schedule/primitive/cache_read_write.cc
+++ b/src/s_tir/schedule/primitive/cache_read_write.cc
@@ -542,6 +542,71 @@ bool AllConsumersUnderStmt(ScheduleState self, Buffer
buffer, StmtSRef scope_sre
return true;
}
+/*!
+ * \brief Collect OR-combined predicates from all nested BlockRealize nodes
within
+ * the given statement that access the specified buffer (read or write,
controlled by
+ * \p index_type). Each nested block's predicate is expressed in the enclosing
block's
+ * scope by substituting the nested block's iter var bindings. This is needed
when the
+ * actual access is gated by a predicate (T.where) on a nested block while the
outer
+ * block has a trivially-true predicate. Sibling blocks that each access the
buffer under
+ * different predicates are OR-ed together so the result covers the union of
their access
+ * regions.
+ * \param body The body statement of the outer block to search within.
+ * \param buffer The buffer being accessed.
+ * \param index_type Whether to look for reads (kRead) or writes (kWrite).
+ * \return The OR-combination of all nested block predicates found.
+ */
+static PrimExpr CollectNestedBlockPredicates(const Stmt& body, const Buffer&
buffer,
+ BufferIndexType index_type) {
+ struct Collector : public StmtVisitor {
+ Collector(const Buffer& buf, BufferIndexType idx_type)
+ : buffer_(buf), index_type_(idx_type), result_(Bool(false)),
found_(false) {}
+
+ void VisitStmt_(const SBlockRealizeNode* realize) final {
+ const SBlockNode* block = realize->block.get();
+ const auto& regions =
+ (index_type_ == BufferIndexType::kRead) ? block->reads :
block->writes;
+ bool accesses_buffer = false;
+ for (const BufferRegion& region : regions) {
+ if (region->buffer.same_as(buffer_)) {
+ accesses_buffer = true;
+ break;
+ }
+ }
+ if (accesses_buffer) {
+ // Build substitution: nested block iter vars -> their binding values
+ // (which are already expressed in terms of the outer scope).
+ ffi::Map<Var, PrimExpr> subst;
+ for (size_t i = 0; i < block->iter_vars.size(); ++i) {
+ subst.Set(block->iter_vars[i]->var, realize->iter_values[i]);
+ }
+ PrimExpr pred =
+ subst.empty() ? realize->predicate :
Substitute(realize->predicate, subst);
+ // OR the predicates across all accessing nested blocks: each such
block is an
+ // independent alternative access path (sibling blocks in a SeqStmt),
so the
+ // cache must cover the *union* of their access regions, not the
intersection.
+ // Using AND (the previous behaviour) underestimates the required
region when
+ // sibling blocks have non-overlapping predicates.
+ result_ = found_ ? (result_ || pred) : pred;
+ found_ = true;
+ }
+ // Continue recursing into deeper nested blocks.
+ StmtVisitor::VisitStmt_(realize);
+ }
+
+ const Buffer& buffer_;
+ BufferIndexType index_type_;
+ PrimExpr result_;
+ bool found_;
+ };
+
+ Collector collector(buffer, index_type);
+ collector(body);
+ // If no nested block accessed the buffer, return true (no restriction — the
caller
+ // will fall back to the original scope-block reads / FullRegion path).
+ return collector.found_ ? collector.result_ : Bool(true);
+}
+
/*!
* \brief Get the buffer region under the sref tree path [dom_low_inclusive,
dom_high_exclusive)
* \param self The state of the schedule.
@@ -549,11 +614,14 @@ bool AllConsumersUnderStmt(ScheduleState self, Buffer
buffer, StmtSRef scope_sre
* \param block_sref The sref of the block related to the region.
* \param dom_low_inclusive The lowest node in the sref tree path.
* \param dom_high_exclusive The highest node in the sref tree path.
+ * \param extra_predicate An additional predicate (e.g. collected from nested
blocks) to AND
+ * with the block's own predicate before relaxation. Defaults to true
(no effect).
* \return The relaxed buffer region.
*/
BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion&
buffer_region,
const StmtSRef& block_sref, const StmtSRef&
dom_low_inclusive,
- const StmtSRef& dom_high_exclusive) {
+ const StmtSRef& dom_high_exclusive,
+ PrimExpr extra_predicate = Bool(true)) {
SBlockRealize realize = GetSBlockRealize(self, block_sref);
ffi::Map<Var, PrimExpr> binding = GetBindings(realize);
const Buffer& buffer = buffer_region->buffer;
@@ -561,7 +629,7 @@ BufferRegion RelaxBufferRegion(ScheduleState self, const
BufferRegion& buffer_re
BufferRegion subst_region = BufferRegion(buffer,
Substitute(buffer_region->region, binding));
ffi::Array<arith::IntSet> int_sets = AnalyzeRegionUpperBound(
/*region=*/subst_region,
- /*predicate=*/realize->predicate,
+ /*predicate=*/Substitute(realize->predicate && extra_predicate, binding),
/*dom_low_inclusive=*/dom_low_inclusive,
/*dom_high_exclusive=*/dom_high_exclusive,
/*analyzer=*/&analyzer);
@@ -1703,9 +1771,25 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef&
block_sref, int read_buff
// Case 2. The buffer is the input block for the scope.
info.loc_sref = scope_sref;
info.loc_pos = 0;
- if (ffi::Optional<BufferRegion> region =
- GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) {
- cache_region = region.value();
+ // When a nested block gates the actual read with T.where, the consumer
block's own
+ // predicate is trivially true, so the scope-block read annotation covers
the full loop
+ // range. Collect nested-read predicates and, if any are non-trivial,
relax the consumer
+ // block's read region under that predicate to get a tighter cache
allocation.
+ // Without a nested predicate we fall back to scope_block->reads (which
preserves the
+ // original buffer's dtype in its extents, e.g. int64 shapes).
+ ffi::Optional<BufferRegion> read_region_opt =
+ GetBufferRegionFromBuffer(block->reads, read_buffer);
+ PrimExpr nested_pred =
+ read_region_opt
+ ? CollectNestedBlockPredicates(block->body, read_buffer,
BufferIndexType::kRead)
+ : Bool(true);
+ if (read_region_opt && !is_one(nested_pred) && block_sref->parent !=
nullptr) {
+ StmtSRef parent_sref = ffi::GetRef<StmtSRef>(block_sref->parent);
+ cache_region = RelaxBufferRegion(self, read_region_opt.value(),
block_sref, parent_sref,
+ scope_sref, nested_pred);
+ } else if (ffi::Optional<BufferRegion> scope_region =
+ GetBufferRegionFromBuffer(scope_block->reads, read_buffer))
{
+ cache_region = scope_region.value();
} else {
cache_region = BufferRegion::FullRegion(read_buffer);
}
@@ -1782,11 +1866,22 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef&
block_sref, int write_bu
// Step 4. Find the producing region and insert position
BufferRegion region = GetBufferRegionFromBuffer(block->writes,
write_buffer).value();
- StmtSRef parent_sref = ffi::GetRef<StmtSRef>(block_sref->parent);
// Detect insert position
CacheLocDetector::Detect</*is_cache_read=*/false>(self, block_sref,
scope_sref, &info);
- BufferRegion cache_region =
- RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref);
+ // Collect predicates from any nested blocks that gate the actual write
(e.g. T.where on an
+ // inner block). The outer block's own predicate may be trivially true even
though the write
+ // is restricted by a nested predicate, so we OR them together for a tighter
region estimate.
+ PrimExpr nested_write_pred =
+ CollectNestedBlockPredicates(block->body, write_buffer,
BufferIndexType::kWrite);
+ BufferRegion cache_region;
+ if (block_sref->parent != nullptr) {
+ StmtSRef parent_sref = ffi::GetRef<StmtSRef>(block_sref->parent);
+ cache_region =
+ RelaxBufferRegion(self, region, block_sref, parent_sref,
info.loc_sref, nested_write_pred);
+ } else {
+ // Root block: no enclosing loops to relax over, use the write region
directly.
+ cache_region = region;
+ }
bool cache_full_region = info.loc_sref->StmtAs<SBlockNode>() == nullptr ||
!AllConsumersUnderStmt(self, write_buffer,
scope_sref, info.loc_sref);
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py
b/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py
index 5da2c56535..8877044437 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py
@@ -1670,5 +1670,175 @@ def
test_symbolic_matmul_blocked_cache_write(use_block_name):
verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked)
+def test_cache_write_with_nested_block_predicate():
+ @T.prim_func
+ def main(A: T.handle, C: T.handle) -> None:
+ A_buf = T.match_buffer(A, (12, 24), "float32")
+ C_buf = T.match_buffer(C, (10, 20), "float32")
+
+ for i, j in T.grid(12, 24):
+ with T.sblock("compute"):
+ vi, vj = T.axis.remap("SS", [i, j])
+
+ with T.sblock("inner"):
+ T.where(vi < 10 and vj < 20)
+ C_buf[vi, vj] = A_buf[vi, vj] * 2.0
+
+ @T.prim_func
+ def expected(A_buf: T.Buffer((12, 24), "float32"), C_buf: T.Buffer((10,
20), "float32")):
+ with T.sblock("root"):
+ C_buf_local = T.sblock_alloc_buffer((10, 20), scope="local")
+ for i, j in T.grid(12, 24):
+ with T.sblock("compute"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A_buf[vi, vj])
+ T.writes(C_buf_local[vi, vj])
+ with T.sblock("inner"):
+ T.where(vi < 10 and vj < 20)
+ T.reads(A_buf[vi, vj])
+ T.writes(C_buf_local[vi, vj])
+ C_buf_local[vi, vj] = A_buf[vi, vj] * T.float32(2)
+ for ax0, ax1 in T.grid(10, 20):
+ with T.sblock("C_buf_local"):
+ v0, v1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(C_buf_local[v0, v1])
+ T.writes(C_buf[v0, v1])
+ C_buf[v0, v1] = C_buf_local[v0, v1]
+
+ sch = tvm.s_tir.Schedule(main)
+ block = sch.get_sblock("compute")
+ sch.cache_write(block, 0, "local")
+ assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
+
+
+def test_cache_read_with_nested_block_predicate():
+ @T.prim_func
+ def main(A: T.handle, C: T.handle) -> None:
+ A_buf = T.match_buffer(A, (12, 24), "float32")
+ C_buf = T.match_buffer(C, (10, 20), "float32")
+
+ for i, j in T.grid(12, 24):
+ with T.sblock("compute"):
+ vi, vj = T.axis.remap("SS", [i, j])
+
+ with T.sblock("inner"):
+ T.where(vi < 10 and vj < 20)
+ C_buf[vi, vj] = A_buf[vi, vj] * 2.0
+
+ @T.prim_func
+ def expected(A_buf: T.Buffer((12, 24), "float32"), C_buf: T.Buffer((10,
20), "float32")):
+ with T.sblock("root"):
+ A_buf_local = T.sblock_alloc_buffer((10, 20), scope="local")
+ for ax0, ax1 in T.grid(10, 20):
+ with T.sblock("A_buf_local"):
+ v0, v1 = T.axis.remap("SS", [ax0, ax1])
+ T.reads(A_buf[v0, v1])
+ T.writes(A_buf_local[v0, v1])
+ A_buf_local[v0, v1] = A_buf[v0, v1]
+ for i, j in T.grid(12, 24):
+ with T.sblock("compute"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ T.reads(A_buf_local[vi, vj])
+ T.writes(C_buf[vi, vj])
+ with T.sblock("inner"):
+ T.where(vi < 10 and vj < 20)
+ T.reads(A_buf_local[vi, vj])
+ T.writes(C_buf[vi, vj])
+ C_buf[vi, vj] = A_buf_local[vi, vj] * T.float32(2)
+
+ sch = tvm.s_tir.Schedule(main)
+ block = sch.get_sblock("compute")
+ sch.cache_read(block, 0, "local")
+ assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
+
+
+def test_cache_write_sibling_nested_block_predicates_use_union():
+ """Regression: cache_write with sibling nested blocks must union their
predicates.
+
+ Two sibling nested sblocks access the same buffer under *different*
predicates:
+ left block: T.where(vi < 8) — writes rows 0-7, all columns
+ top block: T.where(vj < 16) — writes all rows, columns 0-15
+
+ The cache must cover the UNION of both access sets. The bounding box of
that
+ union is (12, 24) — the full buffer shape.
+
+ Bug: CollectNestedBlockPredicates ANDs the predicates of all found nested
blocks,
+ giving (vi < 8) AND (vj < 16). RelaxBufferRegion under that intersection
predicate
+ yields the bounding box of the *intersection* instead: (8, 16), which is
too small.
+ The "left" block then writes C_buf_local[vi, vj] for vi in [8,12) —
indices that
+ were never loaded into C_buf_local — resulting in incorrect output.
+ """
+
+ @T.prim_func
+ def main(A: T.handle, C: T.handle) -> None:
+ A_buf = T.match_buffer(A, (12, 24), "float32")
+ C_buf = T.match_buffer(C, (12, 24), "float32")
+ for i, j in T.grid(12, 24):
+ with T.sblock("compute"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ with T.sblock("left"):
+ T.where(vi < 8)
+ C_buf[vi, vj] = A_buf[vi, vj] * 2.0
+ with T.sblock("top"):
+ T.where(vj < 16)
+ C_buf[vi, vj] = A_buf[vi, vj] * 3.0
+
+ sch = tvm.s_tir.Schedule(main)
+ block = sch.get_sblock("compute")
+ sch.cache_write(block, 0, "local")
+
+ # Extract the alloc buffer shape from the resulting IR.
+ result_script = sch.mod["main"].script()
+ # The cache must be large enough to hold the union of both write regions.
+ # Union bounding box = full (12, 24). The buggy AND gives (8, 16).
+ assert "sblock_alloc_buffer((12, 24)" in result_script, (
+ f"Expected cache shape (12, 24) covering the union of both write
regions, "
+ f"but got a smaller shape. Full IR:\n{result_script}"
+ )
+
+
+def test_cache_read_sibling_nested_block_predicates_use_union():
+ """Regression: cache_read with sibling nested blocks must union their
predicates.
+
+ Two sibling nested sblocks read the same input buffer under different
predicates:
+ left block: T.where(vi < 8) — reads rows 0-7, all columns
+ top block: T.where(vj < 16) — reads all rows, columns 0-15
+
+ The cache must cover the UNION of both read sets. The bounding box of that
+ union is (12, 24) — the full buffer shape.
+
+ Bug: CollectNestedBlockPredicates ANDs the two predicates, giving (vi < 8)
AND
+ (vj < 16). Case 2 of CacheRead calls RelaxBufferRegion under that
intersection
+ predicate, producing a cache of shape (8, 16). The "left" block then
tries to
+ read A_buf_local[vi, vj] for vi in [8,12) — indices outside the cache —
which
+ is incorrect.
+ """
+
+ @T.prim_func
+ def main(A: T.handle, C: T.handle) -> None:
+ A_buf = T.match_buffer(A, (12, 24), "float32")
+ C_buf = T.match_buffer(C, (12, 24), "float32")
+ for i, j in T.grid(12, 24):
+ with T.sblock("compute"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ with T.sblock("left"):
+ T.where(vi < 8)
+ C_buf[vi, vj] = A_buf[vi, vj] * 2.0
+ with T.sblock("top"):
+ T.where(vj < 16)
+ C_buf[vi, vj] = A_buf[vi, vj] * 3.0
+
+ sch = tvm.s_tir.Schedule(main)
+ block = sch.get_sblock("compute")
+ sch.cache_read(block, 0, "local")
+
+ result_script = sch.mod["main"].script()
+ # Cache must cover the union bounding box (12, 24). Buggy AND gives (8,
16).
+ assert "sblock_alloc_buffer((12, 24)" in result_script, (
+ f"Expected cache shape (12, 24) covering the union of both read
regions, "
+ f"but got a smaller shape. Full IR:\n{result_script}"
+ )
+
+
if __name__ == "__main__":
tvm.testing.main()