While switching to TVMC, I noticed a "virtual_device" property on the top-level 
relay module function. It was not properly propagated through my relay passes 
and caused an assertion in lowering to TE, with:

    Check failed: (!virtual_device->IsFullyUnconstrained()) is false

at:

```
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/__main__.py", 
line 24, in <module>
    tvmc.main.main()
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/main.py", line 
115, in main
    sys.exit(_main(sys.argv[1:]))
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/main.py", line 
103, in _main
    return args.func(args)
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", 
line 173, in drive_compile
    compile_model(
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", 
line 337, in compile_model
    graph_module = build(
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", 
line 410, in build
    return relay.build(
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/relay/build_module.py", line 
431, in build
    graph_json, runtime_mod, params = bld_mod.build(
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/relay/build_module.py", line 
154, in build
    self._build(mod, raw_targets, executor, runtime, workspace_memory_pools, 
mod_name)
  File "/home/user1/mlenv/deps/src/tvm/python/tvm/_ffi/_ctypes/packed_func.py", 
line 237, in __call__
    raise get_last_ffi_error()
tvm._ffi.base.TVMError: Traceback (most recent call last):
  29: TVMFuncCall
  28: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char,
 std::char_traits<char>, std::allocator<char> > const&, 
tvm::runtime::ObjectPtr<tvm::runtime::Object> 
const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
  27: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, 
tvm::runtime::String const&)
  26: 
tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::AOTExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char,
 std::char_traits<char>, std::allocator<char> > const&, 
tvm::runtime::ObjectPtr<tvm::runtime::Object> 
const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> 
>::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, 
tvm::runtime::TVMRetValue*)
  25: tvm::relay::backend::AOTExecutorCodegen::Codegen(tvm::IRModule, 
tvm::relay::Function, tvm::runtime::String)
  24: tvm::transform::Pass::operator()(tvm::IRModule) const
  23: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
  22: tvm::transform::SequentialNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
  21: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
  20: tvm::transform::ModulePassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
  19: 
_ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay3tec7LowerTEENS0_6StringENS_17CompilationConfigESt8functionIFvNS_8BaseFuncEEEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SL_SP_
  18: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String 
const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig)
  17: tvm::transform::Pass::operator()(tvm::IRModule) const
  16: tvm::transform::Pass::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
  15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, 
tvm::transform::PassContext const&) const
  14: 
_ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_5relay8FunctionES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_3tec15LowerTensorExprERKNS0_6StringENSD_10TECompilerESt8functionIFvNS_8BaseFuncEEENS_17CompilationConfigEEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SP_ST_
  13: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  12: _ZZN3tvm5relay11ExprFuncto
  11: 
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode
 const*)
  10: 
tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode
 const*)
  9: _ZN3tvm5relay9tr
  8: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*)
  7: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  6: _ZZN3tvm5relay11ExprFuncto
  5: 
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::LetNode 
const*)
  4: 
tvm::relay::tec::LowerTensorExprMutator::PreVisitLetBinding_(tvm::relay::Var 
const&, tvm::RelayExpr const&)
  3: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&)
  2: _ZZN3tvm5relay11ExprFuncto
  1: 
tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode 
const*)
  0: 
tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode
 const*)
  File "/home/user1/mlenv/deps/src/tvm/src/relay/backend/te_compiler.cc", line 
885
```

I noticed that this property is sometimes updated manually after creating new 
copies of a function:

https://github.com/apache/tvm/blob/308d320a66f16abf67c5daf4ae58cec3567decdd/src/relay/ir/expr_functor.cc#L492

However, this was not always done and I had to patch the following cases to fix 
the compilation again:

```
diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc
index 1a16cc9be..d05a30626 100644
--- a/src/relay/ir/transform.cc
+++ b/src/relay/ir/transform.cc
@@ -131,6 +131,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const 
PassContext& pass_ctx)
     // only process optimizable Relay Functions
     if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) {
       Function updated_func = pass_func(GetRef<Function>(function_node), 
updated_mod, pass_ctx);
+      updated_func->virtual_device_ = 
GetRef<Function>(function_node)->virtual_device();
       updates.push_back({kv.first, std::move(updated_func)});
     }
   }
diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py
index b9ca7d0e1..889031ed4 100644
--- a/python/tvm/relay/expr_functor.py
+++ b/python/tvm/relay/expr_functor.py
@@ -204,7 +204,10 @@ class ExprMutator(ExprFunctor):
     def visit_function(self, fn):
         new_params = [self.visit(x) for x in fn.params]
         new_body = self.visit(fn.body)
-        return Function(list(new_params), new_body, fn.ret_type, 
fn.type_params, fn.attrs)
+        func = Function(list(new_params), new_body, fn.ret_type, 
fn.type_params, fn.attrs)
+        from tvm.relay.function import FunctionCopyVirtualDevice
+        FunctionCopyVirtualDevice(func, fn)
+        return func
 
     def visit_let(self, let):
         new_var = self.visit(let.var)
diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py
index f889f1e59..997fd1776 100644
--- a/python/tvm/relay/function.py
+++ b/python/tvm/relay/function.py
@@ -26,6 +26,10 @@
 from . import _ffi_api
 
 
+def FunctionCopyVirtualDevice(f1, f2):
+    _ffi_api.FunctionCopyVirtualDevice(f1, f2)
+
+
 @tvm._ffi.register_object("relay.Function")
 class Function(BaseFunc):
     """A function declaration expression.
diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc
index 63e74144e..bd3906731 100644
--- a/src/relay/ir/function.cc
+++ b/src/relay/ir/function.cc
@@ -127,6 +127,10 @@ TVM_REGISTER_GLOBAL("relay.ir.Function")
                        tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) {
       return Function(params, body, ret_type, ty_params, attrs);
     });
+TVM_REGISTER_GLOBAL("relay.ir.FunctionCopyVirtualDevice")
+    .set_body_typed([](Function f1, Function f2) {
+      f1->virtual_device_ = f2->virtual_device_;
+    });
 
 TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable)
     .set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) {
```

This does not seem like an elegant solution and I'm wondering why the 
virtual_device is not part of the Function() python interface. Would that be an 
appropriate solution?

@mbs-octoml  @electriclilies





---
[Visit 
Topic](https://discuss.tvm.apache.org/t/relay-function-virtual-device-property/12958/1)
 to respond.

You are receiving this because you enabled mailing list mode.

To unsubscribe from these emails, [click 
here](https://discuss.tvm.apache.org/email/unsubscribe/210ea8fb1ff03f47764a08f3a41ee4fc7c54532725118fa9e53837103d4c1c18).

Reply via email to