This is an automated email from the ASF dual-hosted git repository.

spectrometerHBH pushed a commit to branch tir-bench
in repository https://gitbox.apache.org/repos/asf/tvm.git

commit f715671412df16eeb9ed4455597be6e0b67b5a20
Author: Hongyi Jin <[email protected]>
AuthorDate: Thu May 28 14:57:11 2026 -0400

    fix(arith): gate canonical-simplify LT Case 2 on extra scale == +1 (#651)
    
    CanonicalSimplifier::Impl::VisitExpr_(LTNode) Case 2 rewrites a
    "scaled-by-d sum plus a single leftover split" comparison
    
        S + xn < 0  ⇔  S/d + (xn // d) < 0           where d = gcd(scales)
    
    into one where the leftover yn % m gets replaced by
    floormod(floordiv(yn, d*L), m/(d*L)). The Case 1 derivation that
    justifies dropping the remainder xn % d ∈ [0, d) only works when xn ≥ 0.
    With scale = -1 the equivalence becomes ≤ rather than <, and the rewrite
    silently strengthens the predicate by dropping the boundary case
    S/d == xn // d.
    
    This surfaced as a miscompile in TIRx kernels that mask a per-lane
    write by `row > col`, where `row = (lane_id // 4) + 16 * warp_id` and
    `col = 2 * (lane_id % 4)` are independent projections of the same
    lane id. After CSE+inlining the comparison hit canonical_simplify with
    the divided projection on the LHS (scale = -1), and Case 2 folded
    `2*(tx%4) < 16*warp + (tx%32)//4` into a plain `0 < warp_id`, zeroing
    every thread that should have written `val` in warp 0. The same path
    also folded other configurations (e.g. `0 < (tx%32) - 8*warp`) all the
    way to False.
    
    Gate Case 2 with `extra->args[0]->scale == 1`. The original target
    shape (`(yn % m)` with positive scale and lower_factor=1, as well as
    the scale=+1 + lower_factor>1 generalization) is unchanged; both are
    covered by the existing `test_simplify_le` cases and by the new
    `test_simplify_le_negative_scale_extra` regression test, which also
    pins the buggy scale=-1 shape to its unsimplified form and re-asserts
    that the truly-always-true `r=2` variant still folds to True.
---
 src/arith/canonical_simplify.cc                    | 11 +++++-
 .../python/arith/test_arith_canonical_simplify.py  | 44 ++++++++++++++++++++++
 2 files changed, 53 insertions(+), 2 deletions(-)

diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index ac1b89f97a..0001afbdfe 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -1419,10 +1419,17 @@ PrimExpr CanonicalSimplifier::Impl::VisitExpr_(const 
LTNode* op) {
       // Case 1. 0 <= xn < d
       divisible.CopyOnWrite()->DivideBy(gcd);
       return Rewriter::VisitExpr(divisible->Normalize() < make_zero(dtype));
-    } else if (extra->args.size() == 1 &&
+    } else if (extra->args.size() == 1 && extra->args[0]->scale == 1 &&
                extra->args[0]->upper_factor != ConstIntBoundNode::kPosInf &&
                extra->args[0]->upper_factor % (gcd * 
extra->args[0]->lower_factor) == 0) {
-      // Case 2. xn == yn % m, where m % d == 0
+      // Case 2. xn == ((yn % m) // L), scale = +1, m % (d*L) == 0.
+      // S + xn < 0 with S divisible by d  ⇔  S/d + xn // d < 0, because
+      // xn % d ∈ [0, d) lets us drop the remainder via the Case 1 argument,
+      // and xn // d = (yn // (d*L)) % (m/(d*L)).
+      // The scale must be +1: with scale = -1 the equivalence becomes ≤
+      // rather than <, so the rewrite would strengthen the predicate and
+      // silently drop the boundary S/d == xn // d (e.g. row > col where
+      // row and col are independent projections of the same lane id).
       divisible.CopyOnWrite()->DivideBy(gcd);
       const auto split_expr = extra->args[0];
       int64_t lower_factor = gcd * extra->args[0]->lower_factor;
diff --git a/tests/python/arith/test_arith_canonical_simplify.py 
b/tests/python/arith/test_arith_canonical_simplify.py
index ce89db9c99..49f480bcce 100644
--- a/tests/python/arith/test_arith_canonical_simplify.py
+++ b/tests/python/arith/test_arith_canonical_simplify.py
@@ -488,5 +488,49 @@ def test_simplify_le():
     ck.verify(x * 1024 + y < z * 7168, x - z * 7 < 0)
 
 
+def test_simplify_le_negative_scale_extra():
+    """Regression: Case 2 of the LT-with-divisible-coeffs rewrite must not
+    fire when the leftover split term has a negative scale.
+
+    The rewrite ``S + xn < 0  ⇔  S/d + xn // d < 0`` is only sound when
+    the leftover ``xn`` has scale ``+1``. With scale ``-1`` the equivalence
+    becomes ``≤`` rather than ``<`` and the rewrite silently strengthens
+    the predicate. The original bug surfaced as ``row > col`` masks of
+    ``.16x*b`` tcgen05 readbacks collapsing to plain ``warp_id > k``
+    comparisons (lower-triangle writes were silently dropped on the
+    boundary warp).
+    """
+    ck = CanonicalChecker()
+    tx = tvm.tirx.Var("tx", "int32")
+    warp = tvm.tirx.Var("warp", "int32")
+    ck.analyzer.bind(tx, tvm.ir.Range(0, 128))
+    ck.analyzer.bind(warp, tvm.ir.Range(0, 4))
+
+    # Same-source joint projection: the comparison genuinely depends on tx
+    # at warp == 0 (e.g. tx == 4 ⇒ 0 < 1 = True; tx == 1 ⇒ 2 < 0 = False),
+    # so the simplifier must keep both sides.  Pre-fix this folded to
+    # ``0 < warp`` and dropped every True case in warp 0.
+    expr = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4
+    ck.verify(expr, expr)
+
+    # The simpler ``scale = -1`` with ``lower_factor = 1`` shape.  Pre-fix
+    # this folded to ``False`` (drops all warp >= 1 cases where the rhs
+    # actually exceeds 8*warp).
+    expr = warp * 8 < (tx % 32)
+    ck.verify(expr, expr)
+
+    # The corresponding ``scale = +1`` Case 2 path (the rewrite this guards)
+    # must still optimize — verifies we did not over-restrict.
+    x1 = tvm.tirx.Var("x1", "int32")
+    y1 = tvm.tirx.Var("y1", "int32")
+    ck.verify(x1 * 64 + (y1 % 64) < 120, x1 * 8 + (y1 % 64) // 8 < 15)
+
+    # The truly-always-true comparison that arises from the same kernel
+    # (``r = 2 / va = 1`` in the tcgen05.ld.16x256b readback) must still
+    # fold to True so the masked store can be elided.
+    expr_true = (tx % 4) * 2 < warp * 16 + (tx % 32) // 4 + 8
+    ck.verify(expr_true, tvm.tirx.const(True, "bool"))
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to