This is an automated email from the ASF dual-hosted git repository. tqchen pushed a commit to branch isolate-module-create-registry in repository https://gitbox.apache.org/repos/asf/tvm.git
commit 15ee4eab6aab81796af4b332470c2a7c04f5b1f7 Author: tqchen <[email protected]> AuthorDate: Sun Apr 26 13:26:25 2026 +0000 [REFACTOR] Isolate backend module creation via ffi.Module.create.<kind> registry This PR introduces FFI registry indirection for all backend module-creation functions, eliminating the hard linker dependency from libtvm_compiler.so on runtime-defined creator symbols. ## Summary Each backend runtime .cc now registers `ffi.Module.create.<kind>` (with FFI-compatible argument types) in a `TVM_FFI_STATIC_INIT_BLOCK`. The corresponding header-level `XxxModuleCreate` becomes an `inline` function that does a `static const auto fcreate = ffi::Function::GetGlobal(...)` lookup and dispatches through FFI — caching the lookup after first call. `nm -u libtvm_compiler.so | grep ModuleCreate` now returns nothing: the compiler no longer has ModuleCreate symbols as undefined externs. ## Changes - **cuda** (`ffi.Module.create.cuda`): string args → `ffi::String`. - **rocm** (`ffi.Module.create.rocm`): string args → `ffi::String`. - **metal** (`ffi.Module.create.metal`): `unordered_map<string,string>` → `ffi::Map<ffi::String,ffi::String>`. - **hexagon** (`ffi.Module.create.hexagon`): string args → `ffi::String`. - **opencl text** (`ffi.Module.create.opencl`): string args → `ffi::String`. - **opencl SPIRV** (`ffi.Module.create.opencl.spirv`): `unordered_map<string,SPIRVShader>` → `ffi::Map<ffi::String,ffi::Bytes>` (each shader serialised with flag+data via BytesOutStream; rehydrated on receive). - **vulkan** (`ffi.Module.create.vulkan`): same SPIRVShader→Bytes serialisation as opencl.spirv. - **const\_loader** (`ffi.Module.create.const_loader`): `unordered_map` args → `ffi::Map<ffi::String,Tensor>` and `ffi::Map<ffi::String,ffi::Array<ffi::String>>`. - **spirv\_shader.h moved**: `src/runtime/spirv/spirv_shader.h` → `src/runtime/vulkan/spirv_shader.h`; old path becomes a redirect header for backward compatibility. - **off-build stubs updated**: cuda/rocm/metal stubs are now empty (the inline wrappers handle the "off" case by throwing a clear registry-not-found error); opencl/hexagon off-stubs register fallback creators in the FFI registry. - **DeviceSourceModuleCreate** (has `std::function<>` arg): left as-is — `std::function` is not FFI-serialisable. Verification: `ninja tvm_runtime tvm_compiler cpptest` clean with HIDE_PRIVATE_SYMBOLS=ON; 128/128 cpptests pass; all-platform-minimal-test (54 pass), runtime/ (81 pass), relax/test_vm_build.py (84 pass); JVM BUILD SUCCESS. --- src/runtime/const_loader_module.cc | 23 ++++++++++-- src/runtime/const_loader_module.h | 31 ++++++++++++++-- src/runtime/cuda/cuda_module.cc | 17 ++++++--- src/runtime/cuda/cuda_module.h | 17 +++++++-- src/runtime/hexagon/hexagon_module.cc | 20 ++++++++-- src/runtime/hexagon/hexagon_module.h | 25 +++++++++---- src/runtime/metal/metal_module.h | 22 ++++++++--- src/runtime/metal/metal_module.mm | 20 +++++++--- src/runtime/opencl/opencl_module.cc | 17 ++++++--- src/runtime/opencl/opencl_module.h | 44 +++++++++++++++++++--- src/runtime/opencl/opencl_module_spirv.cc | 25 +++++++++++-- src/runtime/rocm/rocm_module.cc | 18 ++++++--- src/runtime/rocm/rocm_module.h | 18 +++++++-- src/runtime/spirv/spirv_shader.h | 55 +++------------------------- src/runtime/{spirv => vulkan}/spirv_shader.h | 6 +-- src/runtime/vulkan/vulkan_module.cc | 33 ++++++++++++++--- src/runtime/vulkan/vulkan_module.h | 39 ++++++++++++++++++-- src/target/opt/build_cuda_off.cc | 15 ++------ src/target/opt/build_hexagon_off.cc | 26 ++++++++++--- src/target/opt/build_metal_off.cc | 19 ++-------- src/target/opt/build_opencl_off.cc | 35 ++++++++++++------ src/target/opt/build_rocm_off.cc | 25 ++----------- src/target/source/codegen_metal.cc | 5 ++- 23 files changed, 367 insertions(+), 188 deletions(-) diff --git a/src/runtime/const_loader_module.cc b/src/runtime/const_loader_module.cc index ae0ea73bd0..7f9d45832c 100644 --- a/src/runtime/const_loader_module.cc +++ b/src/runtime/const_loader_module.cc @@ -250,7 +250,7 @@ class ConstLoaderModuleObj : public ffi::ModuleObj { std::unordered_map<std::string, std::vector<std::string>> const_vars_by_symbol_; }; -ffi::Module ConstLoaderModuleCreate( +static ffi::Module ConstLoaderModuleCreateInternal( const std::unordered_map<std::string, Tensor>& const_var_tensor, const std::unordered_map<std::string, std::vector<std::string>>& const_vars_by_symbol) { auto n = ffi::make_object<ConstLoaderModuleObj>(const_var_tensor, const_vars_by_symbol); @@ -259,8 +259,25 @@ ffi::Module ConstLoaderModuleCreate( TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.Module.load_from_bytes.const_loader", - ConstLoaderModuleObj::LoadFromBytes); + refl::GlobalDef() + .def("ffi.Module.load_from_bytes.const_loader", ConstLoaderModuleObj::LoadFromBytes) + .def("ffi.Module.create.const_loader", + [](ffi::Map<ffi::String, Tensor> const_var_tensor_ffi, + ffi::Map<ffi::String, ffi::Array<ffi::String>> const_vars_by_symbol_ffi) { + std::unordered_map<std::string, Tensor> const_var_tensor; + for (const auto& kv : const_var_tensor_ffi) { + const_var_tensor[std::string(kv.first)] = kv.second; + } + std::unordered_map<std::string, std::vector<std::string>> const_vars_by_symbol; + for (const auto& kv : const_vars_by_symbol_ffi) { + std::vector<std::string> vars; + for (const auto& v : kv.second) { + vars.push_back(std::string(v)); + } + const_vars_by_symbol[std::string(kv.first)] = vars; + } + return ConstLoaderModuleCreateInternal(const_var_tensor, const_vars_by_symbol); + }); } } // namespace runtime diff --git a/src/runtime/const_loader_module.h b/src/runtime/const_loader_module.h index 3bdbc1235c..c97232016d 100644 --- a/src/runtime/const_loader_module.h +++ b/src/runtime/const_loader_module.h @@ -25,7 +25,10 @@ #ifndef TVM_RUNTIME_CONST_LOADER_MODULE_H_ #define TVM_RUNTIME_CONST_LOADER_MODULE_H_ +#include <tvm/ffi/container/array.h> +#include <tvm/ffi/container/map.h> #include <tvm/ffi/extra/module.h> +#include <tvm/ffi/function.h> #include <tvm/runtime/base.h> #include <tvm/runtime/tensor.h> @@ -39,15 +42,37 @@ namespace runtime { /*! * \brief Create a ConstLoader module object. * - * \param const_var_tensor Maps consts var name to Tensor containing data for the var. + * \param const_var_tensor Maps const var name to Tensor containing data for the var. * \param const_vars_by_symbol Maps the name of a module init function to a list of names of * const vars whose data will be passed to that init function. * * \return The created ConstLoaderModule. + * + * Dispatches through the FFI registry ("ffi.Module.create.const_loader"). + * The creator is always available (ConstLoaderModule is a runtime-universal module). */ -TVM_RUNTIME_DLL ffi::Module ConstLoaderModuleCreate( +inline ffi::Module ConstLoaderModuleCreate( const std::unordered_map<std::string, Tensor>& const_var_tensor, - const std::unordered_map<std::string, std::vector<std::string>>& const_vars_by_symbol); + const std::unordered_map<std::string, std::vector<std::string>>& const_vars_by_symbol) { + static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.const_loader"); + TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) + << "ffi.Module.create.const_loader is not registered in runtime. " + << "Ensure libtvm_runtime is loaded."; + // Convert to FFI-compatible types. + ffi::Map<ffi::String, Tensor> ffi_const_var_tensor; + for (const auto& kv : const_var_tensor) { + ffi_const_var_tensor.Set(kv.first, kv.second); + } + ffi::Map<ffi::String, ffi::Array<ffi::String>> ffi_const_vars_by_symbol; + for (const auto& kv : const_vars_by_symbol) { + ffi::Array<ffi::String> vars; + for (const auto& v : kv.second) { + vars.push_back(ffi::String(v)); + } + ffi_const_vars_by_symbol.Set(kv.first, vars); + } + return (*fcreate)(ffi_const_var_tensor, ffi_const_vars_by_symbol).cast<ffi::Module>(); +} } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc index d06f5a9d5c..42f75c7a14 100644 --- a/src/runtime/cuda/cuda_module.cc +++ b/src/runtime/cuda/cuda_module.cc @@ -308,8 +308,9 @@ ffi::Optional<ffi::Function> CUDAModuleNode::GetFunction(const ffi::String& name return PackFuncVoidAddr(f, info->arg_types, info->arg_extra_tags); } -ffi::Module CUDAModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string cuda_source) { +static ffi::Module CUDAModuleCreateInternal(std::string data, std::string fmt, + ffi::Map<ffi::String, FunctionInfo> fmap, + std::string cuda_source) { auto n = ffi::make_object<CUDAModuleNode>(data, fmt, fmap, cuda_source); return ffi::Module(n); } @@ -322,7 +323,7 @@ ffi::Module CUDAModuleLoadFile(const std::string& file_name, const ffi::String& std::string meta_file = GetMetaFilePath(file_name); LoadBinaryFromFile(file_name, &data); LoadMetaDataFromFile(meta_file, &fmap); - return CUDAModuleCreate(data, fmt, fmap, std::string()); + return CUDAModuleCreateInternal(data, fmt, fmap, std::string()); } ffi::Module CUDAModuleLoadFromBytes(const ffi::Bytes& bytes) { @@ -333,7 +334,7 @@ ffi::Module CUDAModuleLoadFromBytes(const ffi::Bytes& bytes) { stream.Read(&fmt); TVM_FFI_ICHECK(stream.Read(&fmap)); stream.Read(&data); - return CUDAModuleCreate(data, fmt, fmap, std::string()); + return CUDAModuleCreateInternal(data, fmt, fmap, std::string()); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -342,7 +343,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("ffi.Module.load_from_file.cuda", CUDAModuleLoadFile) .def("ffi.Module.load_from_file.ptx", CUDAModuleLoadFile) .def("ffi.Module.load_from_file.cubin", CUDAModuleLoadFile) - .def("ffi.Module.load_from_bytes.cuda", CUDAModuleLoadFromBytes); + .def("ffi.Module.load_from_bytes.cuda", CUDAModuleLoadFromBytes) + .def("ffi.Module.create.cuda", + [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String cuda_source) { + return CUDAModuleCreateInternal(std::string(data), std::string(fmt), fmap, + std::string(cuda_source)); + }); } } // namespace runtime } // namespace tvm diff --git a/src/runtime/cuda/cuda_module.h b/src/runtime/cuda/cuda_module.h index 2a2b1068d7..1bd94332ef 100644 --- a/src/runtime/cuda/cuda_module.h +++ b/src/runtime/cuda/cuda_module.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_CUDA_CUDA_MODULE_H_ #include <tvm/ffi/extra/module.h> +#include <tvm/ffi/function.h> #include <tvm/runtime/base.h> #include <memory> @@ -46,10 +47,20 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmt The format of the data, can be "ptx", "cubin" * \param fmap The map function information map of each function. * \param cuda_source Optional, CUDA source file + * + * Dispatches through the FFI registry ("ffi.Module.create.cuda"). + * Requires libtvm_runtime built with USE_CUDA=ON to have registered the creator. */ -TVM_RUNTIME_DLL ffi::Module CUDAModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, - std::string cuda_source); +inline ffi::Module CUDAModuleCreate(ffi::String data, ffi::String fmt, + ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String cuda_source) { + static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.cuda"); + TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) + << "ffi.Module.create.cuda is not registered in runtime. " + << "Link or load libtvm_runtime built with USE_CUDA=ON."; + return (*fcreate)(data, fmt, fmap, cuda_source).cast<ffi::Module>(); +} + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_CUDA_CUDA_MODULE_H_ diff --git a/src/runtime/hexagon/hexagon_module.cc b/src/runtime/hexagon/hexagon_module.cc index dd9d74c202..90aebbaf4a 100644 --- a/src/runtime/hexagon/hexagon_module.cc +++ b/src/runtime/hexagon/hexagon_module.cc @@ -25,6 +25,7 @@ #include <tvm/ffi/extra/module.h> #include <tvm/ffi/function.h> +#include <tvm/ffi/reflection/registry.h> #include <tvm/support/io.h> #include <string> @@ -89,12 +90,25 @@ ffi::Bytes HexagonModuleNode::SaveToBytes() const { return ffi::Bytes(std::move(result)); } -ffi::Module HexagonModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str) { +static ffi::Module HexagonModuleCreateInternal(std::string data, std::string fmt, + ffi::Map<ffi::String, FunctionInfo> fmap, + std::string asm_str, std::string obj_str, + std::string ir_str, std::string bc_str) { auto n = ffi::make_object<HexagonModuleNode>(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str); return ffi::Module(n); } +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "ffi.Module.create.hexagon", + [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String asm_str, ffi::String obj_str, ffi::String ir_str, ffi::String bc_str) { + return HexagonModuleCreateInternal(std::string(data), std::string(fmt), fmap, + std::string(asm_str), std::string(obj_str), + std::string(ir_str), std::string(bc_str)); + }); +} + } // namespace runtime } // namespace tvm diff --git a/src/runtime/hexagon/hexagon_module.h b/src/runtime/hexagon/hexagon_module.h index df3fa7b5bb..eeae7b32b5 100644 --- a/src/runtime/hexagon/hexagon_module.h +++ b/src/runtime/hexagon/hexagon_module.h @@ -21,6 +21,7 @@ #define TVM_RUNTIME_HEXAGON_HEXAGON_MODULE_H_ #include <tvm/ffi/extra/module.h> +#include <tvm/ffi/function.h> #include <tvm/runtime/logging.h> #include <array> @@ -38,14 +39,24 @@ namespace runtime { * \param data The module data. * \param fmt The format of the data, can be "obj". * \param fmap The function information map of each function. - * \param asm_str ffi::String with the generated assembly source. - * \param obj_str ffi::String with the object file data. - * \param ir_str ffi::String with the disassembled LLVM IR source. - * \param bc_str ffi::String with the bitcode LLVM IR. + * \param asm_str String with the generated assembly source. + * \param obj_str String with the object file data. + * \param ir_str String with the disassembled LLVM IR source. + * \param bc_str String with the bitcode LLVM IR. + * + * Dispatches through the FFI registry ("ffi.Module.create.hexagon"). + * Requires libtvm_runtime built with USE_HEXAGON=ON to have registered the creator. */ -ffi::Module HexagonModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str); +inline ffi::Module HexagonModuleCreate(ffi::String data, ffi::String fmt, + ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String asm_str, ffi::String obj_str, ffi::String ir_str, + ffi::String bc_str) { + static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.hexagon"); + TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) + << "ffi.Module.create.hexagon is not registered in runtime. " + << "Link or load libtvm_runtime built with USE_HEXAGON=ON."; + return (*fcreate)(data, fmt, fmap, asm_str, obj_str, ir_str, bc_str).cast<ffi::Module>(); +} /*! \brief Module implementation for compiled Hexagon binaries. It is suitable diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h index 4534cede53..fe9454f674 100644 --- a/src/runtime/metal/metal_module.h +++ b/src/runtime/metal/metal_module.h @@ -24,7 +24,9 @@ #ifndef TVM_RUNTIME_METAL_METAL_MODULE_H_ #define TVM_RUNTIME_METAL_METAL_MODULE_H_ +#include <tvm/ffi/container/map.h> #include <tvm/ffi/extra/module.h> +#include <tvm/ffi/function.h> #include <memory> #include <string> @@ -41,14 +43,24 @@ static constexpr const int kMetalMaxNumDevice = 32; /*! * \brief create a metal module from data. * - * \param smap The map from name to each shader kernel. + * \param smap The map from name to each shader kernel (FFI-typed). * \param fmap The map function information map of each function. * \param fmt The format of the source, can be "metal" or "metallib" - * \param source Optional, source file, concatenaed for debug dump + * \param source Optional, source file, concatenated for debug dump + * + * Dispatches through the FFI registry ("ffi.Module.create.metal"). + * Requires libtvm_runtime built with USE_METAL=ON to have registered the creator. */ -ffi::Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string fmt, - std::string source); +inline ffi::Module MetalModuleCreate(ffi::Map<ffi::String, ffi::String> smap, + ffi::Map<ffi::String, FunctionInfo> fmap, ffi::String fmt, + ffi::String source) { + static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.metal"); + TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) + << "ffi.Module.create.metal is not registered in runtime. " + << "Link or load libtvm_runtime built with USE_METAL=ON."; + return (*fcreate)(smap, fmap, fmt, source).cast<ffi::Module>(); +} + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_METAL_METAL_MODULE_H_ diff --git a/src/runtime/metal/metal_module.mm b/src/runtime/metal/metal_module.mm index 6837404ad3..17974b89aa 100644 --- a/src/runtime/metal/metal_module.mm +++ b/src/runtime/metal/metal_module.mm @@ -272,9 +272,9 @@ ffi::Optional<ffi::Function> MetalModuleNode::GetFunction(const ffi::String& nam return ret; } -ffi::Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string fmt, - std::string source) { +static ffi::Module MetalModuleCreateInternal(std::unordered_map<std::string, std::string> smap, + ffi::Map<ffi::String, FunctionInfo> fmap, + std::string fmt, std::string source) { ObjectPtr<MetalModuleNode> n; AUTORELEASEPOOL { n = ffi::make_object<MetalModuleNode>(smap, fmap, fmt, source); }; return ffi::Module(n); @@ -295,7 +295,7 @@ TVM_FFI_STATIC_INIT_BLOCK() { fmap.Set(kv.first.cast<ffi::String>(), FunctionInfo(std::move(info_node))); } - return MetalModuleCreate( + return MetalModuleCreateInternal( std::unordered_map<std::string, std::string>(smap.begin(), smap.end()), fmap, fmt, source); }); @@ -315,12 +315,20 @@ ffi::Module MetalModuleLoadFromBytes(const ffi::Bytes& bytes) { TVM_FFI_ICHECK(stream.Read(&fmap)); stream.Read(&fmt); - return MetalModuleCreate(smap, fmap, fmt, ""); + return MetalModuleCreateInternal(smap, fmap, fmt, ""); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; - refl::GlobalDef().def("ffi.Module.load_from_bytes.metal", MetalModuleLoadFromBytes); + refl::GlobalDef() + .def("ffi.Module.load_from_bytes.metal", MetalModuleLoadFromBytes) + .def("ffi.Module.create.metal", + [](ffi::Map<ffi::String, ffi::String> smap, ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String fmt, ffi::String source) { + return MetalModuleCreateInternal( + std::unordered_map<std::string, std::string>(smap.begin(), smap.end()), fmap, + std::string(fmt), std::string(source)); + }); } } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module.cc b/src/runtime/opencl/opencl_module.cc index c7f873a021..b51fa3b55c 100644 --- a/src/runtime/opencl/opencl_module.cc +++ b/src/runtime/opencl/opencl_module.cc @@ -360,8 +360,9 @@ ffi::Optional<ffi::Function> OpenCLModuleNode::GetFunction(const ffi::String& na return OpenCLModuleNodeBase::GetFunction(name); } -ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string source) { +static ffi::Module OpenCLModuleCreateInternal(std::string data, std::string fmt, + ffi::Map<ffi::String, FunctionInfo> fmap, + std::string source) { auto n = ffi::make_object<OpenCLModuleNode>(data, fmt, fmap, source); n->Init(); return ffi::Module(n); @@ -375,7 +376,7 @@ ffi::Module OpenCLModuleLoadFile(const std::string& file_name, const ffi::String std::string meta_file = GetMetaFilePath(file_name); LoadBinaryFromFile(file_name, &data); LoadMetaDataFromFile(meta_file, &fmap); - return OpenCLModuleCreate(data, fmt, fmap, std::string()); + return OpenCLModuleCreateInternal(data, fmt, fmap, std::string()); } ffi::Module OpenCLModuleLoadFromBytes(const ffi::Bytes& bytes) { @@ -386,7 +387,7 @@ ffi::Module OpenCLModuleLoadFromBytes(const ffi::Bytes& bytes) { stream.Read(&fmt); TVM_FFI_ICHECK(stream.Read(&fmap)); stream.Read(&data); - return OpenCLModuleCreate(data, fmt, fmap, std::string()); + return OpenCLModuleCreateInternal(data, fmt, fmap, std::string()); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -394,7 +395,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { refl::GlobalDef() .def("ffi.Module.load_from_file.cl", OpenCLModuleLoadFile) .def("ffi.Module.load_from_file.clbin", OpenCLModuleLoadFile) - .def("ffi.Module.load_from_bytes.opencl", OpenCLModuleLoadFromBytes); + .def("ffi.Module.load_from_bytes.opencl", OpenCLModuleLoadFromBytes) + .def("ffi.Module.create.opencl", + [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String source) { + return OpenCLModuleCreateInternal(std::string(data), std::string(fmt), fmap, + std::string(source)); + }); } } // namespace runtime } // namespace tvm diff --git a/src/runtime/opencl/opencl_module.h b/src/runtime/opencl/opencl_module.h index 1e9bb88c93..fa8e89396e 100644 --- a/src/runtime/opencl/opencl_module.h +++ b/src/runtime/opencl/opencl_module.h @@ -24,15 +24,18 @@ #ifndef TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ #define TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ +#include <tvm/ffi/container/map.h> #include <tvm/ffi/function.h> #include <tvm/runtime/base.h> #include <memory> #include <string> +#include <unordered_map> #include <vector> +#include "../../support/bytes_io.h" #include "../metadata.h" -#include "../spirv/spirv_shader.h" +#include "../vulkan/spirv_shader.h" namespace tvm { namespace runtime { @@ -43,10 +46,19 @@ namespace runtime { * \param fmt The format of the data, can be "clbin", "cl" * \param fmap The map function information map of each function. * \param source Generated OpenCL kernels. + * + * Dispatches through the FFI registry ("ffi.Module.create.opencl"). + * Requires libtvm_runtime built with USE_OPENCL=ON to have registered the creator. */ -TVM_RUNTIME_DLL ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, - std::string source); +inline ffi::Module OpenCLModuleCreate(ffi::String data, ffi::String fmt, + ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String source) { + static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.opencl"); + TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) + << "ffi.Module.create.opencl is not registered in runtime. " + << "Link or load libtvm_runtime built with USE_OPENCL=ON."; + return (*fcreate)(data, fmt, fmap, source).cast<ffi::Module>(); +} /*! * \brief Create a opencl module from SPIRV. @@ -54,10 +66,30 @@ TVM_RUNTIME_DLL ffi::Module OpenCLModuleCreate(std::string data, std::string fmt * \param shaders The map from function names to SPIRV binaries. * \param spirv_text The concatenated text representation of SPIRV modules. * \param fmap The map function information map of each function. + * + * Dispatches through the FFI registry ("ffi.Module.create.opencl.spirv"). + * Each SPIRVShader is serialised to ffi::Bytes before crossing the FFI boundary. + * Requires libtvm_runtime built with USE_OPENCL=ON and TVM_ENABLE_SPIRV to have + * registered the creator. */ -TVM_RUNTIME_DLL ffi::Module OpenCLModuleCreate( +inline ffi::Module OpenCLModuleCreate( const std::unordered_map<std::string, spirv::SPIRVShader>& shaders, - const std::string& spirv_text, ffi::Map<ffi::String, FunctionInfo> fmap); + const std::string& spirv_text, ffi::Map<ffi::String, FunctionInfo> fmap) { + static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.opencl.spirv"); + TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) + << "ffi.Module.create.opencl.spirv is not registered in runtime. " + << "Link or load libtvm_runtime built with USE_OPENCL=ON and TVM_ENABLE_SPIRV."; + // Serialise each SPIRVShader to ffi::Bytes for the FFI boundary. + ffi::Map<ffi::String, ffi::Bytes> shader_bytes; + for (const auto& kv : shaders) { + std::string buf; + support::BytesOutStream strm(&buf); + strm.Write(kv.second.flag); + strm.Write(kv.second.data); + shader_bytes.Set(kv.first, ffi::Bytes(std::move(buf))); + } + return (*fcreate)(shader_bytes, ffi::String(spirv_text), fmap).cast<ffi::Module>(); +} } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_OPENCL_OPENCL_MODULE_H_ diff --git a/src/runtime/opencl/opencl_module_spirv.cc b/src/runtime/opencl/opencl_module_spirv.cc index 0125c4121a..4c6d4d10ff 100644 --- a/src/runtime/opencl/opencl_module_spirv.cc +++ b/src/runtime/opencl/opencl_module_spirv.cc @@ -18,6 +18,7 @@ */ #include <tvm/ffi/function.h> +#include <tvm/ffi/reflection/registry.h> #include <tvm/support/io.h> #include <string> @@ -129,13 +130,31 @@ cl_kernel OpenCLSPIRVModuleNode::InstallKernel(cl::OpenCLWorkspace* w, cl::OpenC return kernel; } -ffi::Module OpenCLModuleCreate(const std::unordered_map<std::string, SPIRVShader>& shaders, - const std::string& spirv_text, - ffi::Map<ffi::String, FunctionInfo> fmap) { +static ffi::Module OpenCLSPIRVModuleCreateInternal( + const std::unordered_map<std::string, SPIRVShader>& shaders, const std::string& spirv_text, + ffi::Map<ffi::String, FunctionInfo> fmap) { auto n = ffi::make_object<OpenCLSPIRVModuleNode>(shaders, spirv_text, fmap); n->Init(); return ffi::Module(n); } +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def("ffi.Module.create.opencl.spirv", + [](ffi::Map<ffi::String, ffi::Bytes> shader_bytes, ffi::String spirv_text, + ffi::Map<ffi::String, FunctionInfo> fmap) { + std::unordered_map<std::string, SPIRVShader> shaders; + for (const auto& kv : shader_bytes) { + support::BytesInStream stream(kv.second); + SPIRVShader shader; + TVM_FFI_ICHECK(stream.Read(&shader.flag)); + TVM_FFI_ICHECK(stream.Read(&shader.data)); + shaders[std::string(kv.first)] = shader; + } + return OpenCLSPIRVModuleCreateInternal(shaders, std::string(spirv_text), + fmap); + }); +} + } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.cc b/src/runtime/rocm/rocm_module.cc index 56f929c3c2..6f925063bd 100644 --- a/src/runtime/rocm/rocm_module.cc +++ b/src/runtime/rocm/rocm_module.cc @@ -208,9 +208,9 @@ ffi::Optional<ffi::Function> ROCMModuleNode::GetFunction(const ffi::String& name return PackFuncPackedArgAligned(f, info->arg_types); } -ffi::Module ROCMModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string hip_source, - std::string assembly) { +static ffi::Module ROCMModuleCreateInternal(std::string data, std::string fmt, + ffi::Map<ffi::String, FunctionInfo> fmap, + std::string hip_source, std::string assembly) { auto n = ffi::make_object<ROCMModuleNode>(data, fmt, fmap, hip_source, assembly); return ffi::Module(n); } @@ -222,7 +222,7 @@ ffi::Module ROCMModuleLoadFile(const std::string& file_name, const std::string& std::string meta_file = GetMetaFilePath(file_name); LoadBinaryFromFile(file_name, &data); LoadMetaDataFromFile(meta_file, &fmap); - return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); + return ROCMModuleCreateInternal(data, fmt, fmap, std::string(), std::string()); } ffi::Module ROCMModuleLoadFromBytes(const ffi::Bytes& bytes) { @@ -233,7 +233,7 @@ ffi::Module ROCMModuleLoadFromBytes(const ffi::Bytes& bytes) { stream.Read(&fmt); TVM_FFI_ICHECK(stream.Read(&fmap)); stream.Read(&data); - return ROCMModuleCreate(data, fmt, fmap, std::string(), std::string()); + return ROCMModuleCreateInternal(data, fmt, fmap, std::string(), std::string()); } TVM_FFI_STATIC_INIT_BLOCK() { @@ -242,7 +242,13 @@ TVM_FFI_STATIC_INIT_BLOCK() { .def("ffi.Module.load_from_bytes.hsaco", ROCMModuleLoadFromBytes) .def("ffi.Module.load_from_bytes.hip", ROCMModuleLoadFromBytes) .def("ffi.Module.load_from_file.hsaco", ROCMModuleLoadFile) - .def("ffi.Module.load_from_file.hip", ROCMModuleLoadFile); + .def("ffi.Module.load_from_file.hip", ROCMModuleLoadFile) + .def("ffi.Module.create.rocm", + [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String hip_source, ffi::String assembly) { + return ROCMModuleCreateInternal(std::string(data), std::string(fmt), fmap, + std::string(hip_source), std::string(assembly)); + }); } } // namespace runtime } // namespace tvm diff --git a/src/runtime/rocm/rocm_module.h b/src/runtime/rocm/rocm_module.h index 78f6d86d9c..666f73493f 100644 --- a/src/runtime/rocm/rocm_module.h +++ b/src/runtime/rocm/rocm_module.h @@ -25,6 +25,7 @@ #define TVM_RUNTIME_ROCM_ROCM_MODULE_H_ #include <tvm/ffi/extra/module.h> +#include <tvm/ffi/function.h> #include <memory> #include <string> @@ -45,10 +46,21 @@ static constexpr const int kMaxNumGPUs = 32; * \param fmt The format of the data, can be "hsaco" * \param fmap The map function information map of each function. * \param rocm_source Optional, rocm source file + * \param assembly Optional, GCN assembly source + * + * Dispatches through the FFI registry ("ffi.Module.create.rocm"). + * Requires libtvm_runtime built with USE_ROCM=ON to have registered the creator. */ -ffi::Module ROCMModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string rocm_source, - std::string assembly); +inline ffi::Module ROCMModuleCreate(ffi::String data, ffi::String fmt, + ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String rocm_source, ffi::String assembly) { + static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.rocm"); + TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) + << "ffi.Module.create.rocm is not registered in runtime. " + << "Link or load libtvm_runtime built with USE_ROCM=ON."; + return (*fcreate)(data, fmt, fmap, rocm_source, assembly).cast<ffi::Module>(); +} + } // namespace runtime } // namespace tvm #endif // TVM_RUNTIME_ROCM_ROCM_MODULE_H_ diff --git a/src/runtime/spirv/spirv_shader.h b/src/runtime/spirv/spirv_shader.h index 202b85b243..11d578c6e2 100644 --- a/src/runtime/spirv/spirv_shader.h +++ b/src/runtime/spirv/spirv_shader.h @@ -17,57 +17,14 @@ * under the License. */ +/*! + * \file src/runtime/spirv/spirv_shader.h + * \brief Deprecated include — SPIRVShader has moved to src/runtime/vulkan/spirv_shader.h. + * This header is kept for backward compatibility; include the new path directly. + */ #ifndef TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ #define TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ -#include <tvm/ffi/function.h> -#include <tvm/runtime/base.h> -#include <tvm/runtime/device_api.h> -#include <tvm/runtime/logging.h> -#include <tvm/support/io.h> -#include <tvm/support/serializer.h> - -#include <vector> - -namespace tvm { -namespace runtime { -namespace spirv { - -struct SPIRVShader { - /*! \brief header flag */ - uint32_t flag{0}; - /*! \brief Data segment */ - std::vector<uint32_t> data; - - void Save(support::Stream* writer) const { - writer->Write(flag); - writer->Write(data); - } - bool Load(support::Stream* reader) { - if (!reader->Read(&flag)) return false; - if (!reader->Read(&data)) return false; - return true; - } -}; - -} // namespace spirv - -using spirv::SPIRVShader; -} // namespace runtime -} // namespace tvm +#include "../vulkan/spirv_shader.h" -namespace tvm { -namespace support { -template <> -struct Serializer<::tvm::runtime::spirv::SPIRVShader> { - static constexpr bool enabled = true; - static void Write(Stream* strm, const ::tvm::runtime::spirv::SPIRVShader& data) { - data.Save(strm); - } - static bool Read(Stream* strm, ::tvm::runtime::spirv::SPIRVShader* data) { - return data->Load(strm); - } -}; -} // namespace support -} // namespace tvm #endif // TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ diff --git a/src/runtime/spirv/spirv_shader.h b/src/runtime/vulkan/spirv_shader.h similarity index 93% copy from src/runtime/spirv/spirv_shader.h copy to src/runtime/vulkan/spirv_shader.h index 202b85b243..f290d0dbd1 100644 --- a/src/runtime/spirv/spirv_shader.h +++ b/src/runtime/vulkan/spirv_shader.h @@ -17,8 +17,8 @@ * under the License. */ -#ifndef TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ -#define TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ +#ifndef TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ +#define TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ #include <tvm/ffi/function.h> #include <tvm/runtime/base.h> @@ -70,4 +70,4 @@ struct Serializer<::tvm::runtime::spirv::SPIRVShader> { }; } // namespace support } // namespace tvm -#endif // TVM_RUNTIME_SPIRV_SPIRV_SHADER_H_ +#endif // TVM_RUNTIME_VULKAN_SPIRV_SHADER_H_ diff --git a/src/runtime/vulkan/vulkan_module.cc b/src/runtime/vulkan/vulkan_module.cc index 9267115351..fd2533d7d0 100644 --- a/src/runtime/vulkan/vulkan_module.cc +++ b/src/runtime/vulkan/vulkan_module.cc @@ -25,14 +25,28 @@ #include "../../support/bytes_io.h" #include "../file_utils.h" +#include "spirv_shader.h" #include "vulkan_wrapped_func.h" namespace tvm { namespace runtime { namespace vulkan { -ffi::Module VulkanModuleCreate(std::unordered_map<std::string, SPIRVShader> smap, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string source) { +/*! + * \brief Deserialize a SPIRVShader from ffi::Bytes. + * Format: flag (uint32_t) followed by data (vector<uint32_t>). + */ +static SPIRVShader DeserializeSPIRVShader(const ffi::Bytes& bytes) { + support::BytesInStream stream(bytes); + SPIRVShader shader; + TVM_FFI_ICHECK(stream.Read(&shader.flag)); + TVM_FFI_ICHECK(stream.Read(&shader.data)); + return shader; +} + +static ffi::Module VulkanModuleCreateInternal(std::unordered_map<std::string, SPIRVShader> smap, + ffi::Map<ffi::String, FunctionInfo> fmap, + std::string source) { auto n = ffi::make_object<VulkanModuleNode>(smap, fmap, source); return ffi::Module(n); } @@ -50,7 +64,7 @@ ffi::Module VulkanModuleLoadFile(const std::string& file_name, const ffi::String stream.Read(&magic); TVM_FFI_ICHECK_EQ(magic, kVulkanModuleMagic) << "VulkanModule Magic mismatch"; stream.Read(&smap); - return VulkanModuleCreate(smap, fmap, ""); + return VulkanModuleCreateInternal(smap, fmap, ""); } ffi::Module VulkanModuleLoadFromBytes(const ffi::Bytes& bytes) { @@ -62,14 +76,23 @@ ffi::Module VulkanModuleLoadFromBytes(const ffi::Bytes& bytes) { ffi::Map<ffi::String, FunctionInfo> fmap; TVM_FFI_ICHECK(stream.Read(&fmap)); stream.Read(&smap); - return VulkanModuleCreate(smap, fmap, ""); + return VulkanModuleCreateInternal(smap, fmap, ""); } TVM_FFI_STATIC_INIT_BLOCK() { namespace refl = tvm::ffi::reflection; refl::GlobalDef() .def("ffi.Module.load_from_file.vulkan", VulkanModuleLoadFile) - .def("ffi.Module.load_from_bytes.vulkan", VulkanModuleLoadFromBytes); + .def("ffi.Module.load_from_bytes.vulkan", VulkanModuleLoadFromBytes) + .def("ffi.Module.create.vulkan", + [](ffi::Map<ffi::String, ffi::Bytes> shader_bytes, + ffi::Map<ffi::String, FunctionInfo> fmap, ffi::String source) { + std::unordered_map<std::string, SPIRVShader> smap; + for (const auto& kv : shader_bytes) { + smap[std::string(kv.first)] = DeserializeSPIRVShader(kv.second); + } + return VulkanModuleCreateInternal(smap, fmap, std::string(source)); + }); } } // namespace vulkan diff --git a/src/runtime/vulkan/vulkan_module.h b/src/runtime/vulkan/vulkan_module.h index 2337f3cc79..26965f9e76 100644 --- a/src/runtime/vulkan/vulkan_module.h +++ b/src/runtime/vulkan/vulkan_module.h @@ -20,21 +20,52 @@ #ifndef TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_ #define TVM_RUNTIME_VULKAN_VULKAN_MODULE_H_ +#include <tvm/ffi/container/map.h> #include <tvm/ffi/extra/module.h> +#include <tvm/ffi/function.h> #include <tvm/runtime/base.h> #include <string> #include <unordered_map> +#include "../../support/bytes_io.h" #include "../metadata.h" -#include "../spirv/spirv_shader.h" +#include "spirv_shader.h" namespace tvm { namespace runtime { namespace vulkan { -TVM_RUNTIME_DLL ffi::Module VulkanModuleCreate(std::unordered_map<std::string, SPIRVShader> smap, - ffi::Map<ffi::String, FunctionInfo> fmap, - std::string source); + +/*! + * \brief Create a Vulkan module from SPIRV shaders. + * + * \param smap Map from function name to SPIRVShader. + * \param fmap Map from function name to FunctionInfo. + * \param source Optional SPIRV text (for inspection). + * + * Dispatches through the FFI registry ("ffi.Module.create.vulkan"). + * Each SPIRVShader is serialised to ffi::Bytes before crossing the FFI boundary + * and rehydrated on the runtime side. + * Requires libtvm_runtime built with USE_VULKAN=ON to have registered the creator. + */ +inline ffi::Module VulkanModuleCreate(std::unordered_map<std::string, SPIRVShader> smap, + ffi::Map<ffi::String, FunctionInfo> fmap, + std::string source) { + static const auto fcreate = ffi::Function::GetGlobal("ffi.Module.create.vulkan"); + TVM_FFI_CHECK(fcreate.has_value(), RuntimeError) + << "ffi.Module.create.vulkan is not registered in runtime. " + << "Link or load libtvm_runtime built with USE_VULKAN=ON."; + // Serialise each SPIRVShader to ffi::Bytes for the FFI boundary. + ffi::Map<ffi::String, ffi::Bytes> shader_bytes; + for (const auto& kv : smap) { + std::string buf; + support::BytesOutStream strm(&buf); + strm.Write(kv.second.flag); + strm.Write(kv.second.data); + shader_bytes.Set(kv.first, ffi::Bytes(std::move(buf))); + } + return (*fcreate)(shader_bytes, fmap, ffi::String(source)).cast<ffi::Module>(); +} } // namespace vulkan diff --git a/src/target/opt/build_cuda_off.cc b/src/target/opt/build_cuda_off.cc index bf5c5d63d4..e9e4351c89 100644 --- a/src/target/opt/build_cuda_off.cc +++ b/src/target/opt/build_cuda_off.cc @@ -18,16 +18,7 @@ */ /*! - * Optional module when build CUDA is switched to off + * Optional module when build CUDA is switched to off. + * CUDAModuleCreate is now an inline registry-lookup wrapper in cuda_module.h, + * so no out-of-line stub is needed here. */ -#include "../../runtime/cuda/cuda_module.h" -namespace tvm { -namespace runtime { - -ffi::Module CUDAModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string cuda_source) { - TVM_FFI_THROW(InternalError) << "CUDA is not enabled"; - TVM_FFI_UNREACHABLE(); -} -} // namespace runtime -} // namespace tvm diff --git a/src/target/opt/build_hexagon_off.cc b/src/target/opt/build_hexagon_off.cc index 7fcb2b51a4..08450cf171 100644 --- a/src/target/opt/build_hexagon_off.cc +++ b/src/target/opt/build_hexagon_off.cc @@ -17,16 +17,32 @@ * under the License. */ +/*! + * Optional module when Hexagon runtime is switched to off. + * When ffi.Module.create.hexagon is not registered, HexagonModuleCreate (the inline + * wrapper) raises a clear RuntimeError. Fall back to a DeviceSourceModule for + * compilation-only (source inspection) workflows instead. + */ +#include "../../runtime/hexagon/hexagon_module.h" #include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -ffi::Module HexagonModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string asm_str, - std::string obj_str, std::string ir_str, std::string bc_str) { - LOG(WARNING) << "Hexagon runtime is not enabled, return a source module..."; - return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hex"); +// Register a fallback creator so that compiler-side code that calls +// HexagonModuleCreate() when USE_HEXAGON=OFF still gets a usable +// DeviceSourceModule (for source inspection / serialisation) rather than a +// registry-not-found error. +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef().def( + "ffi.Module.create.hexagon", + [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String /*asm_str*/, ffi::String /*obj_str*/, ffi::String /*ir_str*/, + ffi::String /*bc_str*/) -> ffi::Module { + LOG(WARNING) << "Hexagon runtime is not enabled, returning a source module..."; + return codegen::DeviceSourceModuleCreate(std::string(data), std::string(fmt), fmap, "hex"); + }); } } // namespace runtime diff --git a/src/target/opt/build_metal_off.cc b/src/target/opt/build_metal_off.cc index 7f544d92f6..fae5d511c6 100644 --- a/src/target/opt/build_metal_off.cc +++ b/src/target/opt/build_metal_off.cc @@ -18,20 +18,7 @@ */ /*! - * Optional module when build metal is switched to off + * Optional module when build metal is switched to off. + * MetalModuleCreate is now an inline registry-lookup wrapper in metal_module.h, + * so no out-of-line stub is needed here. */ -#include "../../runtime/metal/metal_module.h" -#include "../source/codegen_source_base.h" - -namespace tvm { -namespace runtime { - -ffi::Module MetalModuleCreate(std::unordered_map<std::string, std::string> smap, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string fmt, - std::string source) { - LOG(WARNING) << "Metal runtime not enabled, return a source module..."; - return codegen::DeviceSourceModuleCreate(source, fmt, fmap, "metal"); -} - -} // namespace runtime -} // namespace tvm diff --git a/src/target/opt/build_opencl_off.cc b/src/target/opt/build_opencl_off.cc index 1a27866a4c..e30725f442 100644 --- a/src/target/opt/build_opencl_off.cc +++ b/src/target/opt/build_opencl_off.cc @@ -18,24 +18,35 @@ */ /*! - * Optional module when build opencl is switched to off + * Optional module when build opencl is switched to off. + * Register fallback creators so that compiler-side code (codegen_opencl.cc) + * that calls OpenCLModuleCreate() when USE_OPENCL=OFF still gets a usable + * DeviceSourceModule for source inspection / serialisation workflows. */ -#include "../../runtime/opencl/opencl_module.h" +#include <tvm/ffi/reflection/registry.h> + +#include "../../runtime/metadata.h" #include "../source/codegen_source_base.h" namespace tvm { namespace runtime { -ffi::Module OpenCLModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string source) { - return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "opencl"); -} - -ffi::Module OpenCLModuleCreate(const std::unordered_map<std::string, SPIRVShader>& shaders, - const std::string& spirv_text, - ffi::Map<ffi::String, FunctionInfo> fmap) { - TVM_FFI_THROW(InternalError) << "OpenCLModuleCreate is called but OpenCL is not enabled."; - TVM_FFI_UNREACHABLE(); +TVM_FFI_STATIC_INIT_BLOCK() { + namespace refl = tvm::ffi::reflection; + refl::GlobalDef() + .def("ffi.Module.create.opencl", + [](ffi::String data, ffi::String fmt, ffi::Map<ffi::String, FunctionInfo> fmap, + ffi::String /*source*/) -> ffi::Module { + return codegen::DeviceSourceModuleCreate(std::string(data), std::string(fmt), fmap, + "opencl"); + }) + .def("ffi.Module.create.opencl.spirv", + [](ffi::Map<ffi::String, ffi::Bytes> /*shader_bytes*/, ffi::String /*spirv_text*/, + ffi::Map<ffi::String, FunctionInfo> /*fmap*/) -> ffi::Module { + TVM_FFI_THROW(InternalError) + << "OpenCLModuleCreate (SPIRV) is called but OpenCL is not enabled."; + TVM_FFI_UNREACHABLE(); + }); } } // namespace runtime diff --git a/src/target/opt/build_rocm_off.cc b/src/target/opt/build_rocm_off.cc index ea1265ad29..634c8252c8 100644 --- a/src/target/opt/build_rocm_off.cc +++ b/src/target/opt/build_rocm_off.cc @@ -18,26 +18,7 @@ */ /*! - * Optional module when build rocm is switched to off + * Optional module when build rocm is switched to off. + * ROCMModuleCreate is now an inline registry-lookup wrapper in rocm_module.h, + * so no out-of-line stub is needed here. */ -#include "../../runtime/rocm/rocm_module.h" -#include "../source/codegen_source_base.h" - -namespace tvm { -namespace runtime { - -ffi::Module ROCMModuleCreate(std::string data, std::string fmt, - ffi::Map<ffi::String, FunctionInfo> fmap, std::string rocm_source, - std::string assembly) { - LOG(WARNING) << "ROCM runtime is not enabled, return a source module..."; - auto fget_source = [rocm_source, assembly](const std::string& format) { - if (format.length() == 0) return assembly; - if (format == "ll" || format == "llvm") return rocm_source; - if (format == "asm") return assembly; - return std::string(""); - }; - return codegen::DeviceSourceModuleCreate(data, fmt, fmap, "hsaco", fget_source); -} - -} // namespace runtime -} // namespace tvm diff --git a/src/target/source/codegen_metal.cc b/src/target/source/codegen_metal.cc index 6831596c81..c4734c54bb 100644 --- a/src/target/source/codegen_metal.cc +++ b/src/target/source/codegen_metal.cc @@ -22,6 +22,7 @@ */ #include "codegen_metal.h" +#include <tvm/ffi/container/map.h> #include <tvm/ffi/reflection/registry.h> #include <tvm/tirx/transform.h> @@ -447,7 +448,7 @@ ffi::Module BuildMetal(IRModule mod, Target target) { mod = tirx::transform::PointerValueTypeRewrite()(std::move(mod)); std::ostringstream source_maker; - std::unordered_map<std::string, std::string> smap; + ffi::Map<ffi::String, ffi::String> smap; const auto fmetal_compile = tvm::ffi::Function::GetGlobal("tvm_callback_metal_compile"); std::string fmt = fmetal_compile ? "metallib" : "metal"; @@ -472,7 +473,7 @@ ffi::Module BuildMetal(IRModule mod, Target target) { if (fmetal_compile) { fsource = (*fmetal_compile)(fsource, target).cast<std::string>(); } - smap[func_name] = fsource; + smap.Set(func_name, fsource); } return MetalModuleCreate(smap, ExtractFuncInfo(mod), fmt, source_maker.str());
