This is an automated email from the ASF dual-hosted git repository.
tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new b7c85e3a53 [REFACTOR] Use FFI types in runtime inline module-create
wrapper signatures (#19449)
b7c85e3a53 is described below
commit b7c85e3a53a0eda71c4358edabfbfc23f45ac274
Author: Tianqi Chen <[email protected]>
AuthorDate: Sun Apr 26 22:03:22 2026 -0400
[REFACTOR] Use FFI types in runtime inline module-create wrapper signatures
(#19449)
## Summary
This PR cleans up the public inline wrapper API for runtime backend
module creators introduced in #19447. The wrappers previously used a mix
of `std::unordered_map` / `std::vector` / `std::string` and `ffi::*`
types with conversion glue inside the wrapper body.
## Changes
- `ConstLoaderModuleCreate`: both parameters change from
`std::unordered_map<std::string, T>` to `ffi::Map<ffi::String, T>`; the
conversion loops in the wrapper body are removed (net −12 lines of glue
code). The callsite in `codegen_vm.cc` is updated to build `ffi::Map`
directly and passes the existing `ffi::Map<ffi::String, Tensor>` params
argument through without a copy.
- `VulkanModuleCreate`: `source` parameter changes from `std::string` to
`ffi::String`. The `SPIRVShader` smap remains `std::unordered_map`
because `SPIRVShader` is a plain C++ struct, not FFI-storable.
- `OpenCLModuleCreate` (SPIRV overload): `spirv_text` changes from
`const std::string&` to `ffi::String`. Same `SPIRVShader` constraint
applies.
- `metal_module.h`: removes stale `<unordered_map>`, `<vector>`,
`<memory>`, `<string>` includes left over from before #19447.
Internal `*ModuleCreateImpl` functions and the compiler-side
`SourceModuleCreate` family are unchanged.
---
src/relax/backend/vm/codegen_vm.cc | 18 ++++--------------
src/runtime/const_loader_module.cc | 3 +++
src/runtime/const_loader_module.h | 24 ++++--------------------
src/runtime/metal/metal_module.h | 5 -----
src/runtime/opencl/opencl_module.h | 6 ++----
src/runtime/vulkan/vulkan_module.h | 8 ++++----
6 files changed, 17 insertions(+), 47 deletions(-)
diff --git a/src/relax/backend/vm/codegen_vm.cc
b/src/relax/backend/vm/codegen_vm.cc
index 13dc02fde4..bf29556768 100644
--- a/src/relax/backend/vm/codegen_vm.cc
+++ b/src/relax/backend/vm/codegen_vm.cc
@@ -465,30 +465,20 @@ void LinkModules(ObjectPtr<VMExecutable> exec, const
ffi::Map<ffi::String, runti
const tvm::ffi::Module& lib, const ffi::Array<ffi::Module>&
ext_libs) {
// query if we need const loader for ext_modules
// Wrap all submodules in the initialization wrapper.
- std::unordered_map<std::string, std::vector<std::string>>
const_vars_by_symbol;
+ ffi::Map<ffi::String, ffi::Array<ffi::String>> const_vars_by_symbol;
for (tvm::ffi::Module mod : ext_libs) {
auto pf_sym = mod->GetFunction("get_symbol");
auto pf_var = mod->GetFunction("get_const_vars");
- std::vector<std::string> symbol_const_vars;
if (pf_sym.has_value() && pf_var.has_value()) {
ffi::String symbol = (*pf_sym)().cast<ffi::String>();
ffi::Array<ffi::String> variables =
(*pf_var)().cast<ffi::Array<ffi::String>>();
- for (size_t i = 0; i < variables.size(); i++) {
- symbol_const_vars.push_back(variables[i].operator std::string());
- }
- TVM_FFI_ICHECK_EQ(const_vars_by_symbol.count(symbol), 0U)
- << "Found duplicated symbol: " << symbol;
- const_vars_by_symbol[symbol] = symbol_const_vars;
+ TVM_FFI_ICHECK(!const_vars_by_symbol.count(symbol)) << "Found duplicated
symbol: " << symbol;
+ const_vars_by_symbol.Set(symbol, variables);
}
}
if (!const_vars_by_symbol.empty() || !params.empty()) {
// need runtime const information, run link const loader
- std::unordered_map<std::string, runtime::Tensor> const_var_tensor;
- for (const auto& [name, param] : params) {
- const_var_tensor[name] = param;
- }
- ffi::Module const_loader_mod =
- runtime::ConstLoaderModuleCreate(const_var_tensor,
const_vars_by_symbol);
+ ffi::Module const_loader_mod = runtime::ConstLoaderModuleCreate(params,
const_vars_by_symbol);
const_loader_mod->ImportModule(lib);
for (const auto& it : ext_libs) {
const_loader_mod->ImportModule(it);
diff --git a/src/runtime/const_loader_module.cc
b/src/runtime/const_loader_module.cc
index 006c1f1e1a..aaaeb9737e 100644
--- a/src/runtime/const_loader_module.cc
+++ b/src/runtime/const_loader_module.cc
@@ -39,6 +39,9 @@
#include <tvm/support/io.h>
#include <cstdint>
+#include <string>
+#include <unordered_map>
+#include <vector>
#include "../support/bytes_io.h"
diff --git a/src/runtime/const_loader_module.h
b/src/runtime/const_loader_module.h
index c97232016d..6722785cc9 100644
--- a/src/runtime/const_loader_module.h
+++ b/src/runtime/const_loader_module.h
@@ -29,13 +29,10 @@
#include <tvm/ffi/container/map.h>
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/function.h>
+#include <tvm/ffi/string.h>
#include <tvm/runtime/base.h>
#include <tvm/runtime/tensor.h>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
namespace tvm {
namespace runtime {
@@ -52,26 +49,13 @@ namespace runtime {
* The creator is always available (ConstLoaderModule is a runtime-universal
module).
*/
inline ffi::Module ConstLoaderModuleCreate(
- const std::unordered_map<std::string, Tensor>& const_var_tensor,
- const std::unordered_map<std::string, std::vector<std::string>>&
const_vars_by_symbol) {
+ const ffi::Map<ffi::String, Tensor>& const_var_tensor,
+ const ffi::Map<ffi::String, ffi::Array<ffi::String>>&
const_vars_by_symbol) {
static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.const_loader");
TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
<< "ffi.Module.create.const_loader is not registered in runtime. "
<< "Ensure libtvm_runtime is loaded.";
- // Convert to FFI-compatible types.
- ffi::Map<ffi::String, Tensor> ffi_const_var_tensor;
- for (const auto& kv : const_var_tensor) {
- ffi_const_var_tensor.Set(kv.first, kv.second);
- }
- ffi::Map<ffi::String, ffi::Array<ffi::String>> ffi_const_vars_by_symbol;
- for (const auto& kv : const_vars_by_symbol) {
- ffi::Array<ffi::String> vars;
- for (const auto& v : kv.second) {
- vars.push_back(ffi::String(v));
- }
- ffi_const_vars_by_symbol.Set(kv.first, vars);
- }
- return (*fcreate)(ffi_const_var_tensor,
ffi_const_vars_by_symbol).cast<ffi::Module>();
+ return (*fcreate)(const_var_tensor,
const_vars_by_symbol).cast<ffi::Module>();
}
} // namespace runtime
diff --git a/src/runtime/metal/metal_module.h b/src/runtime/metal/metal_module.h
index fe9454f674..3f4b3965ad 100644
--- a/src/runtime/metal/metal_module.h
+++ b/src/runtime/metal/metal_module.h
@@ -28,11 +28,6 @@
#include <tvm/ffi/extra/module.h>
#include <tvm/ffi/function.h>
-#include <memory>
-#include <string>
-#include <unordered_map>
-#include <vector>
-
#include "../metadata.h"
namespace tvm {
diff --git a/src/runtime/opencl/opencl_module.h
b/src/runtime/opencl/opencl_module.h
index 6697badd48..9d16ea9231 100644
--- a/src/runtime/opencl/opencl_module.h
+++ b/src/runtime/opencl/opencl_module.h
@@ -28,10 +28,8 @@
#include <tvm/ffi/function.h>
#include <tvm/runtime/base.h>
-#include <memory>
#include <string>
#include <unordered_map>
-#include <vector>
#include "../../support/bytes_io.h"
#include "../metadata.h"
@@ -74,7 +72,7 @@ inline ffi::Module OpenCLModuleCreate(ffi::String data,
ffi::String fmt,
*/
inline ffi::Module OpenCLModuleCreate(
const std::unordered_map<std::string, spirv::SPIRVShader>& shaders,
- const std::string& spirv_text, ffi::Map<ffi::String, FunctionInfo> fmap) {
+ const ffi::String& spirv_text, const ffi::Map<ffi::String, FunctionInfo>&
fmap) {
static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.opencl.spirv");
TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
<< "ffi.Module.create.opencl.spirv is not registered in runtime. "
@@ -87,7 +85,7 @@ inline ffi::Module OpenCLModuleCreate(
strm.Write(kv.second);
shader_bytes.Set(kv.first, ffi::Bytes(std::move(buf)));
}
- return (*fcreate)(shader_bytes, ffi::String(spirv_text),
fmap).cast<ffi::Module>();
+ return (*fcreate)(shader_bytes, spirv_text, fmap).cast<ffi::Module>();
}
} // namespace runtime
} // namespace tvm
diff --git a/src/runtime/vulkan/vulkan_module.h
b/src/runtime/vulkan/vulkan_module.h
index 87df473753..d8fdda4d92 100644
--- a/src/runtime/vulkan/vulkan_module.h
+++ b/src/runtime/vulkan/vulkan_module.h
@@ -48,9 +48,9 @@ namespace vulkan {
* and rehydrated on the runtime side.
* Requires libtvm_runtime built with USE_VULKAN=ON to have registered the
creator.
*/
-inline ffi::Module VulkanModuleCreate(std::unordered_map<std::string,
SPIRVShader> smap,
- ffi::Map<ffi::String, FunctionInfo> fmap,
- std::string source) {
+inline ffi::Module VulkanModuleCreate(const std::unordered_map<std::string,
SPIRVShader>& smap,
+ const ffi::Map<ffi::String,
FunctionInfo>& fmap,
+ const ffi::String& source) {
static const auto fcreate =
ffi::Function::GetGlobal("ffi.Module.create.vulkan");
TVM_FFI_CHECK(fcreate.has_value(), RuntimeError)
<< "ffi.Module.create.vulkan is not registered in runtime. "
@@ -63,7 +63,7 @@ inline ffi::Module
VulkanModuleCreate(std::unordered_map<std::string, SPIRVShade
strm.Write(kv.second);
shader_bytes.Set(kv.first, ffi::Bytes(std::move(buf)));
}
- return (*fcreate)(shader_bytes, fmap,
ffi::String(source)).cast<ffi::Module>();
+ return (*fcreate)(shader_bytes, fmap, source).cast<ffi::Module>();
}
} // namespace vulkan