While switching to TVMC, I noticed a "virtual_device" property on the top-level relay module function. It was not properly propagated through my relay passes and caused an assertion in lowering to TE, with:
Check failed: (!virtual_device->IsFullyUnconstrained()) is false at: ``` File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/__main__.py", line 24, in <module> tvmc.main.main() File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/main.py", line 115, in main sys.exit(_main(sys.argv[1:])) File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/main.py", line 103, in _main return args.func(args) File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", line 173, in drive_compile compile_model( File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", line 337, in compile_model graph_module = build( File "/home/user1/mlenv/deps/src/tvm/python/tvm/driver/tvmc/compiler.py", line 410, in build return relay.build( File "/home/user1/mlenv/deps/src/tvm/python/tvm/relay/build_module.py", line 431, in build graph_json, runtime_mod, params = bld_mod.build( File "/home/user1/mlenv/deps/src/tvm/python/tvm/relay/build_module.py", line 154, in build self._build(mod, raw_targets, executor, runtime, workspace_memory_pools, mod_name) File "/home/user1/mlenv/deps/src/tvm/python/tvm/_ffi/_ctypes/packed_func.py", line 237, in __call__ raise get_last_ffi_error() tvm._ffi.base.TVMError: Traceback (most recent call last): 29: TVMFuncCall 28: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::RelayBuildModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#3}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) 27: tvm::relay::backend::RelayBuildModule::BuildRelay(tvm::IRModule, tvm::runtime::String const&) 26: tvm::runtime::PackedFuncObj::Extractor<tvm::runtime::PackedFuncSubObj<tvm::relay::backend::AOTExecutorCodegenModule::GetFunction(std::__cxx11::basic_string<char, std::char_traits<char>, std::allocator<char> > const&, tvm::runtime::ObjectPtr<tvm::runtime::Object> const&)::{lambda(tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*)#2}> >::Call(tvm::runtime::PackedFuncObj const*, tvm::runtime::TVMArgs, tvm::runtime::TVMRetValue*) 25: tvm::relay::backend::AOTExecutorCodegen::Codegen(tvm::IRModule, tvm::relay::Function, tvm::runtime::String) 24: tvm::transform::Pass::operator()(tvm::IRModule) const 23: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 22: tvm::transform::SequentialNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 21: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 20: tvm::transform::ModulePassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 19: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_8IRModuleES5_NS_9transform11PassContextEEE17AssignTypedLambdaIZNS_5relay3tec7LowerTEENS0_6StringENS_17CompilationConfigESt8functionIFvNS_8BaseFuncEEEEUlS5_S7_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SL_SP_ 18: tvm::relay::tec::LowerTE(tvm::IRModule const&, tvm::runtime::String const&, std::function<void (tvm::BaseFunc)>, tvm::CompilationConfig) 17: tvm::transform::Pass::operator()(tvm::IRModule) const 16: tvm::transform::Pass::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 15: tvm::relay::transform::FunctionPassNode::operator()(tvm::IRModule, tvm::transform::PassContext const&) const 14: _ZN3tvm7runtime13PackedFuncObj9ExtractorINS0_16PackedFuncSubObjIZNS0_15TypedPackedFuncIFNS_5relay8FunctionES6_NS_8IRModuleENS_9transform11PassContextEEE17AssignTypedLambdaIZNS5_3tec15LowerTensorExprERKNS0_6StringENSD_10TECompilerESt8functionIFvNS_8BaseFuncEEENS_17CompilationConfigEEUlS6_S7_S9_E_EEvT_EUlRKNS0_7TVMArgsEPNS0_11TVMRetValueEE_EEE4CallEPKS1_SP_ST_ 13: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&) 12: _ZZN3tvm5relay11ExprFuncto 11: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::FunctionNode const*) 10: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::FunctionNode const*) 9: _ZN3tvm5relay9tr 8: tvm::relay::ExprMutator::VisitExpr_(tvm::relay::FunctionNode const*) 7: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&) 6: _ZZN3tvm5relay11ExprFuncto 5: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::LetNode const*) 4: tvm::relay::tec::LowerTensorExprMutator::PreVisitLetBinding_(tvm::relay::Var const&, tvm::RelayExpr const&) 3: tvm::relay::ExprMutator::VisitExpr(tvm::RelayExpr const&) 2: _ZZN3tvm5relay11ExprFuncto 1: tvm::relay::transform::DeviceAwareExprMutator::VisitExpr_(tvm::relay::CallNode const*) 0: tvm::relay::tec::LowerTensorExprMutator::DeviceAwareVisitExpr_(tvm::relay::CallNode const*) File "/home/user1/mlenv/deps/src/tvm/src/relay/backend/te_compiler.cc", line 885 ``` I noticed that this property is sometimes updated manually after creating new copies of a function: https://github.com/apache/tvm/blob/308d320a66f16abf67c5daf4ae58cec3567decdd/src/relay/ir/expr_functor.cc#L492 However, this was not always done and I had to patch the following cases to fix the compilation again: ``` diff --git a/src/relay/ir/transform.cc b/src/relay/ir/transform.cc index 1a16cc9be..d05a30626 100644 --- a/src/relay/ir/transform.cc +++ b/src/relay/ir/transform.cc @@ -131,6 +131,7 @@ IRModule FunctionPassNode::operator()(IRModule mod, const PassContext& pass_ctx) // only process optimizable Relay Functions if (const auto* function_node = AsOptimizableFunctionNode(kv.second)) { Function updated_func = pass_func(GetRef<Function>(function_node), updated_mod, pass_ctx); + updated_func->virtual_device_ = GetRef<Function>(function_node)->virtual_device(); updates.push_back({kv.first, std::move(updated_func)}); } } diff --git a/python/tvm/relay/expr_functor.py b/python/tvm/relay/expr_functor.py index b9ca7d0e1..889031ed4 100644 --- a/python/tvm/relay/expr_functor.py +++ b/python/tvm/relay/expr_functor.py @@ -204,7 +204,10 @@ class ExprMutator(ExprFunctor): def visit_function(self, fn): new_params = [self.visit(x) for x in fn.params] new_body = self.visit(fn.body) - return Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs) + func = Function(list(new_params), new_body, fn.ret_type, fn.type_params, fn.attrs) + from tvm.relay.function import FunctionCopyVirtualDevice + FunctionCopyVirtualDevice(func, fn) + return func def visit_let(self, let): new_var = self.visit(let.var) diff --git a/python/tvm/relay/function.py b/python/tvm/relay/function.py index f889f1e59..997fd1776 100644 --- a/python/tvm/relay/function.py +++ b/python/tvm/relay/function.py @@ -26,6 +26,10 @@ from . import _ffi_api +def FunctionCopyVirtualDevice(f1, f2): + _ffi_api.FunctionCopyVirtualDevice(f1, f2) + + @tvm._ffi.register_object("relay.Function") class Function(BaseFunc): """A function declaration expression. diff --git a/src/relay/ir/function.cc b/src/relay/ir/function.cc index 63e74144e..bd3906731 100644 --- a/src/relay/ir/function.cc +++ b/src/relay/ir/function.cc @@ -127,6 +127,10 @@ TVM_REGISTER_GLOBAL("relay.ir.Function") tvm::Array<TypeVar> ty_params, tvm::DictAttrs attrs) { return Function(params, body, ret_type, ty_params, attrs); }); +TVM_REGISTER_GLOBAL("relay.ir.FunctionCopyVirtualDevice") + .set_body_typed([](Function f1, Function f2) { + f1->virtual_device_ = f2->virtual_device_; + }); TVM_STATIC_IR_FUNCTOR(ReprPrinter, vtable) .set_dispatch<FunctionNode>([](const ObjectRef& ref, ReprPrinter* p) { ``` This does not seem like an elegant solution and I'm wondering why the virtual_device is not part of the Function() python interface. Would that be an appropriate solution? @mbs-octoml @electriclilies --- [Visit Topic](https://discuss.tvm.apache.org/t/relay-function-virtual-device-property/12958/1) to respond. You are receiving this because you enabled mailing list mode. To unsubscribe from these emails, [click here](https://discuss.tvm.apache.org/email/unsubscribe/210ea8fb1ff03f47764a08f3a41ee4fc7c54532725118fa9e53837103d4c1c18).