This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 2b87313c98 [ARITH] Expose allow_override parameter in Python
Analyzer.bind() (#19417)
2b87313c98 is described below
commit 2b87313c981e9a9c5695a96cc56e094e311657b9
Author: Fabian Peddinghaus <[email protected]>
AuthorDate: Fri Apr 24 02:31:13 2026 +0200
[ARITH] Expose allow_override parameter in Python Analyzer.bind() (#19417)
The C++ Analyzer::Bind() already supports allow_override, but the FFI
bridge always used the default (false). This change threads the optional
argument through the FFI layer and the Python wrapper so callers can
rebind variables without triggering an error.
---
python/tvm/arith/analyzer.py | 12 ++++++++++--
src/arith/analyzer.cc | 5 +++--
tests/python/arith/test_arith_simplify.py | 14 ++++++++++++++
3 files changed, 27 insertions(+), 4 deletions(-)
diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py
index fc5d3c9aea..ea70c4de3d 100644
--- a/python/tvm/arith/analyzer.py
+++ b/python/tvm/arith/analyzer.py
@@ -266,7 +266,12 @@ class Analyzer:
"""
return self._can_prove(expr, strength)
- def bind(self, var: tirx.Var, expr: tirx.PrimExpr | ir.Range) -> None:
+ def bind(
+ self,
+ var: tirx.Var,
+ expr: tirx.PrimExpr | ir.Range,
+ allow_override: bool = False,
+ ) -> None:
"""Bind a variable to the expression.
Parameters
@@ -276,8 +281,11 @@ class Analyzer:
expr : Union[tirx.PrimExpr, ir.Range]
The expression or the range to bind to.
+
+ allow_override : bool
+ Whether to allow overriding an existing binding for the variable.
"""
- return self._bind(var, expr)
+ return self._bind(var, expr, allow_override)
def constraint_scope(self, constraint: tirx.PrimExpr) -> ConstraintScope:
"""Create a constraint scope.
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index f823e9efca..7f3734266b 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -326,10 +326,11 @@ TVM_FFI_STATIC_INIT_BLOCK() {
});
} else if (name == "bind") {
return ffi::Function([self](ffi::PackedArgs args, ffi::Any* ret) {
+ bool allow_override = args.size() >= 3 && args[2].cast<bool>();
if (auto opt_range = args[1].try_cast<Range>()) {
- self->Bind(args[0].cast<Var>(), opt_range.value());
+ self->Bind(args[0].cast<Var>(), opt_range.value(), allow_override);
} else {
- self->Bind(args[0].cast<Var>(), args[1].cast<PrimExpr>());
+ self->Bind(args[0].cast<Var>(), args[1].cast<PrimExpr>(),
allow_override);
}
});
} else if (name == "can_prove") {
diff --git a/tests/python/arith/test_arith_simplify.py
b/tests/python/arith/test_arith_simplify.py
index b367735c1f..d30109fc44 100644
--- a/tests/python/arith/test_arith_simplify.py
+++ b/tests/python/arith/test_arith_simplify.py
@@ -134,6 +134,20 @@ def test_regression_simplify_inf_recursion():
ana.rewrite_simplify(res)
+def test_bind_allow_override():
+ ana = tvm.arith.Analyzer()
+ x = tirx.Var("x", "int64")
+
+ ana.bind(x, tvm.ir.Range(0, 10))
+ ana.bind(x, tvm.ir.Range(0, 5), allow_override=True)
+ assert ana.can_prove(x < 5)
+
+ with pytest.raises(
+ tvm.error.TVMError, match="Trying to update var 'x' with a different
const bound"
+ ):
+ ana.bind(x, tvm.ir.Range(0, 3))
+
+
def test_simplify_floor_mod_with_linear_offset():
"""
Test that the floor_mod is simplified correctly when the offset is linear.