This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new c01f898ed0 [TIR] Update symbolic index term order in loop fusion 
(#18406)
c01f898ed0 is described below

commit c01f898ed0e57ea6a880db811ed2bb305b06bd50
Author: wrongtest <[email protected]>
AuthorDate: Mon Apr 27 03:44:21 2026 +0800

    [TIR] Update symbolic index term order in loop fusion (#18406)
    
    This change just keep stride terms order the same with fused loop order
    in `fuse` primitive. In symbolic circumstances, previous form suffer
    from simplification issues and would make the expression tree much
    complex in following lowering steps.
    
    Take [M, N] tiling as an example, the previous binding form after
    ```python
    i, j = sch.get_loops(block_b)
    i0, i1 = sch.split(i, factors=[None, 64])
    j0, j1 = sch.split(j, factors=[None, 16])
    sch.reorder(i0, j0, i1, j1)
    sch.fuse(i0, j0)
    ```
    
    would be like (i_0_j_0_fused in `[0, ceildiv(M, 64) * ceildiv(N, 16)]`
    ```
    vi = T.axis.spatial(M, i_0_j_0_fused % ((N + 15) // 16 * ((M + 63) // 64)) 
// ((N + 15) // 16) * 64 + i_1)
    ```
    instead of more simple version
    ```
    vi = T.axis.spatial(M, i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1)
    ```
    This is because unfortunately we do not know `ceildiv(N, 16) *
    ceildiv(M, 64) == ceildiv(M, 64) * ceildiv(N, 16)` in rule based
    simplifications. And then certain analysis (for example, region
    estimation) may fail to give concise estimations, due to complex dynamic
    expression trees.
    
    Co-authored-by: baoxinqi <[email protected]>
---
 src/arith/iter_affine_map.cc                       |  7 +++--
 .../schedule/primitive/loop_transformation.cc      |  9 +++---
 tests/python/s_tir/dlight/test_gpu_fallback.py     |  4 +--
 .../s_tir/dlight/test_gpu_general_reduction.py     |  8 ++---
 .../meta_schedule/test_meta_schedule_space_cuda.py |  2 +-
 .../s_tir/schedule/test_tir_schedule_split_fuse.py | 35 ++++++++++++++++++++++
 .../test_s_tir_transform_default_gpu_schedule.py   |  2 +-
 7 files changed, 51 insertions(+), 16 deletions(-)

diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 0e996485a4..a4d5097167 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -2115,7 +2115,8 @@ class IterMapToExprNormalizer : public ExprMutator {
     }
     if (analyzer_->CanProve(expr->extent == expr->source->extent) && 
is_one(expr->lower_factor)) {
       return source * expr->scale;
-    } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor 
* expr->extent)) {
+    } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor 
* expr->extent) ||
+               analyzer_->CanProve(expr->source->extent == expr->extent * 
expr->lower_factor)) {
       // Simplify if `expr` is always 0. The 2nd condition guarantess that we 
do not aggressively
       // simplify trivial iters like `vi \in [0, 1)`, which can be useful for 
subsequent analysis
       // like tensorization.
@@ -2124,8 +2125,8 @@ class IterMapToExprNormalizer : public ExprMutator {
       }
       return floordiv(source, expr->lower_factor) * expr->scale;
     } else {
-      return floordiv(floormod(source, expr->lower_factor * expr->extent), 
expr->lower_factor) *
-             expr->scale;
+      PrimExpr full_extent = analyzer_->canonical_simplify(expr->extent * 
expr->lower_factor);
+      return floordiv(floormod(source, full_extent), expr->lower_factor) * 
expr->scale;
     }
   }
 
diff --git a/src/s_tir/schedule/primitive/loop_transformation.cc 
b/src/s_tir/schedule/primitive/loop_transformation.cc
index ae35d1c91c..f43b1f2a2e 100644
--- a/src/s_tir/schedule/primitive/loop_transformation.cc
+++ b/src/s_tir/schedule/primitive/loop_transformation.cc
@@ -923,15 +923,16 @@ StmtSRef Fuse(ScheduleState self, const 
ffi::Array<StmtSRef>& loop_srefs,
     bits = std::max(bits, loops[i]->loop_var.dtype().bits());
   }
   suffix += "_fused";
+
   Var fused_var = 
loops[0]->loop_var.copy_with_suffix(suffix).copy_with_dtype(DataType::Int(bits));
   ffi::Array<PrimExpr> substitute_value;
   substitute_value.resize(loops.size());
   PrimExpr lower = 1;
   for (int i = static_cast<int>(loops.size()) - 1; i > 0; i--) {
-    substitute_value.Set(i, is_one(loops[i]->extent)
-                                ? 0
-                                : floordiv(floormod(fused_var, lower * 
loops[i]->extent), lower));
-    lower = lower * loops[i]->extent;
+    PrimExpr next_lower = analyzer.canonical_simplify(loops[i]->extent * 
lower);
+    substitute_value.Set(
+        i, is_one(loops[i]->extent) ? 0 : floordiv(floormod(fused_var, 
next_lower), lower));
+    lower = next_lower;
   }
   substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var, 
lower));
   Stmt new_stmt = loops.back()->body;
diff --git a/tests/python/s_tir/dlight/test_gpu_fallback.py 
b/tests/python/s_tir/dlight/test_gpu_fallback.py
index 58f9cc4ad1..72cf06a2ac 100644
--- a/tests/python/s_tir/dlight/test_gpu_fallback.py
+++ b/tests/python/s_tir/dlight/test_gpu_fallback.py
@@ -162,8 +162,8 @@ def test_fallback_irregular_spatial():
         for ax0_ax1_ax2_fused_0 in T.thread_binding((nlayer * nhead * seqlen + 
1023) // 1024, thread="blockIdx.x"):
             for ax0_ax1_ax2_fused_1 in T.thread_binding(1024, 
thread="threadIdx.x"):
                 with T.sblock("block"):
-                    v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % (seqlen * nhead * nlayer) // (seqlen * nhead))
-                    v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % (seqlen * nhead) // seqlen)
+                    v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) // (nhead * seqlen))
+                    v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % (nhead * seqlen) // seqlen)
                     v2 = T.axis.spatial(seqlen, (ax0_ax1_ax2_fused_0 * 1024 + 
ax0_ax1_ax2_fused_1) % seqlen)
                     T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 < 
nlayer * nhead * seqlen)
                     T.reads(pages[page_table_values[page_table_indptr[seq_id] 
+ v2 // page_size], v0, v1, v2 % page_size], 
page_table_values[page_table_indptr[seq_id] + v2 // page_size], 
page_table_indptr[seq_id])
diff --git a/tests/python/s_tir/dlight/test_gpu_general_reduction.py 
b/tests/python/s_tir/dlight/test_gpu_general_reduction.py
index 7c6b995b37..fbdbf1b82b 100644
--- a/tests/python/s_tir/dlight/test_gpu_general_reduction.py
+++ b/tests/python/s_tir/dlight/test_gpu_general_reduction.py
@@ -616,7 +616,7 @@ def test_logsumexp():
                             with T.sblock("max"):
                                 v0 = T.axis.spatial(
                                     batch_size,
-                                    ax0_ax1_fused % (num_chunks * batch_size) 
// num_chunks + ax0,
+                                    ax0_ax1_fused // num_chunks + ax0,
                                 )
                                 v1 = T.axis.spatial(num_chunks, ax0_ax1_fused 
% num_chunks + ax1)
                                 v2 = T.axis.reduce(
@@ -646,7 +646,7 @@ def test_logsumexp():
                             with T.sblock("sum_exp"):
                                 v0 = T.axis.spatial(
                                     batch_size,
-                                    ax0_ax1_fused % (num_chunks * batch_size) 
// num_chunks + ax0,
+                                    ax0_ax1_fused // num_chunks + ax0,
                                 )
                                 v1 = T.axis.spatial(num_chunks, ax0_ax1_fused 
% num_chunks + ax1)
                                 v2 = T.axis.reduce(
@@ -677,9 +677,7 @@ def test_logsumexp():
                         },
                     ):
                         with T.sblock("log"):
-                            v0 = T.axis.spatial(
-                                batch_size, ax0_ax1_fused % (num_chunks * 
batch_size) // num_chunks
-                            )
+                            v0 = T.axis.spatial(batch_size, ax0_ax1_fused // 
num_chunks)
                             v1 = T.axis.spatial(num_chunks, ax0_ax1_fused % 
num_chunks)
                             v2 = T.axis.spatial(T.int64(1), ax2_0 * 
T.int64(256) + ax2_1)
                             T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1))
diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py 
b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
index d99ce2fdcf..177b10f2c1 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
@@ -724,7 +724,7 @@ def test_cuda_t2d():
                             for ax0_ax1_ax2_ax3_fused in range(rh_0 % 2 * 96 + 
96):
                                 with T.sblock("PadInput_shared"):
                                     v0 = T.axis.spatial(1, 0)
-                                    v1 = T.axis.spatial(6, 
n_0_h_0_w_0_co_0_fused // 64 + rh_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (rh_0 
% 2 + 1)) // 96)
+                                    v1 = T.axis.spatial(6, 
n_0_h_0_w_0_co_0_fused // 64 + rh_0 // 2 + ax0_ax1_ax2_ax3_fused % (rh_0 % 2 * 
96 + 96) // 96)
                                     v2 = T.axis.spatial(6, 
n_0_h_0_w_0_co_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32)
                                     v3 = T.axis.spatial(512, rc_0 * 32 + 
ax0_ax1_ax2_ax3_fused % 32)
                                     T.reads(inputs[v0, v1 - 1, v2 - 1, v3])
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py 
b/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
index 3fe374be3d..afa28f5ef6 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
@@ -824,5 +824,40 @@ def test_unsupported_target_scalable_split(capfd):
     assert warning_msg in captured
 
 
+def test_fused_symbolic_2D_tiling():
+    @T.prim_func
+    def before(a: T.handle, b: T.handle, M: T.int32, N: T.int32) -> None:
+        A = T.match_buffer(a, (M, N))
+        B = T.match_buffer(b, (M, N))
+        for i, j in T.grid(M, N):
+            with T.sblock("B"):
+                vi, vj = T.axis.remap("SS", [i, j])
+                B[vi, vj] = A[vi, vj] * 2.0
+
+    @T.prim_func
+    def expected(a: T.handle, b: T.handle, M: T.int32, N: T.int32) -> None:
+        A = T.match_buffer(a, (M, N))
+        B = T.match_buffer(b, (M, N))
+        for i_0_j_0_fused, i_1, j_1 in T.grid(((M + 63) // 64) * ((N + 15) // 
16), 64, 16):
+            with T.sblock("B"):
+                vi = T.axis.spatial(M, i_0_j_0_fused // ((N + 15) // 16) * 64 
+ i_1)
+                vj = T.axis.spatial(N, i_0_j_0_fused % ((N + 15) // 16) * 16 + 
j_1)
+                T.where(
+                    i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1 < M
+                    and i_0_j_0_fused % ((N + 15) // 16) * 16 + j_1 < N
+                )
+                B[vi, vj] = A[vi, vj] * T.float32(2.0)
+
+    sch = tvm.s_tir.Schedule(before, debug_mask="all")
+    block_b = sch.get_sblock("B")
+    i, j = sch.get_loops(block_b)
+    i0, i1 = sch.split(i, factors=[None, 64])
+    j0, j1 = sch.split(j, factors=[None, 16])
+    sch.reorder(i0, j0, i1, j1)
+    sch.fuse(i0, j0)
+    assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
+    verify_trace_roundtrip(sch=sch, mod=before)
+
+
 if __name__ == "__main__":
     tvm.testing.main()
diff --git 
a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py 
b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
index 0212585b53..c562a29e87 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
@@ -55,7 +55,7 @@ def test_broadcast_to_symbolic():
                 for ax0_ax1_fused_2 in T.thread_binding(T.int64(1024), 
thread="threadIdx.x"):
                     for ax0_ax1_fused_0 in range((x_0 * x_1 + T.int64(262143)) 
// T.int64(262144)):
                         with T.sblock("T_broadcast_to"):
-                            v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 * 
T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % (x_1 * 
x_0) // x_1)
+                            v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 * 
T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) // x_1)
                             v_ax1 = T.axis.spatial(x_1, (ax0_ax1_fused_0 * 
T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % x_1)
                             T.where((ax0_ax1_fused_0 * T.int64(256) + 
ax0_ax1_fused_1) * T.int64(1024) + ax0_ax1_fused_2 < x_0 * x_1)
                             T_broadcast_to[v_ax0, v_ax1] = 
rxplaceholder[v_ax0, T.int64(0)]

Reply via email to