This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 7bd73e5ad2 [Arith] Restrict floormod coefficient reduction to keep
DetectIterMapstable (#19832)
7bd73e5ad2 is described below
commit 7bd73e5ad22e359aff273165eddc47ff9c5373ee
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 18 17:26:48 2026 -0400
[Arith] Restrict floormod coefficient reduction to keep DetectIterMapstable
(#19832)
This PR fixes #19825, which restricts the rewrites
```
floormod(x * c1 + y, c2) -> floormod(x * floormod(c1, c2) + y, c2)
```
and
```
floormod(x + y * c1, c2) -> floormod(x + y * floormod(c1, c2), c2).
```
While algebraically valid in isolation, these transformations rewrite
only the `floormod` side of a matching `floordiv`/`floormod` pair. As a
result, the two expressions no longer share a visible fused index
expression, causing `DetectIterMap` to reject otherwise bijective splits
such as:
```
lane = flat % 128
reg = flat // 128
```
where both expressions originate from the same fused index.
### Context
Per the suggestion in #19825, the two rewrites are guarded with `c1 % c2
== 0` rather than dropped outright. The multiplied term is still
eliminated when it is a multiple of the divisor (e.g. `(x*10 + y) % 2 ->
y % 2`), which is safe for `DetectIterMap`; only the
coefficient-shrinking case (`c1` not a multiple of `c2`) is disabled.
Both operand orderings are covered, and the PR adds a rewrite-simplify
regression plus an end-to-end `DetectIterMap` regression test.
---
src/arith/rewrite_simplify.cc | 4 ++--
src/relax/op/image/resize.cc | 7 +++----
tests/python/arith/test_arith_iter_affine_map.py | 15 +++++++++++++++
tests/python/arith/test_arith_rewrite_simplify.py | 7 +++++--
4 files changed, 25 insertions(+), 8 deletions(-)
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index b5b0cc604e..6d6ce03016 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1249,7 +1249,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorModNode* op) {
CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));
TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2)
+ y, c2),
- c2.Eval()->value > 0);
+ c2.Eval()->value > 0 && c1.Eval()->value %
c2.Eval()->value == 0);
// (x + 5) % 2 -> (x + 1) %2, (x + 3) % 3 => x
TVM_TRY_REWRITE_IF(
@@ -1257,7 +1257,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const
FloorModNode* op) {
c2.Eval()->value > 0 && (c1.Eval()->value >= c2.Eval()->value ||
c1.Eval()->value < 0));
TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1,
c2), c2),
- c2.Eval()->value > 0);
+ c2.Eval()->value > 0 && c1.Eval()->value %
c2.Eval()->value == 0);
TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2),
c2.Eval()->value != 0);
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 653ea04c63..eab2f68499 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -311,10 +311,9 @@ StructInfo InferStructInfoGridSample(const Call& call,
const BlockBuilder& ctx)
// treated as the 2D NCHW path so existing behavior is preserved.
const bool is_ncdhw = (attrs->layout == "NCDHW");
- auto [data_layout, data2tgt] =
- CheckTensorLayout(call, ctx, attrs->layout,
- /*tgt_layout=*/is_ncdhw ? "NCDHW" : "NCHW",
- /*tensor_name=*/"data");
+ auto [data_layout, data2tgt] = CheckTensorLayout(call, ctx, attrs->layout,
+ /*tgt_layout=*/is_ncdhw ?
"NCDHW" : "NCHW",
+ /*tensor_name=*/"data");
DataType out_dtype = data_sinfo->dtype;
diff --git a/tests/python/arith/test_arith_iter_affine_map.py
b/tests/python/arith/test_arith_iter_affine_map.py
index fdbb65a0bd..ef1c9f132c 100644
--- a/tests/python/arith/test_arith_iter_affine_map.py
+++ b/tests/python/arith/test_arith_iter_affine_map.py
@@ -201,6 +201,21 @@ def test_split():
)
+def test_split_simplified_modulo():
+ # regression for #19825: simplifying the modulo must not break the fused
split
+ i = tvm.tirx.Var("i", "int32")
+ j = tvm.tirx.Var("j", "int32")
+ dom = var_dom([(i, 64), (j, 192)])
+ analyzer = tvm.arith.Analyzer()
+
+ for flat in [i * 192 + j, j + i * 192]:
+ lane = analyzer.simplify(floormod(flat, 128))
+ quotient = floordiv(flat, 128)
+ res = tvm.arith.detect_iter_map([lane, quotient], dom,
check_level="bijective")
+ assert len(res.errors) == 0, res.errors
+ assert len(res.indices) == 2
+
+
def test_compound():
x = tvm.tirx.Var("x", "int32")
y = tvm.tirx.Var("y", "int32")
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py
b/tests/python/arith/test_arith_rewrite_simplify.py
index e0ef9da822..a6887db8be 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -675,10 +675,13 @@ class TestFloormodIndex(BaseCompare):
TestCase(flm(x * 10, 2), 0),
TestCase(flm(x * 9600, 6400), flm(x * 3200, 6400)),
TestCase(flm(x * 10 + y, 2), flm(y, 2)),
- TestCase(flm(x * 360 + y, 16), flm(x * 8 + y, 16)),
+ # coefficient is not shrunk unless it is a multiple of the divisor
(#19825)
+ TestCase(flm(x * 360 + y, 16), flm(x * 360 + y, 16)),
+ TestCase(flm(x * 192 + y, 128), flm(x * 192 + y, 128)),
TestCase(flm(x + 10, 2), flm(x, 2)),
TestCase(flm(x + y * 10, 2), flm(x, 2)),
- TestCase(flm(x + y * 360, 16), flm(x + y * 8, 16)),
+ TestCase(flm(x + y * 360, 16), flm(x + y * 360, 16)),
+ TestCase(flm(x + y * 192, 128), flm(x + y * 192, 128)),
TestCase(flm(x * (-10), 2), 0),
TestCase(flm(x * (-10) + y, 2), flm(y, 2)),
TestCase(flm(x + (-10), 2), flm(x, 2)),