This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 7bd73e5ad2 [Arith] Restrict floormod coefficient reduction to keep 
DetectIterMapstable (#19832)
7bd73e5ad2 is described below

commit 7bd73e5ad22e359aff273165eddc47ff9c5373ee
Author: Shushi Hong <[email protected]>
AuthorDate: Thu Jun 18 17:26:48 2026 -0400

    [Arith] Restrict floormod coefficient reduction to keep DetectIterMapstable 
(#19832)
    
    This PR fixes #19825, which restricts the rewrites
    ```
    floormod(x * c1 + y, c2) -> floormod(x * floormod(c1, c2) + y, c2)
    ```
    and
    
    ```
    floormod(x + y * c1, c2) -> floormod(x + y * floormod(c1, c2), c2).
    ```
    While algebraically valid in isolation, these transformations rewrite
    only the `floormod` side of a matching `floordiv`/`floormod` pair. As a
    result, the two expressions no longer share a visible fused index
    expression, causing `DetectIterMap` to reject otherwise bijective splits
    such as:
    
    ```
    lane = flat % 128
    reg  = flat // 128
    ```
    where both expressions originate from the same fused index.
    
    ### Context
    
    Per the suggestion in #19825, the two rewrites are guarded with `c1 % c2
    == 0` rather than dropped outright. The multiplied term is still
    eliminated when it is a multiple of the divisor (e.g. `(x*10 + y) % 2 ->
    y % 2`), which is safe for `DetectIterMap`; only the
    coefficient-shrinking case (`c1` not a multiple of `c2`) is disabled.
    Both operand orderings are covered, and the PR adds a rewrite-simplify
    regression plus an end-to-end `DetectIterMap` regression test.
---
 src/arith/rewrite_simplify.cc                     |  4 ++--
 src/relax/op/image/resize.cc                      |  7 +++----
 tests/python/arith/test_arith_iter_affine_map.py  | 15 +++++++++++++++
 tests/python/arith/test_arith_rewrite_simplify.py |  7 +++++--
 4 files changed, 25 insertions(+), 8 deletions(-)

diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index b5b0cc604e..6d6ce03016 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -1249,7 +1249,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
FloorModNode* op) {
                            CanProveEqual(floordiv(y.Eval(), c1.Eval()), 0));
 
     TVM_TRY_REWRITE_IF(floormod(x * c1 + y, c2), floormod(x * floormod(c1, c2) 
+ y, c2),
-                       c2.Eval()->value > 0);
+                       c2.Eval()->value > 0 && c1.Eval()->value % 
c2.Eval()->value == 0);
 
     // (x + 5) % 2 -> (x + 1) %2,  (x + 3) % 3 => x
     TVM_TRY_REWRITE_IF(
@@ -1257,7 +1257,7 @@ PrimExpr RewriteSimplifier::Impl::VisitExpr_(const 
FloorModNode* op) {
         c2.Eval()->value > 0 && (c1.Eval()->value >= c2.Eval()->value || 
c1.Eval()->value < 0));
 
     TVM_TRY_REWRITE_IF(floormod(x + y * c1, c2), floormod(x + y * floormod(c1, 
c2), c2),
-                       c2.Eval()->value > 0);
+                       c2.Eval()->value > 0 && c1.Eval()->value % 
c2.Eval()->value == 0);
 
     TVM_TRY_REWRITE_IF(floormod(x * c1, x * c2), x * floormod(c1, c2), 
c2.Eval()->value != 0);
 
diff --git a/src/relax/op/image/resize.cc b/src/relax/op/image/resize.cc
index 653ea04c63..eab2f68499 100644
--- a/src/relax/op/image/resize.cc
+++ b/src/relax/op/image/resize.cc
@@ -311,10 +311,9 @@ StructInfo InferStructInfoGridSample(const Call& call, 
const BlockBuilder& ctx)
   // treated as the 2D NCHW path so existing behavior is preserved.
   const bool is_ncdhw = (attrs->layout == "NCDHW");
 
-  auto [data_layout, data2tgt] =
-      CheckTensorLayout(call, ctx, attrs->layout,
-                        /*tgt_layout=*/is_ncdhw ? "NCDHW" : "NCHW",
-                        /*tensor_name=*/"data");
+  auto [data_layout, data2tgt] = CheckTensorLayout(call, ctx, attrs->layout,
+                                                   /*tgt_layout=*/is_ncdhw ? 
"NCDHW" : "NCHW",
+                                                   /*tensor_name=*/"data");
 
   DataType out_dtype = data_sinfo->dtype;
 
diff --git a/tests/python/arith/test_arith_iter_affine_map.py 
b/tests/python/arith/test_arith_iter_affine_map.py
index fdbb65a0bd..ef1c9f132c 100644
--- a/tests/python/arith/test_arith_iter_affine_map.py
+++ b/tests/python/arith/test_arith_iter_affine_map.py
@@ -201,6 +201,21 @@ def test_split():
     )
 
 
+def test_split_simplified_modulo():
+    # regression for #19825: simplifying the modulo must not break the fused 
split
+    i = tvm.tirx.Var("i", "int32")
+    j = tvm.tirx.Var("j", "int32")
+    dom = var_dom([(i, 64), (j, 192)])
+    analyzer = tvm.arith.Analyzer()
+
+    for flat in [i * 192 + j, j + i * 192]:
+        lane = analyzer.simplify(floormod(flat, 128))
+        quotient = floordiv(flat, 128)
+        res = tvm.arith.detect_iter_map([lane, quotient], dom, 
check_level="bijective")
+        assert len(res.errors) == 0, res.errors
+        assert len(res.indices) == 2
+
+
 def test_compound():
     x = tvm.tirx.Var("x", "int32")
     y = tvm.tirx.Var("y", "int32")
diff --git a/tests/python/arith/test_arith_rewrite_simplify.py 
b/tests/python/arith/test_arith_rewrite_simplify.py
index e0ef9da822..a6887db8be 100644
--- a/tests/python/arith/test_arith_rewrite_simplify.py
+++ b/tests/python/arith/test_arith_rewrite_simplify.py
@@ -675,10 +675,13 @@ class TestFloormodIndex(BaseCompare):
         TestCase(flm(x * 10, 2), 0),
         TestCase(flm(x * 9600, 6400), flm(x * 3200, 6400)),
         TestCase(flm(x * 10 + y, 2), flm(y, 2)),
-        TestCase(flm(x * 360 + y, 16), flm(x * 8 + y, 16)),
+        # coefficient is not shrunk unless it is a multiple of the divisor 
(#19825)
+        TestCase(flm(x * 360 + y, 16), flm(x * 360 + y, 16)),
+        TestCase(flm(x * 192 + y, 128), flm(x * 192 + y, 128)),
         TestCase(flm(x + 10, 2), flm(x, 2)),
         TestCase(flm(x + y * 10, 2), flm(x, 2)),
-        TestCase(flm(x + y * 360, 16), flm(x + y * 8, 16)),
+        TestCase(flm(x + y * 360, 16), flm(x + y * 360, 16)),
+        TestCase(flm(x + y * 192, 128), flm(x + y * 192, 128)),
         TestCase(flm(x * (-10), 2), 0),
         TestCase(flm(x * (-10) + y, 2), flm(y, 2)),
         TestCase(flm(x + (-10), 2), flm(x, 2)),

Reply via email to