This is an automated email from the ASF dual-hosted git repository.

tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 9f907f5cd0 [Arith] Add Analyzer::Clone for deep-copying analyzer state 
(#19836)
9f907f5cd0 is described below

commit 9f907f5cd099891031a6e8eecb0bb22b963d1dc2
Author: Shushi Hong <[email protected]>
AuthorDate: Fri Jun 19 07:16:19 2026 -0400

    [Arith] Add Analyzer::Clone for deep-copying analyzer state (#19836)
    
    Copying an Analyzer handle shares the same mutable AnalyzerObj, so a
    pass had no way to snapshot accumulated facts (variable bounds, modular
    sets, rewrite/canonical bindings, integer-set domains, literal
    constraints, transitive comparisons) and keep exploring without mutating
    the original.
    
    This pr adds AnalyzerObj::Clone(), which allocates a fresh AnalyzerObj
    and copies each sub-analyzer's persistent state through a new
    per-sub-analyzer CopyFrom. Parent back-pointers are re-established by
    the fresh constructor rather than copied, and per-query/recursion
    scratch state is left default. Exposed to Python as Analyzer.clone().
---
 include/tvm/arith/analyzer.h                     | 26 ++++++++
 python/tvm/arith/analyzer.py                     | 18 ++++++
 src/arith/analyzer.cc                            | 12 ++++
 src/arith/canonical_simplify.cc                  |  4 ++
 src/arith/const_int_bound.cc                     |  9 +++
 src/arith/int_set.cc                             |  7 +++
 src/arith/modular_set.cc                         |  6 ++
 src/arith/rewrite_simplify.cc                    |  2 +
 src/arith/rewrite_simplify.h                     |  7 +++
 src/arith/transitive_comparison_analyzer.cc      | 11 ++++
 tests/cpp/arith_simplify_test.cc                 | 21 +++++++
 tests/python/arith/test_arith_analyzer_object.py | 79 ++++++++++++++++++++++++
 12 files changed, 202 insertions(+)

diff --git a/include/tvm/arith/analyzer.h b/include/tvm/arith/analyzer.h
index e635315e67..9aca5c1189 100644
--- a/include/tvm/arith/analyzer.h
+++ b/include/tvm/arith/analyzer.h
@@ -181,6 +181,7 @@ class ConstIntBoundAnalyzer {
   friend class ConstraintContext;
   explicit ConstIntBoundAnalyzer(AnalyzerObj* parent);
   TVM_DLL ~ConstIntBoundAnalyzer();
+  void CopyFrom(const ConstIntBoundAnalyzer& other);
   /*!
    * \brief Update the internal state to enter constraint.
    * \param constraint A constraint expression.
@@ -260,6 +261,7 @@ class ModularSetAnalyzer {
   friend class ConstraintContext;
   explicit ModularSetAnalyzer(AnalyzerObj* parent);
   TVM_DLL ~ModularSetAnalyzer();
+  void CopyFrom(const ModularSetAnalyzer& other);
   /*!
    * \brief Update the internal state to enter constraint.
    * \param constraint A constraint expression.
@@ -414,6 +416,7 @@ class RewriteSimplifier {
   friend class CanonicalSimplifier;
   explicit RewriteSimplifier(AnalyzerObj* parent);
   TVM_DLL ~RewriteSimplifier();
+  void CopyFrom(const RewriteSimplifier& other);
   class Impl;
   /*! \brief Internal impl */
   Impl* impl_;
@@ -445,6 +448,7 @@ class CanonicalSimplifier {
   friend class ConstraintContext;
   explicit CanonicalSimplifier(AnalyzerObj* parent);
   TVM_DLL ~CanonicalSimplifier();
+  void CopyFrom(const CanonicalSimplifier& other);
   class Impl;
   /*! \brief Internal impl */
   Impl* impl_;
@@ -530,6 +534,7 @@ class TransitiveComparisonAnalyzer {
   friend class ConstraintContext;
   TransitiveComparisonAnalyzer();
   TVM_DLL ~TransitiveComparisonAnalyzer();
+  void CopyFrom(const TransitiveComparisonAnalyzer& other);
   class Impl;
   /*! \brief Internal impl */
   std::unique_ptr<Impl> impl_;
@@ -584,6 +589,7 @@ class IntSetAnalyzer {
   friend class AnalyzerObj;
   explicit IntSetAnalyzer(AnalyzerObj* parent);
   TVM_DLL ~IntSetAnalyzer();
+  void CopyFrom(const IntSetAnalyzer& other);
   class Impl;
   /*! \brief Internal impl */
   Impl* impl_;
@@ -854,6 +860,26 @@ class TVM_DLL AnalyzerObj : public ffi::Object {
    */
   PrimExpr Simplify(const PrimExpr& expr, int steps = 2);
 
+  /*!
+   * \brief Deep-copy this analyzer into a new, independent Analyzer.
+   *
+   * The returned analyzer carries the same accumulated facts (variable
+   * bounds, modular sets, rewrite/canonical bindings, integer-set domains,
+   * literal constraints and transitive comparisons) as this one, but owns
+   * its own state: binding or simplifying on either analyzer afterwards does
+   * not affect the other. This is the deep copy that handle-copying an
+   * Analyzer does not provide.
+   *
+   * \note Do not call this while a `With<ConstraintContext>` scope is active
+   *       on this analyzer. The clone would inherit the scoped constraints
+   *       but not the recovery functions that pop them on scope exit, so the
+   *       constraints would leak as if they were global facts. Clone at a
+   *       point where no constraint scope is in effect.
+   *
+   * \return A new Analyzer holding an independent copy of the facts.
+   */
+  Analyzer Clone() const;
+
   /*!
    * \brief Analyzer methods update facts, constraints, caches, and stats.
    *
diff --git a/python/tvm/arith/analyzer.py b/python/tvm/arith/analyzer.py
index 78e93395c3..d82cae0129 100644
--- a/python/tvm/arith/analyzer.py
+++ b/python/tvm/arith/analyzer.py
@@ -278,6 +278,24 @@ class Analyzer(Object):
         """
         return _ffi_api.AnalyzerSimplify(self, expr, steps)
 
+    def clone(self) -> "Analyzer":
+        """Return a deep copy of this analyzer with independent state.
+
+        The returned analyzer carries the same accumulated facts (variable
+        bounds, modular sets, bindings, integer-set domains, literal
+        constraints and transitive comparisons) as this one, but owns its own
+        state: binding or simplifying on either analyzer afterwards does not
+        affect the other. Unlike copying the handle, this is a true deep copy.
+
+        Do not call this while a constraint scope is active on this analyzer.
+
+        Returns
+        -------
+        result : Analyzer
+            A new analyzer holding an independent copy of the facts.
+        """
+        return _ffi_api.AnalyzerClone(self)
+
     def rewrite_simplify(self, expr: tirx.PrimExpr) -> tirx.PrimExpr:
         """Simplify expression via rewriting rules.
 
diff --git a/src/arith/analyzer.cc b/src/arith/analyzer.cc
index b66ecb0fd1..fc59f891e1 100644
--- a/src/arith/analyzer.cc
+++ b/src/arith/analyzer.cc
@@ -265,11 +265,23 @@ PrimExpr AnalyzerObj::Simplify(const PrimExpr& expr, int 
steps) {
   return res;
 }
 
+Analyzer AnalyzerObj::Clone() const {
+  Analyzer cloned;
+  cloned->const_int_bound.CopyFrom(this->const_int_bound);
+  cloned->modular_set.CopyFrom(this->modular_set);
+  cloned->rewrite_simplify.CopyFrom(this->rewrite_simplify);
+  cloned->canonical_simplify.CopyFrom(this->canonical_simplify);
+  cloned->int_set.CopyFrom(this->int_set);
+  cloned->transitive_comparisons.CopyFrom(this->transitive_comparisons);
+  return cloned;
+}
+
 TVM_FFI_STATIC_INIT_BLOCK() {
   namespace refl = tvm::ffi::reflection;
   refl::ObjectDef<AnalyzerObj>();
   refl::GlobalDef()
       .def("arith.Analyzer", []() { return Analyzer(); })
+      .def("arith.AnalyzerClone", [](Analyzer analyzer) { return 
analyzer->Clone(); })
       .def("arith.AnalyzerConstIntBound",
            [](Analyzer analyzer, const PrimExpr& expr) { return 
analyzer->const_int_bound(expr); })
       .def("arith.AnalyzerConstIntBoundUpdate",
diff --git a/src/arith/canonical_simplify.cc b/src/arith/canonical_simplify.cc
index f1dd1a63c5..7806c23445 100644
--- a/src/arith/canonical_simplify.cc
+++ b/src/arith/canonical_simplify.cc
@@ -1454,5 +1454,9 @@ CanonicalSimplifier::CanonicalSimplifier(AnalyzerObj* 
parent) : impl_(new Impl(p
 
 CanonicalSimplifier::~CanonicalSimplifier() { delete impl_; }
 
+void CanonicalSimplifier::CopyFrom(const CanonicalSimplifier& other) {
+  impl_->CopyFrom(*other.impl_);
+}
+
 }  // namespace arith
 }  // namespace tvm
diff --git a/src/arith/const_int_bound.cc b/src/arith/const_int_bound.cc
index 8ff1a8b17e..4d700564ea 100644
--- a/src/arith/const_int_bound.cc
+++ b/src/arith/const_int_bound.cc
@@ -498,6 +498,11 @@ class ConstIntBoundAnalyzer::Impl
     return frecover;
   }
 
+  void CopyFrom(const Impl& other) {
+    var_map_ = other.var_map_;
+    additional_info_ = other.additional_info_;
+  }
+
  private:
   friend class ConstIntBoundAnalyzer;
   // parent analyzer
@@ -859,5 +864,9 @@ ConstIntBoundAnalyzer::ConstIntBoundAnalyzer(AnalyzerObj* 
parent) : impl_(new Im
 
 ConstIntBoundAnalyzer::~ConstIntBoundAnalyzer() { delete impl_; }
 
+void ConstIntBoundAnalyzer::CopyFrom(const ConstIntBoundAnalyzer& other) {
+  impl_->CopyFrom(*other.impl_);
+}
+
 }  // namespace arith
 }  // namespace tvm
diff --git a/src/arith/int_set.cc b/src/arith/int_set.cc
index a1e01d3e86..b68042e2af 100644
--- a/src/arith/int_set.cc
+++ b/src/arith/int_set.cc
@@ -658,6 +658,11 @@ class IntSetAnalyzer::Impl {
   void Bind(const Var& var, const PrimExpr& expr, bool override_info);
   std::function<void()> EnterConstraint(const PrimExpr& constraint);
 
+  void CopyFrom(const Impl& other) {
+    dom_map_ = other.dom_map_;
+    dom_constraints_ = other.dom_constraints_;
+  }
+
  private:
   // Utility function to split a boolean condition into the domain
   // bounds implied by that condition.
@@ -681,6 +686,8 @@ IntSetAnalyzer::IntSetAnalyzer(AnalyzerObj* parent) : 
impl_(new Impl(parent)) {}
 
 IntSetAnalyzer::~IntSetAnalyzer() { delete impl_; }
 
+void IntSetAnalyzer::CopyFrom(const IntSetAnalyzer& other) { 
impl_->CopyFrom(*other.impl_); }
+
 IntSet IntSetAnalyzer::operator()(const PrimExpr& expr, const ffi::Map<Var, 
IntSet>& dom_map) {
   return impl_->Eval(expr, dom_map);
 }
diff --git a/src/arith/modular_set.cc b/src/arith/modular_set.cc
index 5f66356e1a..856f5df0b7 100644
--- a/src/arith/modular_set.cc
+++ b/src/arith/modular_set.cc
@@ -310,6 +310,8 @@ class ModularSetAnalyzer::Impl : public 
ExprFunctor<ModularSetAnalyzer::Entry(co
     return Everything();
   }
 
+  void CopyFrom(const Impl& other) { var_map_ = other.var_map_; }
+
  private:
   /*! \brief pointer to parent. */
   AnalyzerObj* parent_{nullptr};
@@ -407,5 +409,9 @@ ModularSetAnalyzer::ModularSetAnalyzer(AnalyzerObj* parent) 
: impl_(new Impl(par
 
 ModularSetAnalyzer::~ModularSetAnalyzer() { delete impl_; }
 
+void ModularSetAnalyzer::CopyFrom(const ModularSetAnalyzer& other) {
+  impl_->CopyFrom(*other.impl_);
+}
+
 }  // namespace arith
 }  // namespace tvm
diff --git a/src/arith/rewrite_simplify.cc b/src/arith/rewrite_simplify.cc
index 6d6ce03016..a1bbce1072 100644
--- a/src/arith/rewrite_simplify.cc
+++ b/src/arith/rewrite_simplify.cc
@@ -2466,6 +2466,8 @@ RewriteSimplifier::RewriteSimplifier(AnalyzerObj* parent) 
: impl_(new Impl(paren
 
 RewriteSimplifier::~RewriteSimplifier() { delete impl_; }
 
+void RewriteSimplifier::CopyFrom(const RewriteSimplifier& other) { 
impl_->CopyFrom(*other.impl_); }
+
 // Pattern A (RM): auto-default repr from reflection.
 
 }  // namespace arith
diff --git a/src/arith/rewrite_simplify.h b/src/arith/rewrite_simplify.h
index b42b73336a..719aa5ec07 100644
--- a/src/arith/rewrite_simplify.h
+++ b/src/arith/rewrite_simplify.h
@@ -135,6 +135,13 @@ class RewriteSimplifier::Impl : public 
IRMutatorWithAnalyzer {
 
   void SetMaximumRewriteSteps(int64_t maximum) { maximum_rewrite_steps_ = 
maximum; }
 
+  void CopyFrom(const Impl& other) {
+    var_map_ = other.var_map_;
+    literal_constraints_ = other.literal_constraints_;
+    enabled_extensions_ = other.enabled_extensions_;
+    maximum_rewrite_steps_ = other.maximum_rewrite_steps_;
+  }
+
  protected:
   int64_t maximum_rewrite_steps_{0};
   RewriteSimplifierStatsNode stats_;
diff --git a/src/arith/transitive_comparison_analyzer.cc 
b/src/arith/transitive_comparison_analyzer.cc
index e7deea4cfd..20fd05169f 100644
--- a/src/arith/transitive_comparison_analyzer.cc
+++ b/src/arith/transitive_comparison_analyzer.cc
@@ -82,6 +82,13 @@ class TransitiveComparisonAnalyzer::Impl {
    */
   std::function<void()> EnterConstraint(const PrimExpr& expr);
 
+  void CopyFrom(const Impl& other) {
+    expr_to_key = other.expr_to_key;
+    prev_bindings_ = other.prev_bindings_;
+    knowns_ = other.knowns_;
+    scoped_knowns_ = other.scoped_knowns_;
+  }
+
  private:
   /* \brief Internal representation of a PrimExpr
    *
@@ -528,6 +535,10 @@ bool 
TransitiveComparisonAnalyzer::Impl::Comparison::Implies(
 TransitiveComparisonAnalyzer::TransitiveComparisonAnalyzer() : 
impl_(std::make_unique<Impl>()) {}
 TransitiveComparisonAnalyzer::~TransitiveComparisonAnalyzer() {}
 
+void TransitiveComparisonAnalyzer::CopyFrom(const 
TransitiveComparisonAnalyzer& other) {
+  impl_->CopyFrom(*other.impl_);
+}
+
 CompareResult TransitiveComparisonAnalyzer::TryCompare(const PrimExpr& lhs, 
const PrimExpr& rhs,
                                                        bool 
propagate_inequalities) {
   return impl_->TryCompare(lhs, rhs, propagate_inequalities);
diff --git a/tests/cpp/arith_simplify_test.cc b/tests/cpp/arith_simplify_test.cc
index ba5305e9dd..d5050446d6 100644
--- a/tests/cpp/arith_simplify_test.cc
+++ b/tests/cpp/arith_simplify_test.cc
@@ -75,6 +75,27 @@ TEST(AnalyzerObjectRef, 
ConstHandleRefCanMutateAnalyzerState) {
   TVM_FFI_ICHECK(analyzer->CanProve(x < 8));
 }
 
+TEST(AnalyzerObjectRef, CloneIsIndependent) {
+  tvm::arith::Analyzer analyzer;
+  auto x = tvm::te::var("x");
+  auto y = tvm::te::var("y");
+
+  analyzer->Bind(x, tvm::Range::FromMinExtent(0, 8));
+  analyzer->modular_set.Update(x, tvm::arith::ModularSet(4, 0));
+
+  tvm::arith::Analyzer clone = analyzer->Clone();
+  TVM_FFI_ICHECK(clone->CanProve(x < 8));
+  TVM_FFI_ICHECK(clone->modular_set(x)->coeff == 4);
+
+  clone->Bind(y, tvm::Range::FromMinExtent(0, 4));
+  clone->modular_set.Update(x, tvm::arith::ModularSet(8, 0), true);
+  TVM_FFI_ICHECK(clone->CanProve(y < 4));
+  TVM_FFI_ICHECK(!analyzer->CanProve(y < 4));
+  TVM_FFI_ICHECK(analyzer->CanProve(x < 8));
+  TVM_FFI_ICHECK(analyzer->modular_set(x)->coeff == 4);
+  TVM_FFI_ICHECK(clone->modular_set(x)->coeff == 8);
+}
+
 TEST(ConstantFold, Broadcast) {
   tvm::ffi::StructuralEqual checker;
   auto i32x4 = tvm::tirx::Broadcast(tvm::IntImm::Int32(10), 4);
diff --git a/tests/python/arith/test_arith_analyzer_object.py 
b/tests/python/arith/test_arith_analyzer_object.py
index 4b4c4134b9..9edd75d7aa 100644
--- a/tests/python/arith/test_arith_analyzer_object.py
+++ b/tests/python/arith/test_arith_analyzer_object.py
@@ -204,5 +204,84 @@ def test_analyzer_object_state_persists_across_ffi_calls():
     tvm.ir.assert_structural_equal(analyzer.simplify(tile), tvm.tirx.const(8, 
"int32"))
 
 
+def test_analyzer_object_clone_is_independent():
+    analyzer = tvm.arith.Analyzer()
+    x = tirx.Var("x", "int64")
+    y = tirx.Var("y", "int64")
+    z = tirx.Var("z", "int64")
+
+    analyzer.bind(x, tvm.ir.Range(0, 8))
+
+    clone = analyzer.clone()
+    assert clone is not analyzer
+    assert clone.can_prove(x < 8)
+
+    clone.bind(y, tvm.ir.Range(0, 4))
+    assert clone.can_prove(y < 4)
+    assert not analyzer.can_prove(y < 4)
+
+    analyzer.bind(z, tvm.ir.Range(0, 4))
+    assert analyzer.can_prove(z < 4)
+    assert not clone.can_prove(z < 4)
+
+    assert analyzer.can_prove(x < 8)
+    assert clone.can_prove(x < 8)
+
+
+def test_analyzer_object_clone_copies_every_sub_analyzer():
+    analyzer = tvm.arith.Analyzer()
+    x = tirx.Var("x", "int64")
+    w = tirx.Var("w", "int64")
+    v = tirx.Var("v", "int64")
+
+    analyzer.bind(x, tvm.ir.Range(0, 8))
+    analyzer.update(x, tvm.arith.ModularSet(4, 0))
+    analyzer.bind(w, tirx.const(4, "int64"))
+    analyzer.update(v, tvm.arith.IntervalSet(2, 9))
+    analyzer.enabled_extensions = Extension.ComparisonOfProductAndSum
+
+    clone = analyzer.clone()
+
+    assert clone.can_prove(x < 8)
+    assert clone.modular_set(x).coeff == 4
+    tvm.ir.assert_structural_equal(clone.simplify(w + 1), tirx.const(5, 
"int64"))
+    assert clone.int_set(v).max_value.value == 9
+    assert clone.enabled_extensions == Extension.ComparisonOfProductAndSum
+    assert clone.try_compare(x, tirx.const(0, "int64")) == CompareResult.GE
+
+    t = tirx.Var("t", "int64")
+    clone.update(x, tvm.arith.ModularSet(8, 0), override=True)
+    clone.update(v, tvm.arith.IntervalSet(0, 3), override=True)
+    clone.bind(w, tirx.const(8, "int64"), allow_override=True)
+    clone.bind(t, tvm.ir.Range(0, 4))
+    clone.enabled_extensions = Extension.NoExtensions
+
+    assert analyzer.modular_set(x).coeff == 4
+    assert clone.modular_set(x).coeff == 8
+    assert analyzer.int_set(v).max_value.value == 9
+    assert clone.int_set(v).max_value.value == 3
+    tvm.ir.assert_structural_equal(analyzer.simplify(w + 1), tirx.const(5, 
"int64"))
+    tvm.ir.assert_structural_equal(clone.simplify(w + 1), tirx.const(9, 
"int64"))
+    assert analyzer.enabled_extensions == Extension.ComparisonOfProductAndSum
+    assert clone.enabled_extensions == Extension.NoExtensions
+    assert clone.try_compare(t, tirx.const(0, "int64")) == CompareResult.GE
+    assert analyzer.try_compare(t, tirx.const(0, "int64")) == 
CompareResult.UNKNOWN
+
+
+def test_analyzer_object_clone_resets_rewrite_stats():
+    analyzer = tvm.arith.Analyzer()
+    x = tirx.Var("x", "int64")
+    y = tirx.Var("y", "int64")
+    analyzer.bind(x, tvm.ir.Range(0, 8))
+    analyzer.bind(y, tvm.ir.Range(0, 8))
+    analyzer.simplify((x + y) * 2 - x - y)
+    source_attempts = analyzer.rewrite_simplify_stats.rewrites_attempted
+    assert source_attempts > 0
+
+    clone = analyzer.clone()
+    assert clone.rewrite_simplify_stats.rewrites_attempted == 0
+    assert analyzer.rewrite_simplify_stats.rewrites_attempted == 
source_attempts
+
+
 if __name__ == "__main__":
     tvm.testing.main()

Reply via email to