This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new c01f898ed0 [TIR] Update symbolic index term order in loop fusion
(#18406)
c01f898ed0 is described below
commit c01f898ed0e57ea6a880db811ed2bb305b06bd50
Author: wrongtest <[email protected]>
AuthorDate: Mon Apr 27 03:44:21 2026 +0800
[TIR] Update symbolic index term order in loop fusion (#18406)
This change just keep stride terms order the same with fused loop order
in `fuse` primitive. In symbolic circumstances, previous form suffer
from simplification issues and would make the expression tree much
complex in following lowering steps.
Take [M, N] tiling as an example, the previous binding form after
```python
i, j = sch.get_loops(block_b)
i0, i1 = sch.split(i, factors=[None, 64])
j0, j1 = sch.split(j, factors=[None, 16])
sch.reorder(i0, j0, i1, j1)
sch.fuse(i0, j0)
```
would be like (i_0_j_0_fused in `[0, ceildiv(M, 64) * ceildiv(N, 16)]`
```
vi = T.axis.spatial(M, i_0_j_0_fused % ((N + 15) // 16 * ((M + 63) // 64))
// ((N + 15) // 16) * 64 + i_1)
```
instead of more simple version
```
vi = T.axis.spatial(M, i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1)
```
This is because unfortunately we do not know `ceildiv(N, 16) *
ceildiv(M, 64) == ceildiv(M, 64) * ceildiv(N, 16)` in rule based
simplifications. And then certain analysis (for example, region
estimation) may fail to give concise estimations, due to complex dynamic
expression trees.
Co-authored-by: baoxinqi <[email protected]>
---
src/arith/iter_affine_map.cc | 7 +++--
.../schedule/primitive/loop_transformation.cc | 9 +++---
tests/python/s_tir/dlight/test_gpu_fallback.py | 4 +--
.../s_tir/dlight/test_gpu_general_reduction.py | 8 ++---
.../meta_schedule/test_meta_schedule_space_cuda.py | 2 +-
.../s_tir/schedule/test_tir_schedule_split_fuse.py | 35 ++++++++++++++++++++++
.../test_s_tir_transform_default_gpu_schedule.py | 2 +-
7 files changed, 51 insertions(+), 16 deletions(-)
diff --git a/src/arith/iter_affine_map.cc b/src/arith/iter_affine_map.cc
index 0e996485a4..a4d5097167 100644
--- a/src/arith/iter_affine_map.cc
+++ b/src/arith/iter_affine_map.cc
@@ -2115,7 +2115,8 @@ class IterMapToExprNormalizer : public ExprMutator {
}
if (analyzer_->CanProve(expr->extent == expr->source->extent) &&
is_one(expr->lower_factor)) {
return source * expr->scale;
- } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor
* expr->extent)) {
+ } else if (analyzer_->CanProve(expr->source->extent == expr->lower_factor
* expr->extent) ||
+ analyzer_->CanProve(expr->source->extent == expr->extent *
expr->lower_factor)) {
// Simplify if `expr` is always 0. The 2nd condition guarantess that we
do not aggressively
// simplify trivial iters like `vi \in [0, 1)`, which can be useful for
subsequent analysis
// like tensorization.
@@ -2124,8 +2125,8 @@ class IterMapToExprNormalizer : public ExprMutator {
}
return floordiv(source, expr->lower_factor) * expr->scale;
} else {
- return floordiv(floormod(source, expr->lower_factor * expr->extent),
expr->lower_factor) *
- expr->scale;
+ PrimExpr full_extent = analyzer_->canonical_simplify(expr->extent *
expr->lower_factor);
+ return floordiv(floormod(source, full_extent), expr->lower_factor) *
expr->scale;
}
}
diff --git a/src/s_tir/schedule/primitive/loop_transformation.cc
b/src/s_tir/schedule/primitive/loop_transformation.cc
index ae35d1c91c..f43b1f2a2e 100644
--- a/src/s_tir/schedule/primitive/loop_transformation.cc
+++ b/src/s_tir/schedule/primitive/loop_transformation.cc
@@ -923,15 +923,16 @@ StmtSRef Fuse(ScheduleState self, const
ffi::Array<StmtSRef>& loop_srefs,
bits = std::max(bits, loops[i]->loop_var.dtype().bits());
}
suffix += "_fused";
+
Var fused_var =
loops[0]->loop_var.copy_with_suffix(suffix).copy_with_dtype(DataType::Int(bits));
ffi::Array<PrimExpr> substitute_value;
substitute_value.resize(loops.size());
PrimExpr lower = 1;
for (int i = static_cast<int>(loops.size()) - 1; i > 0; i--) {
- substitute_value.Set(i, is_one(loops[i]->extent)
- ? 0
- : floordiv(floormod(fused_var, lower *
loops[i]->extent), lower));
- lower = lower * loops[i]->extent;
+ PrimExpr next_lower = analyzer.canonical_simplify(loops[i]->extent *
lower);
+ substitute_value.Set(
+ i, is_one(loops[i]->extent) ? 0 : floordiv(floormod(fused_var,
next_lower), lower));
+ lower = next_lower;
}
substitute_value.Set(0, is_one(loops[0]->extent) ? 0 : floordiv(fused_var,
lower));
Stmt new_stmt = loops.back()->body;
diff --git a/tests/python/s_tir/dlight/test_gpu_fallback.py
b/tests/python/s_tir/dlight/test_gpu_fallback.py
index 58f9cc4ad1..72cf06a2ac 100644
--- a/tests/python/s_tir/dlight/test_gpu_fallback.py
+++ b/tests/python/s_tir/dlight/test_gpu_fallback.py
@@ -162,8 +162,8 @@ def test_fallback_irregular_spatial():
for ax0_ax1_ax2_fused_0 in T.thread_binding((nlayer * nhead * seqlen +
1023) // 1024, thread="blockIdx.x"):
for ax0_ax1_ax2_fused_1 in T.thread_binding(1024,
thread="threadIdx.x"):
with T.sblock("block"):
- v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % (seqlen * nhead * nlayer) // (seqlen * nhead))
- v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % (seqlen * nhead) // seqlen)
+ v0 = T.axis.spatial(nlayer, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) // (nhead * seqlen))
+ v1 = T.axis.spatial(nhead, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % (nhead * seqlen) // seqlen)
v2 = T.axis.spatial(seqlen, (ax0_ax1_ax2_fused_0 * 1024 +
ax0_ax1_ax2_fused_1) % seqlen)
T.where(ax0_ax1_ax2_fused_0 * 1024 + ax0_ax1_ax2_fused_1 <
nlayer * nhead * seqlen)
T.reads(pages[page_table_values[page_table_indptr[seq_id]
+ v2 // page_size], v0, v1, v2 % page_size],
page_table_values[page_table_indptr[seq_id] + v2 // page_size],
page_table_indptr[seq_id])
diff --git a/tests/python/s_tir/dlight/test_gpu_general_reduction.py
b/tests/python/s_tir/dlight/test_gpu_general_reduction.py
index 7c6b995b37..fbdbf1b82b 100644
--- a/tests/python/s_tir/dlight/test_gpu_general_reduction.py
+++ b/tests/python/s_tir/dlight/test_gpu_general_reduction.py
@@ -616,7 +616,7 @@ def test_logsumexp():
with T.sblock("max"):
v0 = T.axis.spatial(
batch_size,
- ax0_ax1_fused % (num_chunks * batch_size)
// num_chunks + ax0,
+ ax0_ax1_fused // num_chunks + ax0,
)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused
% num_chunks + ax1)
v2 = T.axis.reduce(
@@ -646,7 +646,7 @@ def test_logsumexp():
with T.sblock("sum_exp"):
v0 = T.axis.spatial(
batch_size,
- ax0_ax1_fused % (num_chunks * batch_size)
// num_chunks + ax0,
+ ax0_ax1_fused // num_chunks + ax0,
)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused
% num_chunks + ax1)
v2 = T.axis.reduce(
@@ -677,9 +677,7 @@ def test_logsumexp():
},
):
with T.sblock("log"):
- v0 = T.axis.spatial(
- batch_size, ax0_ax1_fused % (num_chunks *
batch_size) // num_chunks
- )
+ v0 = T.axis.spatial(batch_size, ax0_ax1_fused //
num_chunks)
v1 = T.axis.spatial(num_chunks, ax0_ax1_fused %
num_chunks)
v2 = T.axis.spatial(T.int64(1), ax2_0 *
T.int64(256) + ax2_1)
T.where(ax2_0 * T.int64(256) + ax2_1 < T.int64(1))
diff --git a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
index d99ce2fdcf..177b10f2c1 100644
--- a/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
+++ b/tests/python/s_tir/meta_schedule/test_meta_schedule_space_cuda.py
@@ -724,7 +724,7 @@ def test_cuda_t2d():
for ax0_ax1_ax2_ax3_fused in range(rh_0 % 2 * 96 +
96):
with T.sblock("PadInput_shared"):
v0 = T.axis.spatial(1, 0)
- v1 = T.axis.spatial(6,
n_0_h_0_w_0_co_0_fused // 64 + rh_0 // 2 + ax0_ax1_ax2_ax3_fused % (96 * (rh_0
% 2 + 1)) // 96)
+ v1 = T.axis.spatial(6,
n_0_h_0_w_0_co_0_fused // 64 + rh_0 // 2 + ax0_ax1_ax2_ax3_fused % (rh_0 % 2 *
96 + 96) // 96)
v2 = T.axis.spatial(6,
n_0_h_0_w_0_co_0_fused % 64 // 16 + ax0_ax1_ax2_ax3_fused % 96 // 32)
v3 = T.axis.spatial(512, rc_0 * 32 +
ax0_ax1_ax2_ax3_fused % 32)
T.reads(inputs[v0, v1 - 1, v2 - 1, v3])
diff --git a/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
b/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
index 3fe374be3d..afa28f5ef6 100644
--- a/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
+++ b/tests/python/s_tir/schedule/test_tir_schedule_split_fuse.py
@@ -824,5 +824,40 @@ def test_unsupported_target_scalable_split(capfd):
assert warning_msg in captured
+def test_fused_symbolic_2D_tiling():
+ @T.prim_func
+ def before(a: T.handle, b: T.handle, M: T.int32, N: T.int32) -> None:
+ A = T.match_buffer(a, (M, N))
+ B = T.match_buffer(b, (M, N))
+ for i, j in T.grid(M, N):
+ with T.sblock("B"):
+ vi, vj = T.axis.remap("SS", [i, j])
+ B[vi, vj] = A[vi, vj] * 2.0
+
+ @T.prim_func
+ def expected(a: T.handle, b: T.handle, M: T.int32, N: T.int32) -> None:
+ A = T.match_buffer(a, (M, N))
+ B = T.match_buffer(b, (M, N))
+ for i_0_j_0_fused, i_1, j_1 in T.grid(((M + 63) // 64) * ((N + 15) //
16), 64, 16):
+ with T.sblock("B"):
+ vi = T.axis.spatial(M, i_0_j_0_fused // ((N + 15) // 16) * 64
+ i_1)
+ vj = T.axis.spatial(N, i_0_j_0_fused % ((N + 15) // 16) * 16 +
j_1)
+ T.where(
+ i_0_j_0_fused // ((N + 15) // 16) * 64 + i_1 < M
+ and i_0_j_0_fused % ((N + 15) // 16) * 16 + j_1 < N
+ )
+ B[vi, vj] = A[vi, vj] * T.float32(2.0)
+
+ sch = tvm.s_tir.Schedule(before, debug_mask="all")
+ block_b = sch.get_sblock("B")
+ i, j = sch.get_loops(block_b)
+ i0, i1 = sch.split(i, factors=[None, 64])
+ j0, j1 = sch.split(j, factors=[None, 16])
+ sch.reorder(i0, j0, i1, j1)
+ sch.fuse(i0, j0)
+ assert_structural_equal_ignore_global_symbol(expected, sch.mod["main"])
+ verify_trace_roundtrip(sch=sch, mod=before)
+
+
if __name__ == "__main__":
tvm.testing.main()
diff --git
a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
index 0212585b53..c562a29e87 100644
--- a/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
+++ b/tests/python/s_tir/transform/test_s_tir_transform_default_gpu_schedule.py
@@ -55,7 +55,7 @@ def test_broadcast_to_symbolic():
for ax0_ax1_fused_2 in T.thread_binding(T.int64(1024),
thread="threadIdx.x"):
for ax0_ax1_fused_0 in range((x_0 * x_1 + T.int64(262143))
// T.int64(262144)):
with T.sblock("T_broadcast_to"):
- v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 *
T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % (x_1 *
x_0) // x_1)
+ v_ax0 = T.axis.spatial(x_0, (ax0_ax1_fused_0 *
T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) // x_1)
v_ax1 = T.axis.spatial(x_1, (ax0_ax1_fused_0 *
T.int64(262144) + ax0_ax1_fused_1 * T.int64(1024) + ax0_ax1_fused_2) % x_1)
T.where((ax0_ax1_fused_0 * T.int64(256) +
ax0_ax1_fused_1) * T.int64(1024) + ax0_ax1_fused_2 < x_0 * x_1)
T_broadcast_to[v_ax0, v_ax1] =
rxplaceholder[v_ax0, T.int64(0)]