This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new e4c5b7ca60 [BUGFIX][TIR] Skip bool-typed expressions in CSE (#19502)
e4c5b7ca60 is described below
commit e4c5b7ca6057aa834b8f5f428b9f2b9437ab4105
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun May 3 21:11:30 2026 -0400
[BUGFIX][TIR] Skip bool-typed expressions in CSE (#19502)
## Summary
The TIR CSE pass currently lifts bool-typed sub-expressions like `i < n`
or `a && b` into `cse_v: bool = ...` bindings whenever they appear
twice. Boolean expressions are almost always predicates feeding `if` /
`Select` / `assert`, where reading the condition inline is clearer than
going through a boolean temporary, and where downstream simplification
(ProveCondition, branch elimination) benefits from seeing the predicate
directly.
- Extend `CSEPlanner::IsEligible` in
`src/tirx/transform/common_subexpr_elim.cc` to reject any compound
expression whose result dtype is `bool`.
- Update the file-level `Eligibility rules` doc-comment and the
per-function `IsEligible` docstring to document the new rule.
- Add two regression tests (`test_no_lift_bool_predicate`,
`test_no_lift_bool_logical`) covering comparison predicates and
logical-And predicates respectively.
---
src/tirx/transform/common_subexpr_elim.cc | 14 +++++++
.../test_tir_transform_common_subexpr_elim.py | 43 ++++++++++++++++++++++
2 files changed, 57 insertions(+)
diff --git a/src/tirx/transform/common_subexpr_elim.cc
b/src/tirx/transform/common_subexpr_elim.cc
index 38925dc25a..9e7b2b1fb7 100644
--- a/src/tirx/transform/common_subexpr_elim.cc
+++ b/src/tirx/transform/common_subexpr_elim.cc
@@ -49,6 +49,10 @@
* - It is not a leaf (Var, IntImm, FloatImm, StringImm).
* - It does not contain Call or BufferLoad (side-effects / memory
dependence).
* - It is not Ramp or Broadcast (hardware-specific vector ops).
+ * - It is not bool-typed. Boolean predicates are kept inline because the
+ * consumer (if / Select / assert) reads more clearly with the condition
+ * spelled out, and downstream simplification benefits from seeing the
+ * predicate directly.
*
* Scope tree
* ----------
@@ -263,6 +267,8 @@ class CSEPlanner : public StmtExprVisitor {
* - Not a Call or BufferLoad (side effects / memory dependence).
* - Not Ramp or Broadcast (hardware-specific vector construction).
* - Does not transitively contain any forbidden node.
+ * - Is not bool-typed (predicates are kept inline for readability and
+ * downstream simplification).
*
* \param expr The expression to check.
* \return true if the expression can participate in CSE.
@@ -274,6 +280,14 @@ class CSEPlanner : public StmtExprVisitor {
}
if (IsForbiddenNode(expr)) return false;
if (expr.as<RampNode>() || expr.as<BroadcastNode>()) return false;
+ // Reject bool-typed expressions. Boolean predicates almost always feed an
+ // if / Select / assert, where reading the condition inline is clearer than
+ // going through a `cse_v: bool = (a < b)` temporary, and where downstream
+ // simplification (ProveCondition, branch elimination) benefits from seeing
+ // the predicate directly. BoolImm is already filtered above as an IntImm
+ // leaf, so this rule only affects compound bool expressions
+ // (LT/LE/GT/GE/EQ/NE/And/Or/Not/Cast-to-bool/Select-of-bool).
+ if (expr.dtype().is_bool()) return false;
if (CheckContains::ExprContains(expr, IsForbiddenNode)) return false;
return true;
}
diff --git
a/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py
b/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py
index 8786720a25..e025ae88a9 100644
--- a/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py
+++ b/tests/python/tirx-transform/test_tir_transform_common_subexpr_elim.py
@@ -713,6 +713,47 @@ def test_let_floordiv_pattern():
assert "cse_v" not in script, f"CSE incorrectly extracted from Let
body:\n{script}"
+# =====================================================================
+# T22: No lifting of bool predicate (comparison expression)
+# A duplicated `i < n` feeds two if-statements. CSE must leave it
+# inline rather than hoisting a `cse_v: bool = (i < n)` binding.
+# =====================================================================
+def test_no_lift_bool_predicate():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(B: T.Buffer((50,), "int32"), n: T.int32, x: T.int32):
+ for i in range(50):
+ if i < n:
+ B[i] = x
+ if i < n:
+ B[i] = x + 1
+
+ after = tvm.tirx.transform.CommonSubexprElim()(Before)
+ tvm.ir.assert_structural_equal(after, Before)
+ assert "cse_v" not in after["main"].script()
+
+
+# =====================================================================
+# T23: No lifting of bool logical expression (And)
+# A duplicated `a && b` feeds two if-statements. CSE must leave it
+# inline rather than hoisting a `cse_v: bool = T.And(a, b)` binding.
+# =====================================================================
+def test_no_lift_bool_logical():
+ @tvm.script.ir_module
+ class Before:
+ @T.prim_func
+ def main(B: T.Buffer((50,), "int32"), a: T.bool, b: T.bool, x:
T.int32):
+ if T.And(a, b):
+ B[0] = x
+ if T.And(a, b):
+ B[1] = x + 1
+
+ after = tvm.tirx.transform.CommonSubexprElim()(Before)
+ tvm.ir.assert_structural_equal(after, Before)
+ assert "cse_v" not in after["main"].script()
+
+
if __name__ == "__main__":
test_basic()
test_if_single_branch()
@@ -735,3 +776,5 @@ if __name__ == "__main__":
test_let_value_cse()
test_nested_let_no_extraction()
test_let_floordiv_pattern()
+ test_no_lift_bool_predicate()
+ test_no_lift_bool_logical()