This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new 4bcf694cbf [REFACTOR][IR] Inline ReplaceGlobalVars into 
AttachGlobalSymbol (#19625)
4bcf694cbf is described below

commit 4bcf694cbf211121a600435bca48967eabef360a
Author: Tianqi Chen <[email protected]>
AuthorDate: Wed May 27 15:34:24 2026 -0400

    [REFACTOR][IR] Inline ReplaceGlobalVars into AttachGlobalSymbol (#19625)
    
    ## Summary
    
    `ReplaceGlobalVars` was a public IR-layer API with only one in-tree C++
    caller (`relax::AttachGlobalSymbol`). The mechanism used a NodeFunctor
    vtable populated at static-init time by per-dialect `.cc` files in
    relax and tirx, which made the IR layer logically depend on its
    dialects even though the include graph did not show it.
    
    Move the dispatch logic into the consumer as file-local mutators and
    a private helper. Delete the public header, the IR-layer driver, both
    per-dialect dispatch registrations, the `IRModule.replace_global_vars`
    python method, and its dedicated test file. The behavior is still
    covered by `tests/python/relax/test_transform_attach_global_symbol.py`
    and by the pipelines that include the `AttachGlobalSymbol` pass.
---
 include/tvm/ir/replace_global_vars.h               |  57 ----
 python/tvm/ir/module.py                            |  27 --
 src/ir/replace_global_vars.cc                      | 110 --------
 src/relax/transform/attach_global_symbol.cc        | 106 ++++++-
 src/relax/transform/replace_global_vars.cc         |  83 ------
 src/tirx/transform/replace_global_vars.cc          |  84 ------
 .../python/ir/test_transform_replace_global_var.py | 308 ---------------------
 7 files changed, 104 insertions(+), 671 deletions(-)

diff --git a/include/tvm/ir/replace_global_vars.h 
b/include/tvm/ir/replace_global_vars.h
deleted file mode 100644
index 0a9b385296..0000000000
--- a/include/tvm/ir/replace_global_vars.h
+++ /dev/null
@@ -1,57 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file tvm/ir/replace_global_vars.h
- *
- * \brief A utility to replace GlobalVar instances across all TVM IR
- * types in an IRMdoule.
- */
-#ifndef TVM_IR_REPLACE_GLOBAL_VARS_H_
-#define TVM_IR_REPLACE_GLOBAL_VARS_H_
-
-#include <tvm/ir/module.h>
-
-namespace tvm {
-namespace transform {
-
-/*!
- * \brief Replace GlobalVar instances across any IR type.
- *
- * \param mod The module to update
- *
- * \param replacements The map, where each entry maps from an old
- * `GlobalVar` to the new `GlobalVar` that should replace it.
- *
- * \return The updated IRModule
- */
-TVM_DLL IRModule ReplaceGlobalVars(IRModule mod, ffi::Map<GlobalVar, 
GlobalVar> replacements);
-
-struct GlobalVarReplacer {
-  using FType = NodeFunctor<BaseFunc(const ffi::ObjectRef&, 
ffi::Map<GlobalVar, GlobalVar>)>;
-  TVM_DLL static FType& vtable() {
-    static FType inst;
-    return inst;
-  }
-};
-
-}  // namespace transform
-}  // namespace tvm
-
-#endif  // TVM_IR_REPLACE_GLOBAL_VARS_H_
diff --git a/python/tvm/ir/module.py b/python/tvm/ir/module.py
index a9f43e09bd..95b9d940ec 100644
--- a/python/tvm/ir/module.py
+++ b/python/tvm/ir/module.py
@@ -195,33 +195,6 @@ class IRModule(Node, Scriptable):
         """
         return _ffi_api.Module_GetGlobalVars(self)
 
-    def replace_global_vars(
-        self,
-        replacements: dict[str | _expr.GlobalVar, str | _expr.GlobalVar],
-    ) -> "IRModule":
-        """Replace GlobalVar instances within the module
-
-        Replace GlobalVars within the IRModule.  Since the IRModule
-        may contain internal references to a GlobalVar, either in TIR
-        or in Relax, this method should be used whenever replacing or
-        renaming a GlobalVar.
-
-        Parameters
-        ----------
-        replacements: Dict[Union[str, _expr.GlobalVar], Union[str, 
_expr.GlobalVar]]
-
-            A dictionary where each key is a GlobalVar to be replaced,
-            and the corresponding value is the GlobalVar with which to
-            replace it.
-
-        Returns
-        -------
-        IRModule
-            The updated module
-
-        """
-        return _ffi_api.Module_ReplaceGlobalVars(self, replacements)
-
     @staticmethod
     def from_expr(expr, functions=None):
         """Construct a module from a standalone expression.
diff --git a/src/ir/replace_global_vars.cc b/src/ir/replace_global_vars.cc
deleted file mode 100644
index 2a3517b4d8..0000000000
--- a/src/ir/replace_global_vars.cc
+++ /dev/null
@@ -1,110 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- * \file src/ir/replace_global_vars.cc
- * \brief IRModule transform to replace GlobalVar instances across any IR type.
- */
-
-#include <tvm/ffi/container/variant.h>
-#include <tvm/ffi/reflection/registry.h>
-#include <tvm/ir/replace_global_vars.h>
-
-#include <vector>
-
-namespace tvm {
-namespace transform {
-
-IRModule ReplaceGlobalVars(IRModule mod, ffi::Map<GlobalVar, GlobalVar> 
replacements) {
-  if (replacements.empty()) {
-    return mod;
-  }
-
-  std::vector<GlobalVar> to_remove;
-  IRModule updates;
-
-  const auto& vtable = GlobalVarReplacer::vtable();
-
-  for (const auto& [old_gvar, old_func] : mod->functions) {
-    auto new_gvar = replacements.Get(old_gvar).value_or(old_gvar);
-    auto new_func = vtable(old_func, replacements);
-
-    if (!new_gvar.same_as(old_gvar)) {
-      to_remove.push_back(old_gvar);
-    }
-    if (!old_gvar.same_as(new_gvar) || !old_func.same_as(new_func)) {
-      updates->Add(new_gvar, new_func);
-    }
-  }
-
-  if (to_remove.size() || updates->functions.size()) {
-    auto write_ptr = mod.CopyOnWrite();
-    for (const auto& old_gvar : to_remove) {
-      write_ptr->Remove(old_gvar);
-    }
-    write_ptr->Update(updates);
-  }
-  return mod;
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
-  namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("transform.ReplaceGlobalVars", ReplaceGlobalVars);
-}
-
-IRModule ModuleReplaceGlobalVars(
-    IRModule mod,
-    ffi::Map<ffi::Variant<ffi::String, GlobalVar>, ffi::Variant<ffi::String, 
GlobalVar>>
-        replacements) {
-  ffi::Map<GlobalVar, GlobalVar> gvar_replacements;
-  for (const auto& [before, after] : replacements) {
-    GlobalVar gvar_before;
-    if (auto gvar = before.as<GlobalVar>()) {
-      gvar_before = gvar.value();
-    } else if (auto str = before.as<ffi::String>()) {
-      gvar_before = mod->GetGlobalVar(str.value());
-    } else {
-      TVM_FFI_THROW(InternalError)
-          << "ffi::Variant<ffi::String,GlobalVar> must contain either 
ffi::String or GlobalVar";
-    }
-
-    GlobalVar gvar_after;
-    if (auto gvar = after.as<GlobalVar>()) {
-      gvar_after = gvar.value();
-    } else if (auto str = after.as<ffi::String>()) {
-      gvar_after = gvar_before;
-      gvar_after.CopyOnWrite()->name_hint = str.value();
-    } else {
-      TVM_FFI_THROW(InternalError)
-          << "ffi::Variant<ffi::String,GlobalVar> must contain either 
ffi::String or GlobalVar";
-    }
-
-    gvar_replacements.Set(gvar_before, gvar_after);
-  }
-
-  return ReplaceGlobalVars(mod, gvar_replacements);
-}
-
-TVM_FFI_STATIC_INIT_BLOCK() {
-  namespace refl = tvm::ffi::reflection;
-  refl::GlobalDef().def("ir.Module_ReplaceGlobalVars", 
ModuleReplaceGlobalVars);
-}
-
-}  // namespace transform
-}  // namespace tvm
diff --git a/src/relax/transform/attach_global_symbol.cc 
b/src/relax/transform/attach_global_symbol.cc
index d22b6eb40a..0e8cd722c1 100644
--- a/src/relax/transform/attach_global_symbol.cc
+++ b/src/relax/transform/attach_global_symbol.cc
@@ -24,15 +24,117 @@
 #include <tvm/ffi/cast.h>
 #include <tvm/ffi/reflection/registry.h>
 #include <tvm/ir/module.h>
-#include <tvm/ir/replace_global_vars.h>
+#include <tvm/relax/expr_functor.h>
 #include <tvm/relax/struct_info.h>
 #include <tvm/relax/transform.h>
 #include <tvm/tirx/function.h>
+#include <tvm/tirx/stmt_functor.h>
+
+#include <vector>
 
 namespace tvm {
 namespace relax {
 namespace transform {
 
+namespace {
+
+// File-local mutator: replace GlobalVar references inside a relax::Function.
+struct RelaxGvarMutator : ExprMutator {
+  ffi::Map<GlobalVar, GlobalVar> replacements;
+  explicit RelaxGvarMutator(ffi::Map<GlobalVar, GlobalVar> replacements)
+      : replacements(replacements) {}
+
+  using ExprMutator::VisitExpr_;
+  Expr VisitExpr_(const GlobalVarNode* node) override {
+    auto gvar = ffi::GetRef<GlobalVar>(node);
+    return replacements.Get(gvar).value_or(gvar);
+  }
+};
+
+// File-local mutator: replace GlobalVar references inside a tirx::PrimFunc.
+struct TirxGvarMutator : tirx::StmtExprMutator {
+  ffi::Map<GlobalVar, GlobalVar> replacements;
+  explicit TirxGvarMutator(ffi::Map<GlobalVar, GlobalVar> replacements)
+      : replacements(replacements) {}
+
+  PrimExpr VisitExpr_(const tirx::CallNode* node) override {
+    auto call = Downcast<tirx::Call>(tirx::StmtExprMutator::VisitExpr_(node));
+    if (auto old_gvar = call->op.as<GlobalVar>()) {
+      if (auto new_gvar = replacements.Get(old_gvar.value())) {
+        call.CopyOnWrite()->op = new_gvar.value();
+      }
+    }
+    return call;
+  }
+};
+
+// Replace GlobalVar references across all functions in the module.
+// Direct dispatch on function type — no NodeFunctor indirection needed
+// since this file already includes the relax + tirx headers.
+IRModule ReplaceGlobalVarsInModule(IRModule mod, ffi::Map<GlobalVar, 
GlobalVar> replacements) {
+  if (replacements.empty()) {
+    return mod;
+  }
+
+  std::vector<GlobalVar> to_remove;
+  IRModule updates;
+
+  for (const auto& [old_gvar, old_func] : mod->functions) {
+    auto new_gvar = replacements.Get(old_gvar).value_or(old_gvar);
+    BaseFunc new_func;
+
+    if (auto* prim_func_node = old_func.as<tirx::PrimFuncNode>()) {
+      auto func = ffi::GetRef<tirx::PrimFunc>(prim_func_node);
+      TirxGvarMutator mutator(replacements);
+      auto new_body = mutator(func->body);
+      if (!new_body.same_as(func->body)) {
+        func.CopyOnWrite()->body = new_body;
+      }
+      // Update kGlobalSymbol if the function is externally exposed and being 
renamed.
+      if (func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
+        if (new_gvar->name_hint != old_gvar->name_hint) {
+          func = WithAttr(func, tvm::attr::kGlobalSymbol, new_gvar->name_hint);
+        }
+      }
+      new_func = func;
+    } else if (auto* relax_func_node = old_func.as<FunctionNode>()) {
+      RelaxGvarMutator mutator(replacements);
+      auto new_relax_func =
+          
Downcast<Function>(mutator(Downcast<Function>(ffi::GetRef<Function>(relax_func_node))));
+      // Update kGlobalSymbol if the function is externally exposed and being 
renamed.
+      if (new_relax_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
+        if (new_gvar->name_hint != old_gvar->name_hint) {
+          new_relax_func = WithAttr(new_relax_func, tvm::attr::kGlobalSymbol, 
new_gvar->name_hint);
+        }
+      }
+      new_func = new_relax_func;
+    } else if (old_func.as<ExternFuncNode>()) {
+      // ExternFunc: no internal GlobalVar references to update.
+      new_func = old_func;
+    } else {
+      new_func = old_func;
+    }
+
+    if (!new_gvar.same_as(old_gvar)) {
+      to_remove.push_back(old_gvar);
+    }
+    if (!old_gvar.same_as(new_gvar) || !old_func.same_as(new_func)) {
+      updates->Add(new_gvar, new_func);
+    }
+  }
+
+  if (to_remove.size() || updates->functions.size()) {
+    auto write_ptr = mod.CopyOnWrite();
+    for (const auto& old_gvar : to_remove) {
+      write_ptr->Remove(old_gvar);
+    }
+    write_ptr->Update(updates);
+  }
+  return mod;
+}
+
+}  // namespace
+
 Pass AttachGlobalSymbol() {
   auto pass_func = [=](IRModule mod, PassContext pc) {
     ffi::String c_prefix = 
mod->GetAttr<ffi::String>(tvm::attr::kSystemLibPrefix).value_or("");
@@ -74,7 +176,7 @@ Pass AttachGlobalSymbol() {
       mod.CopyOnWrite()->Update(updates);
 
       if (gvar_updates.size()) {
-        mod = tvm::transform::ReplaceGlobalVars(mod, gvar_updates);
+        mod = ReplaceGlobalVarsInModule(mod, gvar_updates);
       }
     }
     return mod;
diff --git a/src/relax/transform/replace_global_vars.cc 
b/src/relax/transform/replace_global_vars.cc
deleted file mode 100644
index f895cd50eb..0000000000
--- a/src/relax/transform/replace_global_vars.cc
+++ /dev/null
@@ -1,83 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *
- * \file src/relax/transform/replace_global_vars.cc
- *
- * \brief GlobalVar replacement across IR types
- */
-
-#include <tvm/ffi/cast.h>
-#include <tvm/ir/replace_global_vars.h>
-#include <tvm/relax/analysis.h>
-#include <tvm/relax/expr_functor.h>
-#include <tvm/tirx/expr_functor.h>
-
-namespace tvm {
-namespace relax {
-
-namespace {
-using tvm::transform::GlobalVarReplacer;
-
-struct Mutator : ExprMutator {
-  ffi::Map<GlobalVar, GlobalVar> replacements;
-  explicit Mutator(ffi::Map<GlobalVar, GlobalVar> replacements) : 
replacements(replacements) {}
-
-  using ExprMutator::VisitExpr_;
-  Expr VisitExpr_(const GlobalVarNode* node) override {
-    auto gvar = ffi::GetRef<GlobalVar>(node);
-    return replacements.Get(gvar).value_or(gvar);
-  }
-};
-
-}  // namespace
-
-TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
-    .set_dispatch<relax::FunctionNode>([](const ffi::ObjectRef& func,
-                                          ffi::Map<GlobalVar, GlobalVar> 
replacements) -> BaseFunc {
-      Mutator mutator(replacements);
-      auto new_func = Downcast<Function>(mutator(Downcast<Function>(func)));
-
-      // If the function is externally exposed, and is being replaced
-      // by a GlobalVar with a new name, then the function's
-      // kGlobalSymbol must be updated to match.
-      if (auto opt = new_func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) 
{
-        auto name = opt.value();
-        for (const auto& [before, after] : replacements) {
-          if (before->name_hint == name) {
-            if (after->name_hint != name) {
-              new_func = WithAttr(new_func, tvm::attr::kGlobalSymbol, 
after->name_hint);
-            }
-            break;
-          }
-        }
-      }
-
-      return new_func;
-    });
-
-TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
-    .set_dispatch<relax::ExternFuncNode>([](const ffi::ObjectRef& func,
-                                            ffi::Map<GlobalVar, GlobalVar>) -> 
BaseFunc {
-      return Downcast<ExternFunc>(func);
-    });
-
-}  // namespace relax
-}  // namespace tvm
diff --git a/src/tirx/transform/replace_global_vars.cc 
b/src/tirx/transform/replace_global_vars.cc
deleted file mode 100644
index 289d219b6b..0000000000
--- a/src/tirx/transform/replace_global_vars.cc
+++ /dev/null
@@ -1,84 +0,0 @@
-/*
- * Licensed to the Apache Software Foundation (ASF) under one
- * or more contributor license agreements.  See the NOTICE file
- * distributed with this work for additional information
- * regarding copyright ownership.  The ASF licenses this file
- * to you under the Apache License, Version 2.0 (the
- * "License"); you may not use this file except in compliance
- * with the License.  You may obtain a copy of the License at
- *
- *   http://www.apache.org/licenses/LICENSE-2.0
- *
- * Unless required by applicable law or agreed to in writing,
- * software distributed under the License is distributed on an
- * "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
- * KIND, either express or implied.  See the License for the
- * specific language governing permissions and limitations
- * under the License.
- */
-
-/*!
- *
- * \file src/tirx/transform/replace_global_vars.cc
- *
- * \brief GlobalVar replacement across IR types
- */
-
-#include <tvm/ir/replace_global_vars.h>
-#include <tvm/tirx/function.h>
-#include <tvm/tirx/stmt_functor.h>
-
-namespace tvm {
-namespace tirx {
-
-namespace {
-using tvm::transform::GlobalVarReplacer;
-
-struct Mutator : StmtExprMutator {
-  ffi::Map<GlobalVar, GlobalVar> replacements;
-  explicit Mutator(ffi::Map<GlobalVar, GlobalVar> replacements) : 
replacements(replacements) {}
-
-  PrimExpr VisitExpr_(const CallNode* node) override {
-    auto call = Downcast<Call>(StmtExprMutator::VisitExpr_(node));
-    if (auto old_gvar = call->op.as<GlobalVar>()) {
-      if (auto new_gvar = replacements.Get(old_gvar.value())) {
-        call.CopyOnWrite()->op = new_gvar.value();
-      }
-    }
-    return call;
-  }
-};
-
-}  // namespace
-
-TVM_STATIC_IR_FUNCTOR(GlobalVarReplacer, vtable)
-    .set_dispatch<tirx::PrimFuncNode>([](const ffi::ObjectRef& obj,
-                                         ffi::Map<GlobalVar, GlobalVar> 
replacements) -> BaseFunc {
-      Mutator mutator(replacements);
-      auto func = Downcast<PrimFunc>(obj);
-      auto new_body = mutator(func->body);
-
-      if (!new_body.same_as(func->body)) {
-        func.CopyOnWrite()->body = new_body;
-      }
-
-      // If the function is externally exposed, and is being replaced
-      // by a GlobalVar with a new name, then the function's
-      // kGlobalSymbol must be updated to match.
-      if (auto opt = func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
-        auto name = opt.value();
-        for (const auto& [before, after] : replacements) {
-          if (before->name_hint == name) {
-            if (after->name_hint != name) {
-              func = WithAttr(func, tvm::attr::kGlobalSymbol, 
after->name_hint);
-            }
-            break;
-          }
-        }
-      }
-
-      return func;
-    });
-
-}  // namespace tirx
-}  // namespace tvm
diff --git a/tests/python/ir/test_transform_replace_global_var.py 
b/tests/python/ir/test_transform_replace_global_var.py
deleted file mode 100644
index 70a693c06e..0000000000
--- a/tests/python/ir/test_transform_replace_global_var.py
+++ /dev/null
@@ -1,308 +0,0 @@
-# Licensed to the Apache Software Foundation (ASF) under one
-# or more contributor license agreements.  See the NOTICE file
-# distributed with this work for additional information
-# regarding copyright ownership.  The ASF licenses this file
-# to you under the Apache License, Version 2.0 (the
-# "License"); you may not use this file except in compliance
-# with the License.  You may obtain a copy of the License at
-#
-#   http://www.apache.org/licenses/LICENSE-2.0
-#
-# Unless required by applicable law or agreed to in writing,
-# software distributed under the License is distributed on an
-# "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
-# KIND, either express or implied.  See the License for the
-# specific language governing permissions and limitations
-# under the License.
-
-import tvm.testing
-from tvm.script import ir as I
-from tvm.script import relax as R
-from tvm.script import tirx as T
-
-
-def _get_before_module():
-    @I.ir_module
-    class Module:
-        @R.function
-        def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], 
"float32"):
-            R.func_attr({"relax.force_pure": True})
-
-            B = Module.relax_subroutine(A)
-            C = R.call_tir(Module.tir_main, B, out_sinfo=R.Tensor([16], 
"float32"))
-
-            D = R.builtin.alloc_tensor(R.shape([16]), "float32", 
runtime_device_index=0)
-            Module.tir_main(C, D)
-
-            return D
-
-        @R.function(private=True)
-        def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], 
"float32"):
-            B = R.add(A, R.prim_value(T.float32(1.0)))
-            return B
-
-        @T.prim_func(s_tir=True)
-        def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
-            Module.tir_subroutine(A.data, B.data)
-
-        @T.prim_func(private=True, s_tir=True)
-        def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")):
-            A = T.decl_buffer(16, "float32", data=A_data)
-            B = T.decl_buffer(16, "float32", data=B_data)
-            for i in range(16):
-                B[i] = A[i] + 1.0
-
-    return Module
-
-
-def test_no_op_if_no_replacements():
-    """If no replacements are performed, the IRModule is unmodified"""
-
-    before = _get_before_module()
-    expected = before
-
-    after = before.replace_global_vars({})
-
-    tvm.ir.assert_structural_equal(expected, after)
-    assert before.same_as(after)
-
-
-def test_replace_relax_main():
-    """An externally-exposed Relax function may be replaced
-
-    In this example, the "relax_main" function is renamed.  This
-    requires changing both the GlobalVar used to refer to the
-    function, and the "global_symbol" attribute of the
-    externally-exposed function.
-
-    """
-
-    before = _get_before_module()
-    after = before.replace_global_vars({"relax_main": 
"relax_main_with_new_name"})
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def relax_main_with_new_name(A: R.Tensor([16], "float32")) -> 
R.Tensor([16], "float32"):
-            R.func_attr({"relax.force_pure": True})
-
-            B = Expected.relax_subroutine(A)
-            C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], 
"float32"))
-
-            D = R.builtin.alloc_tensor(R.shape([16]), "float32", 
runtime_device_index=0)
-            Expected.tir_main(C, D)
-
-            return D
-
-        @R.function(private=True)
-        def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], 
"float32"):
-            B = R.add(A, R.prim_value(T.float32(1.0)))
-            return B
-
-        @T.prim_func(s_tir=True)
-        def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
-            Expected.tir_subroutine(A.data, B.data)
-
-        @T.prim_func(private=True, s_tir=True)
-        def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")):
-            A = T.decl_buffer(16, "float32", data=A_data)
-            B = T.decl_buffer(16, "float32", data=B_data)
-            for i in range(16):
-                B[i] = A[i] + 1.0
-
-    tvm.ir.assert_structural_equal(Expected, after)
-
-
-def test_replace_relax_subroutine():
-    """An internal Relax function may be replaced
-
-    In this example, the "relax_subroutine" function is renamed.  This
-    requires changing both the GlobalVar used to refer to the
-    function, and the GlobalVar used to call the subroutine within
-    "relax_main".  The "global_symbol" attribute does not need to be
-    updated, because internal functions do not have this attribute.
-
-    """
-
-    before = _get_before_module()
-    after = before.replace_global_vars({"relax_subroutine": 
"relax_subroutine_with_new_name"})
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], 
"float32"):
-            R.func_attr({"relax.force_pure": True})
-
-            B = Expected.relax_subroutine_with_new_name(A)
-            C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], 
"float32"))
-
-            D = R.builtin.alloc_tensor(R.shape([16]), "float32", 
runtime_device_index=0)
-            Expected.tir_main(C, D)
-
-            return D
-
-        @R.function(private=True)
-        def relax_subroutine_with_new_name(
-            A: R.Tensor([16], "float32"),
-        ) -> R.Tensor([16], "float32"):
-            B = R.add(A, R.prim_value(T.float32(1.0)))
-            return B
-
-        @T.prim_func(s_tir=True)
-        def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
-            Expected.tir_subroutine(A.data, B.data)
-
-        @T.prim_func(private=True, s_tir=True)
-        def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")):
-            A = T.decl_buffer(16, "float32", data=A_data)
-            B = T.decl_buffer(16, "float32", data=B_data)
-            for i in range(16):
-                B[i] = A[i] + 1.0
-
-    tvm.ir.assert_structural_equal(Expected, after)
-
-
-def test_replace_tir_main():
-    """An externally-exposed TIR function may be replaced
-
-    In this example, the "tir_main" function is renamed.  This
-    requires changing both the GlobalVar used to refer to the
-    function, the "global_symbol" attribute of the externally-exposed
-    function.  In addition, calls to the TIR function should be
-    updated to use the new GlobalVar.
-
-    """
-
-    before = _get_before_module()
-    after = before.replace_global_vars({"tir_main": "tir_main_with_new_name"})
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], 
"float32"):
-            R.func_attr({"relax.force_pure": True})
-
-            B = Expected.relax_subroutine(A)
-            C = R.call_tir(Expected.tir_main_with_new_name, B, 
out_sinfo=R.Tensor([16], "float32"))
-
-            D = R.builtin.alloc_tensor(R.shape([16]), "float32", 
runtime_device_index=0)
-            Expected.tir_main_with_new_name(C, D)
-
-            return D
-
-        @R.function(private=True)
-        def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], 
"float32"):
-            B = R.add(A, R.prim_value(T.float32(1.0)))
-            return B
-
-        @T.prim_func(s_tir=True)
-        def tir_main_with_new_name(A: T.Buffer(16, "float32"), B: T.Buffer(16, 
"float32")):
-            Expected.tir_subroutine(A.data, B.data)
-
-        @T.prim_func(private=True, s_tir=True)
-        def tir_subroutine(A_data: T.ptr("float32"), B_data: T.ptr("float32")):
-            A = T.decl_buffer(16, "float32", data=A_data)
-            B = T.decl_buffer(16, "float32", data=B_data)
-            for i in range(16):
-                B[i] = A[i] + 1.0
-
-    tvm.ir.assert_structural_equal(Expected, after)
-
-
-def test_replace_tir_subroutine():
-    """An internally-exposed TIR function may be replaced
-
-    In this example, the "tir_subroutine" function is renamed.  This
-    requires changing both the GlobalVar used to refer to the
-    function, and the GlobalVar used to refer to it.  Internal
-    functions do not have the "global_symbol" attribute, so it does
-    not need to be updated.
-
-    """
-
-    before = _get_before_module()
-    after = before.replace_global_vars({"tir_subroutine": 
"tir_subroutine_with_new_name"})
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def relax_main(A: R.Tensor([16], "float32")) -> R.Tensor([16], 
"float32"):
-            R.func_attr({"relax.force_pure": True})
-
-            B = Expected.relax_subroutine(A)
-            C = R.call_tir(Expected.tir_main, B, out_sinfo=R.Tensor([16], 
"float32"))
-
-            D = R.builtin.alloc_tensor(R.shape([16]), "float32", 
runtime_device_index=0)
-            Expected.tir_main(C, D)
-
-            return D
-
-        @R.function(private=True)
-        def relax_subroutine(A: R.Tensor([16], "float32")) -> R.Tensor([16], 
"float32"):
-            B = R.add(A, R.prim_value(T.float32(1.0)))
-            return B
-
-        @T.prim_func(s_tir=True)
-        def tir_main(A: T.Buffer(16, "float32"), B: T.Buffer(16, "float32")):
-            Expected.tir_subroutine_with_new_name(A.data, B.data)
-
-        @T.prim_func(private=True, s_tir=True)
-        def tir_subroutine_with_new_name(A_data: T.ptr("float32"), B_data: 
T.ptr("float32")):
-            A = T.decl_buffer(16, "float32", data=A_data)
-            B = T.decl_buffer(16, "float32", data=B_data)
-            for i in range(16):
-                B[i] = A[i] + 1.0
-
-    tvm.ir.assert_structural_equal(Expected, after)
-
-
-def test_simultaneous_replacements():
-    """Multiple replacements may be performed simultaneously"""
-
-    before = _get_before_module()
-    after = before.replace_global_vars(
-        {
-            "relax_main": "relax_main_with_new_name",
-            "relax_subroutine": "relax_subroutine_with_new_name",
-            "tir_main": "tir_main_with_new_name",
-            "tir_subroutine": "tir_subroutine_with_new_name",
-        }
-    )
-
-    @I.ir_module
-    class Expected:
-        @R.function
-        def relax_main_with_new_name(A: R.Tensor([16], "float32")) -> 
R.Tensor([16], "float32"):
-            R.func_attr({"relax.force_pure": True})
-
-            B = Expected.relax_subroutine_with_new_name(A)
-            C = R.call_tir(Expected.tir_main_with_new_name, B, 
out_sinfo=R.Tensor([16], "float32"))
-
-            D = R.builtin.alloc_tensor(R.shape([16]), "float32", 
runtime_device_index=0)
-            Expected.tir_main_with_new_name(C, D)
-
-            return D
-
-        @R.function(private=True)
-        def relax_subroutine_with_new_name(
-            A: R.Tensor([16], "float32"),
-        ) -> R.Tensor([16], "float32"):
-            B = R.add(A, R.prim_value(T.float32(1.0)))
-            return B
-
-        @T.prim_func(s_tir=True)
-        def tir_main_with_new_name(A: T.Buffer(16, "float32"), B: T.Buffer(16, 
"float32")):
-            Expected.tir_subroutine_with_new_name(A.data, B.data)
-
-        @T.prim_func(private=True, s_tir=True)
-        def tir_subroutine_with_new_name(A_data: T.ptr("float32"), B_data: 
T.ptr("float32")):
-            A = T.decl_buffer(16, "float32", data=A_data)
-            B = T.decl_buffer(16, "float32", data=B_data)
-            for i in range(16):
-                B[i] = A[i] + 1.0
-
-    tvm.ir.assert_structural_equal(Expected, after)
-
-
-if __name__ == "__main__":
-    tvm.testing.main()

Reply via email to