This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 0a79095b1d  [S-TIR] Fix cache_read/cache_write region when inner block 
has T.whe… (#19406)
0a79095b1d is described below

commit 0a79095b1dc6861ac2a162e8fe7b9c9d09e0a435
Author: Andrey Malyshev <[email protected]>
AuthorDate: Thu Apr 16 08:07:53 2026 +0300

     [S-TIR] Fix cache_read/cache_write region when inner block has T.whe… 
(#19406)
    
    …re predicate
    
    When the actual buffer access is gated by T.where on a nested (inner)
    sblock, the outer block's own predicate is trivially true. Both
    cache_write and cache_read were computing cache regions based only on
    that outer predicate, producing allocations as large as the full loop
    extent instead of the guarded region
    
      Fix:
    - Add CollectNestedBlockPredicates(), a single helper parameterised by
    BufferIndexType (kRead / kWrite) that walks the outer block's body,
    finds nested sblocks accessing the target buffer, and AND-combines their
    predicates after substituting iter-var bindings into the outer scope.
    - Add extra_predicate parameter to RelaxBufferRegion() and AND it with
    the block's own predicate before region relaxation.
    - cache_write: pass the collected nested-write predicate to
    RelaxBufferRegion so the cache allocation is tightened.
    - cache_read (Case 2 — input buffer): when a non-trivial nested-read
    predicate exists, relax the consumer block's declared read region under
    that predicate; otherwise fall back to the original scope_block->reads
    path (preserves non-int32 dtypes in extents).
---
 src/s_tir/schedule/primitive/cache_read_write.cc   | 111 +++++++++++++-
 .../schedule/test_tir_schedule_cache_read_write.py | 170 +++++++++++++++++++++
 2 files changed, 273 insertions(+), 8 deletions(-)

diff --git a/src/s_tir/schedule/primitive/cache_read_write.cc 
b/src/s_tir/schedule/primitive/cache_read_write.cc
index baa80b38a9..b077386be1 100644
--- a/src/s_tir/schedule/primitive/cache_read_write.cc
+++ b/src/s_tir/schedule/primitive/cache_read_write.cc
@@ -542,6 +542,71 @@ bool AllConsumersUnderStmt(ScheduleState self, Buffer 
buffer, StmtSRef scope_sre
   return true;
 }
 
+/*!
+ * \brief Collect OR-combined predicates from all nested BlockRealize nodes 
within
+ * the given statement that access the specified buffer (read or write, 
controlled by
+ * \p index_type). Each nested block's predicate is expressed in the enclosing 
block's
+ * scope by substituting the nested block's iter var bindings. This is needed 
when the
+ * actual access is gated by a predicate (T.where) on a nested block while the 
outer
+ * block has a trivially-true predicate. Sibling blocks that each access the 
buffer under
+ * different predicates are OR-ed together so the result covers the union of 
their access
+ * regions.
+ * \param body The body statement of the outer block to search within.
+ * \param buffer The buffer being accessed.
+ * \param index_type Whether to look for reads (kRead) or writes (kWrite).
+ * \return The OR-combination of all nested block predicates found.
+ */
+static PrimExpr CollectNestedBlockPredicates(const Stmt& body, const Buffer& 
buffer,
+                                             BufferIndexType index_type) {
+  struct Collector : public StmtVisitor {
+    Collector(const Buffer& buf, BufferIndexType idx_type)
+        : buffer_(buf), index_type_(idx_type), result_(Bool(false)), 
found_(false) {}
+
+    void VisitStmt_(const SBlockRealizeNode* realize) final {
+      const SBlockNode* block = realize->block.get();
+      const auto& regions =
+          (index_type_ == BufferIndexType::kRead) ? block->reads : 
block->writes;
+      bool accesses_buffer = false;
+      for (const BufferRegion& region : regions) {
+        if (region->buffer.same_as(buffer_)) {
+          accesses_buffer = true;
+          break;
+        }
+      }
+      if (accesses_buffer) {
+        // Build substitution: nested block iter vars -> their binding values
+        // (which are already expressed in terms of the outer scope).
+        ffi::Map<Var, PrimExpr> subst;
+        for (size_t i = 0; i < block->iter_vars.size(); ++i) {
+          subst.Set(block->iter_vars[i]->var, realize->iter_values[i]);
+        }
+        PrimExpr pred =
+            subst.empty() ? realize->predicate : 
Substitute(realize->predicate, subst);
+        // OR the predicates across all accessing nested blocks: each such 
block is an
+        // independent alternative access path (sibling blocks in a SeqStmt), 
so the
+        // cache must cover the *union* of their access regions, not the 
intersection.
+        // Using AND (the previous behaviour) underestimates the required 
region when
+        // sibling blocks have non-overlapping predicates.
+        result_ = found_ ? (result_ || pred) : pred;
+        found_ = true;
+      }
+      // Continue recursing into deeper nested blocks.
+      StmtVisitor::VisitStmt_(realize);
+    }
+
+    const Buffer& buffer_;
+    BufferIndexType index_type_;
+    PrimExpr result_;
+    bool found_;
+  };
+
+  Collector collector(buffer, index_type);
+  collector(body);
+  // If no nested block accessed the buffer, return true (no restriction — the 
caller
+  // will fall back to the original scope-block reads / FullRegion path).
+  return collector.found_ ? collector.result_ : Bool(true);
+}
+
 /*!
  * \brief Get the buffer region under the sref tree path [dom_low_inclusive, 
dom_high_exclusive)
  * \param self The state of the schedule.
@@ -549,11 +614,14 @@ bool AllConsumersUnderStmt(ScheduleState self, Buffer 
buffer, StmtSRef scope_sre
  * \param block_sref The sref of the block related to the region.
  * \param dom_low_inclusive The lowest node in the sref tree path.
  * \param dom_high_exclusive The highest node in the sref tree path.
+ * \param extra_predicate An additional predicate (e.g. collected from nested 
blocks) to AND
+ *        with the block's own predicate before relaxation. Defaults to true 
(no effect).
  * \return The relaxed buffer region.
  */
 BufferRegion RelaxBufferRegion(ScheduleState self, const BufferRegion& 
buffer_region,
                                const StmtSRef& block_sref, const StmtSRef& 
dom_low_inclusive,
-                               const StmtSRef& dom_high_exclusive) {
+                               const StmtSRef& dom_high_exclusive,
+                               PrimExpr extra_predicate = Bool(true)) {
   SBlockRealize realize = GetSBlockRealize(self, block_sref);
   ffi::Map<Var, PrimExpr> binding = GetBindings(realize);
   const Buffer& buffer = buffer_region->buffer;
@@ -561,7 +629,7 @@ BufferRegion RelaxBufferRegion(ScheduleState self, const 
BufferRegion& buffer_re
   BufferRegion subst_region = BufferRegion(buffer, 
Substitute(buffer_region->region, binding));
   ffi::Array<arith::IntSet> int_sets = AnalyzeRegionUpperBound(
       /*region=*/subst_region,
-      /*predicate=*/realize->predicate,
+      /*predicate=*/Substitute(realize->predicate && extra_predicate, binding),
       /*dom_low_inclusive=*/dom_low_inclusive,
       /*dom_high_exclusive=*/dom_high_exclusive,
       /*analyzer=*/&analyzer);
@@ -1703,9 +1771,25 @@ StmtSRef CacheRead(ScheduleState self, const StmtSRef& 
block_sref, int read_buff
     // Case 2. The buffer is the input block for the scope.
     info.loc_sref = scope_sref;
     info.loc_pos = 0;
-    if (ffi::Optional<BufferRegion> region =
-            GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) {
-      cache_region = region.value();
+    // When a nested block gates the actual read with T.where, the consumer 
block's own
+    // predicate is trivially true, so the scope-block read annotation covers 
the full loop
+    // range. Collect nested-read predicates and, if any are non-trivial, 
relax the consumer
+    // block's read region under that predicate to get a tighter cache 
allocation.
+    // Without a nested predicate we fall back to scope_block->reads (which 
preserves the
+    // original buffer's dtype in its extents, e.g. int64 shapes).
+    ffi::Optional<BufferRegion> read_region_opt =
+        GetBufferRegionFromBuffer(block->reads, read_buffer);
+    PrimExpr nested_pred =
+        read_region_opt
+            ? CollectNestedBlockPredicates(block->body, read_buffer, 
BufferIndexType::kRead)
+            : Bool(true);
+    if (read_region_opt && !is_one(nested_pred) && block_sref->parent != 
nullptr) {
+      StmtSRef parent_sref = ffi::GetRef<StmtSRef>(block_sref->parent);
+      cache_region = RelaxBufferRegion(self, read_region_opt.value(), 
block_sref, parent_sref,
+                                       scope_sref, nested_pred);
+    } else if (ffi::Optional<BufferRegion> scope_region =
+                   GetBufferRegionFromBuffer(scope_block->reads, read_buffer)) 
{
+      cache_region = scope_region.value();
     } else {
       cache_region = BufferRegion::FullRegion(read_buffer);
     }
@@ -1782,11 +1866,22 @@ StmtSRef CacheWrite(ScheduleState self, const StmtSRef& 
block_sref, int write_bu
 
   // Step 4. Find the producing region and insert position
   BufferRegion region = GetBufferRegionFromBuffer(block->writes, 
write_buffer).value();
-  StmtSRef parent_sref = ffi::GetRef<StmtSRef>(block_sref->parent);
   // Detect insert position
   CacheLocDetector::Detect</*is_cache_read=*/false>(self, block_sref, 
scope_sref, &info);
-  BufferRegion cache_region =
-      RelaxBufferRegion(self, region, block_sref, parent_sref, info.loc_sref);
+  // Collect predicates from any nested blocks that gate the actual write 
(e.g. T.where on an
+  // inner block). The outer block's own predicate may be trivially true even 
though the write
+  // is restricted by a nested predicate, so we OR them together for a tighter 
region estimate.
+  PrimExpr nested_write_pred =
+      CollectNestedBlockPredicates(block->body, write_buffer, 
BufferIndexType::kWrite);
+  BufferRegion cache_region;
+  if (block_sref->parent != nullptr) {
+    StmtSRef parent_sref = ffi::GetRef<StmtSRef>(block_sref->parent);
+    cache_region =
+        RelaxBufferRegion(self, region, block_sref, parent_sref, 
info.loc_sref, nested_write_pred);
+  } else {
+    // Root block: no enclosing loops to relax over, use the write region 
directly.
+    cache_region = region;
+  }
 
   bool cache_full_region = info.loc_sref->StmtAs<SBlockNode>() == nullptr ||
                            !AllConsumersUnderStmt(self, write_buffer, 
scope_sref, info.loc_sref);
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py 
b/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py
index 5da2c56535..8877044437 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_cache_read_write.py
@@ -1670,5 +1670,175 @@ def 
test_symbolic_matmul_blocked_cache_write(use_block_name):
     verify_trace_roundtrip(sch=sch, mod=symbolic_matmul_blocked)
 
 
+def test_cache_write_with_nested_block_predicate():
+    @T.prim_func
+    def main(A: T.handle, C: T.handle) -> None:
+        A_buf = T.match_buffer(A, (12, 24), "float32")
+        C_buf = T.match_buffer(C, (10, 20), "float32")
+
+        for i, j in T.grid(12, 24):
+            with T.sblock("compute"):
+                vi, vj = T.axis.remap("SS", [i, j])
+
+                with T.sblock("inner"):
+                    T.where(vi < 10 and vj < 20)
+                    C_buf[vi, vj] = A_buf[vi, vj] * 2.0
+
+    @T.prim_func
+    def expected(A_buf: T.Buffer((12, 24), "float32"), C_buf: T.Buffer((10, 
20), "float32")):
+        with T.sblock("root"):
+            C_buf_local = T.sblock_alloc_buffer((10, 20), scope="local")
+            for i, j in T.grid(12, 24):
+                with T.sblock("compute"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    T.reads(A_buf[vi, vj])
+                    T.writes(C_buf_local[vi, vj])
+                    with T.sblock("inner"):
+                        T.where(vi < 10 and vj < 20)
+                        T.reads(A_buf[vi, vj])
+                        T.writes(C_buf_local[vi, vj])
+                        C_buf_local[vi, vj] = A_buf[vi, vj] * T.float32(2)
+            for ax0, ax1 in T.grid(10, 20):
+                with T.sblock("C_buf_local"):
+                    v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(C_buf_local[v0, v1])
+                    T.writes(C_buf[v0, v1])
+                    C_buf[v0, v1] = C_buf_local[v0, v1]
+
+    sch = tvm.s_tir.Schedule(main)
+    block = sch.get_sblock("compute")
+    sch.cache_write(block, 0, "local")
+    assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
+
+
+def test_cache_read_with_nested_block_predicate():
+    @T.prim_func
+    def main(A: T.handle, C: T.handle) -> None:
+        A_buf = T.match_buffer(A, (12, 24), "float32")
+        C_buf = T.match_buffer(C, (10, 20), "float32")
+
+        for i, j in T.grid(12, 24):
+            with T.sblock("compute"):
+                vi, vj = T.axis.remap("SS", [i, j])
+
+                with T.sblock("inner"):
+                    T.where(vi < 10 and vj < 20)
+                    C_buf[vi, vj] = A_buf[vi, vj] * 2.0
+
+    @T.prim_func
+    def expected(A_buf: T.Buffer((12, 24), "float32"), C_buf: T.Buffer((10, 
20), "float32")):
+        with T.sblock("root"):
+            A_buf_local = T.sblock_alloc_buffer((10, 20), scope="local")
+            for ax0, ax1 in T.grid(10, 20):
+                with T.sblock("A_buf_local"):
+                    v0, v1 = T.axis.remap("SS", [ax0, ax1])
+                    T.reads(A_buf[v0, v1])
+                    T.writes(A_buf_local[v0, v1])
+                    A_buf_local[v0, v1] = A_buf[v0, v1]
+            for i, j in T.grid(12, 24):
+                with T.sblock("compute"):
+                    vi, vj = T.axis.remap("SS", [i, j])
+                    T.reads(A_buf_local[vi, vj])
+                    T.writes(C_buf[vi, vj])
+                    with T.sblock("inner"):
+                        T.where(vi < 10 and vj < 20)
+                        T.reads(A_buf_local[vi, vj])
+                        T.writes(C_buf[vi, vj])
+                        C_buf[vi, vj] = A_buf_local[vi, vj] * T.float32(2)
+
+    sch = tvm.s_tir.Schedule(main)
+    block = sch.get_sblock("compute")
+    sch.cache_read(block, 0, "local")
+    assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
+
+
+def test_cache_write_sibling_nested_block_predicates_use_union():
+    """Regression: cache_write with sibling nested blocks must union their 
predicates.
+
+    Two sibling nested sblocks access the same buffer under *different* 
predicates:
+      left  block: T.where(vi < 8)   — writes rows 0-7, all columns
+      top   block: T.where(vj < 16)  — writes all rows, columns 0-15
+
+    The cache must cover the UNION of both access sets.  The bounding box of 
that
+    union is (12, 24) — the full buffer shape.
+
+    Bug: CollectNestedBlockPredicates ANDs the predicates of all found nested 
blocks,
+    giving (vi < 8) AND (vj < 16).  RelaxBufferRegion under that intersection 
predicate
+    yields the bounding box of the *intersection* instead: (8, 16), which is 
too small.
+    The "left" block then writes C_buf_local[vi, vj] for vi in [8,12) — 
indices that
+    were never loaded into C_buf_local — resulting in incorrect output.
+    """
+
+    @T.prim_func
+    def main(A: T.handle, C: T.handle) -> None:
+        A_buf = T.match_buffer(A, (12, 24), "float32")
+        C_buf = T.match_buffer(C, (12, 24), "float32")
+        for i, j in T.grid(12, 24):
+            with T.sblock("compute"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                with T.sblock("left"):
+                    T.where(vi < 8)
+                    C_buf[vi, vj] = A_buf[vi, vj] * 2.0
+                with T.sblock("top"):
+                    T.where(vj < 16)
+                    C_buf[vi, vj] = A_buf[vi, vj] * 3.0
+
+    sch = tvm.s_tir.Schedule(main)
+    block = sch.get_sblock("compute")
+    sch.cache_write(block, 0, "local")
+
+    # Extract the alloc buffer shape from the resulting IR.
+    result_script = sch.mod["main"].script()
+    # The cache must be large enough to hold the union of both write regions.
+    # Union bounding box = full (12, 24).  The buggy AND gives (8, 16).
+    assert "sblock_alloc_buffer((12, 24)" in result_script, (
+        f"Expected cache shape (12, 24) covering the union of both write 
regions, "
+        f"but got a smaller shape. Full IR:\n{result_script}"
+    )
+
+
+def test_cache_read_sibling_nested_block_predicates_use_union():
+    """Regression: cache_read with sibling nested blocks must union their 
predicates.
+
+    Two sibling nested sblocks read the same input buffer under different 
predicates:
+      left  block: T.where(vi < 8)   — reads rows 0-7, all columns
+      top   block: T.where(vj < 16)  — reads all rows, columns 0-15
+
+    The cache must cover the UNION of both read sets.  The bounding box of that
+    union is (12, 24) — the full buffer shape.
+
+    Bug: CollectNestedBlockPredicates ANDs the two predicates, giving (vi < 8) 
AND
+    (vj < 16).  Case 2 of CacheRead calls RelaxBufferRegion under that 
intersection
+    predicate, producing a cache of shape (8, 16).  The "left" block then 
tries to
+    read A_buf_local[vi, vj] for vi in [8,12) — indices outside the cache — 
which
+    is incorrect.
+    """
+
+    @T.prim_func
+    def main(A: T.handle, C: T.handle) -> None:
+        A_buf = T.match_buffer(A, (12, 24), "float32")
+        C_buf = T.match_buffer(C, (12, 24), "float32")
+        for i, j in T.grid(12, 24):
+            with T.sblock("compute"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                with T.sblock("left"):
+                    T.where(vi < 8)
+                    C_buf[vi, vj] = A_buf[vi, vj] * 2.0
+                with T.sblock("top"):
+                    T.where(vj < 16)
+                    C_buf[vi, vj] = A_buf[vi, vj] * 3.0
+
+    sch = tvm.s_tir.Schedule(main)
+    block = sch.get_sblock("compute")
+    sch.cache_read(block, 0, "local")
+
+    result_script = sch.mod["main"].script()
+    # Cache must cover the union bounding box (12, 24).  Buggy AND gives (8, 
16).
+    assert "sblock_alloc_buffer((12, 24)" in result_script, (
+        f"Expected cache shape (12, 24) covering the union of both read 
regions, "
+        f"but got a smaller shape. Full IR:\n{result_script}"
+    )
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to