This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new e4c5b7ca60 [BUGFIX][TIR] Skip bool-typed expressions in CSE (#19502)
e4c5b7ca60 is described below

commit e4c5b7ca6057aa834b8f5f428b9f2b9437ab4105
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun May 3 21:11:30 2026 -0400

    [BUGFIX][TIR] Skip bool-typed expressions in CSE (#19502)
    
    ## Summary
    
    The TIR CSE pass currently lifts bool-typed sub-expressions like `i < n`
    or `a && b` into `cse_v: bool = ...` bindings whenever they appear
    twice. Boolean expressions are almost always predicates feeding `if` /
    `Select` / `assert`, where reading the condition inline is clearer than
    going through a boolean temporary, and where downstream simplification
    (ProveCondition, branch elimination) benefits from seeing the predicate
    directly.
    
    - Extend `CSEPlanner::IsEligible` in
    `src/tirx/transform/common_subexpr_elim.cc` to reject any compound
    expression whose result dtype is `bool`.
    - Update the file-level `Eligibility rules` doc-comment and the
    per-function `IsEligible` docstring to document the new rule.
    - Add two regression tests (`test_no_lift_bool_predicate`,
    `test_no_lift_bool_logical`) covering comparison predicates and
    logical-And predicates respectively.
---
 src/tirx/transform/common_subexpr_elim.cc          | 14 +++++++
 .../test_tir_transform_common_subexpr_elim.py      | 43 ++++++++++++++++++++++
 2 files changed, 57 insertions(+)

diff --git a/src/tirx/transform/common_subexpr_elim.cc 
b/src/tirx/transform/common_subexpr_elim.cc
index 38925dc25a..9e7b2b1fb7 100644
--- a/src/tirx/transform/common_subexpr_elim.cc
+++ b/src/tirx/transform/common_subexpr_elim.cc
@@ -49,6 +49,10 @@
  *   - It is not a leaf (Var, IntImm, FloatImm, StringImm).
  *   - It does not contain Call or BufferLoad (side-effects / memory 
dependence).
  *   - It is not Ramp or Broadcast (hardware-specific vector ops).
+ *   - It is not bool-typed. Boolean predicates are kept inline because the
+ *     consumer (if / Select / assert) reads more clearly with the condition
+ *     spelled out, and downstream simplification benefits from seeing the
+ *     predicate directly.
  *
  * Scope tree
  * ----------
@@ -263,6 +267,8 @@ class CSEPlanner : public StmtExprVisitor {
    *   - Not a Call or BufferLoad (side effects / memory dependence).
    *   - Not Ramp or Broadcast (hardware-specific vector construction).
    *   - Does not transitively contain any forbidden node.
+   *   - Is not bool-typed (predicates are kept inline for readability and
+   *     downstream simplification).
    *
    * \param expr The expression to check.
    * \return true if the expression can participate in CSE.
@@ -274,6 +280,14 @@ class CSEPlanner : public StmtExprVisitor {
     }
     if (IsForbiddenNode(expr)) return false;
     if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false;
+    // Reject bool-typed expressions. Boolean predicates almost always feed an
+    // if / Select / assert, where reading the condition inline is clearer than
+    // going through a `cse_v: bool = (a < b)` temporary, and where downstream
+    // simplification (ProveCondition, branch elimination) benefits from seeing
+    // the predicate directly. BoolImm is already filtered above as an IntImm
+    // leaf, so this rule only affects compound bool expressions
+    // (LT/LE/GT/GE/EQ/NE/And/Or/Not/Cast-to-bool/Select-of-bool).
+    if (expr.dtype().is_bool()) return false;
     if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false;
     return true;
   }
diff --git 
a/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py 
b/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py
index 8786720a25..e025ae88a9 100644
--- a/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py
+++ b/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py
@@ -713,6 +713,47 @@ def test_let_floordiv_pattern():
     assert "cse_v" not in script, f"CSE incorrectly extracted from Let 
body:\n{script}"
 
 
+# =====================================================================
+# T22: No lifting of bool predicate (comparison expression)
+# A duplicated `i < n` feeds two if-statements.  CSE must leave it
+# inline rather than hoisting a `cse_v: bool = (i < n)` binding.
+# =====================================================================
+def test_no_lift_bool_predicate():
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def main(B: T.Buffer((50,), "int32"), n: T.int32, x: T.int32):
+            for i in range(50):
+                if i < n:
+                    B[i] = x
+                if i < n:
+                    B[i] = x + 1
+
+    after = tvm.tirx.transform.CommonSubexprElim()(Before)
+    tvm.ir.assert_structural_equal(after, Before)
+    assert "cse_v" not in after["main"].script()
+
+
+# =====================================================================
+# T23: No lifting of bool logical expression (And)
+# A duplicated `a && b` feeds two if-statements.  CSE must leave it
+# inline rather than hoisting a `cse_v: bool = T.And(a, b)` binding.
+# =====================================================================
+def test_no_lift_bool_logical():
+    @tvm.script.ir_module
+    class Before:
+        @T.prim_func
+        def main(B: T.Buffer((50,), "int32"), a: T.bool, b: T.bool, x: 
T.int32):
+            if T.And(a, b):
+                B[0] = x
+            if T.And(a, b):
+                B[1] = x + 1
+
+    after = tvm.tirx.transform.CommonSubexprElim()(Before)
+    tvm.ir.assert_structural_equal(after, Before)
+    assert "cse_v" not in after["main"].script()
+
+
 if __name__ == "__main__":
     test_basic()
     test_if_single_branch()
@@ -735,3 +776,5 @@ if __name__ == "__main__":
     test_let_value_cse()
     test_nested_let_no_extraction()
     test_let_floordiv_pattern()
+    test_no_lift_bool_predicate()
+    test_no_lift_bool_logical()

Reply via email to