This is an automated email from the ASF dual-hosted git repository.
MasterJH5574 pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 649388ea39 [REFACTOR][IR] Simplify CallingConv attribute access
(#19799)
649388ea39 is described below
commit 649388ea3935c211acd9e86392e57a7400b8d342
Author: Tianqi Chen <[email protected]>
AuthorDate: Tue Jun 16 13:03:34 2026 -0400
[REFACTOR][IR] Simplify CallingConv attribute access (#19799)
CallingConv already participates in TVM FFI integral enum conversion, so
keeping call sites on manual integer casts adds noise without changing
behavior. This was not possible before the TVM FFI Any support but now
we natively support enum class int value conversion with Any, so we can
simplify the codepath
Main changes:
- Read `tvm::attr::kCallingConv` as `CallingConv` directly
- Compare optional/defaulted values against `CallingConv` enum values
- Store CallingConv enum values directly where the cleanup touches attr
writes
---
src/backend/cuda/codegen/codegen_cuda.cc | 21 +++++++++------------
src/backend/metal/codegen/codegen_metal.cc | 10 ++++++----
src/backend/opencl/codegen/codegen_opencl.cc | 10 ++++++----
src/backend/vulkan/codegen/spirv_utils.cc | 10 ++++++----
src/backend/webgpu/codegen/codegen_webgpu.cc | 10 ++++++----
src/tirx/analysis/verify_memory.cc | 4 ++--
src/tirx/transform/make_packed_api.cc | 13 ++++++-------
src/tirx/transform/split_host_device.cc | 8 ++++----
8 files changed, 45 insertions(+), 41 deletions(-)
diff --git a/src/backend/cuda/codegen/codegen_cuda.cc
b/src/backend/cuda/codegen/codegen_cuda.cc
index aa2ef63b14..e04541a73d 100644
--- a/src/backend/cuda/codegen/codegen_cuda.cc
+++ b/src/backend/cuda/codegen/codegen_cuda.cc
@@ -160,16 +160,15 @@ void CodeGenCUDA::Init(bool output_ssa) {
void CodeGenCUDA::PrintFunctionSignature(const ffi::String& function_name,
const PrimFunc& func,
std::ostream& os) {
- int64_t calling_conv = func->GetAttr<int64_t>(tvm::attr::kCallingConv,
-
static_cast<int64_t>(tvm::CallingConv::kDefault))
- .value();
- if (calling_conv == static_cast<int64_t>(CallingConv::kDeviceKernelLaunch)) {
+ CallingConv calling_conv =
+ func->GetAttr<CallingConv>(tvm::attr::kCallingConv,
CallingConv::kDefault).value();
+ if (calling_conv == CallingConv::kDeviceKernelLaunch) {
os << "extern \"C\" __global__ ";
- } else if (calling_conv == static_cast<int64_t>(CallingConv::kDefault)) {
+ } else if (calling_conv == CallingConv::kDefault) {
os << "extern \"C\" __device__ ";
} else {
TVM_FFI_THROW(InternalError) << "Unsupported calling convention for cuda
codegen: "
- << calling_conv;
+ << static_cast<int>(calling_conv);
}
CodeGenC::PrintFunctionSignature(function_name, func, os);
}
@@ -2107,12 +2106,10 @@ ffi::Module BuildCUDA(IRModule mod, Target target) {
for (auto [gvar, base_func] : mod->functions) {
TVM_FFI_ICHECK(base_func->IsInstance<PrimFuncNode>()) << "CodeGenCUDA: Can
only take PrimFunc";
auto prim_func = Downcast<PrimFunc>(base_func);
- int64_t calling_conv = prim_func
- ->GetAttr<int64_t>(tvm::attr::kCallingConv,
-
static_cast<int64_t>(tvm::CallingConv::kDefault))
- .value();
- TVM_FFI_ICHECK(calling_conv ==
static_cast<int64_t>(CallingConv::kDeviceKernelLaunch) ||
- calling_conv == static_cast<int64_t>(CallingConv::kDefault))
+ CallingConv calling_conv =
+ prim_func->GetAttr<CallingConv>(tvm::attr::kCallingConv,
CallingConv::kDefault).value();
+ TVM_FFI_ICHECK(calling_conv == CallingConv::kDeviceKernelLaunch ||
+ calling_conv == CallingConv::kDefault)
<< "CodeGenCUDA: expect calling_conv equals
CallingConv::kDeviceKernelLaunch or "
"CallingConv::kDefault";
functions.Set(gvar, prim_func);
diff --git a/src/backend/metal/codegen/codegen_metal.cc
b/src/backend/metal/codegen/codegen_metal.cc
index b68840f327..17668a4867 100644
--- a/src/backend/metal/codegen/codegen_metal.cc
+++ b/src/backend/metal/codegen/codegen_metal.cc
@@ -474,10 +474,12 @@ ffi::Module BuildMetal(IRModule mod, Target target) {
CodeGenMetal cg(target);
cg.Init(output_ssa);
auto f = Downcast<PrimFunc>(kv.second);
- auto calling_conv = f->GetAttr<int64_t>(tvm::attr::kCallingConv);
- TVM_FFI_ICHECK(calling_conv.has_value() &&
- calling_conv.value() ==
static_cast<int64_t>(CallingConv::kDeviceKernelLaunch))
- << "CodeGenMetal: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
+ auto calling_conv = f->GetAttr<CallingConv>(tvm::attr::kCallingConv);
+ TVM_FFI_ICHECK(calling_conv.has_value())
+ << "CodeGenMetal: expected kCallingConv attribute to be set.";
+ TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch)
+ << "CodeGenMetal: expect calling_conv equals
CallingConv::kDeviceKernelLaunch, but got "
+ << static_cast<int>(calling_conv.value());
cg.AddFunction(kv.first, f);
diff --git a/src/backend/opencl/codegen/codegen_opencl.cc
b/src/backend/opencl/codegen/codegen_opencl.cc
index 5bad02e558..a5a94c41da 100644
--- a/src/backend/opencl/codegen/codegen_opencl.cc
+++ b/src/backend/opencl/codegen/codegen_opencl.cc
@@ -689,10 +689,12 @@ ffi::Module BuildOpenCL(IRModule mod, Target target) {
TVM_FFI_ICHECK(base_func->IsInstance<PrimFuncNode>())
<< "CodeGenOpenCL: Can only take PrimFunc";
auto prim_func = Downcast<PrimFunc>(base_func);
- auto calling_conv = prim_func->GetAttr<int64_t>(tvm::attr::kCallingConv);
- TVM_FFI_ICHECK(calling_conv.has_value() &&
- calling_conv.value() ==
static_cast<int64_t>(CallingConv::kDeviceKernelLaunch))
- << "CodeGenOpenCL: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
+ auto calling_conv =
prim_func->GetAttr<CallingConv>(tvm::attr::kCallingConv);
+ TVM_FFI_ICHECK(calling_conv.has_value())
+ << "CodeGenOpenCL: expected kCallingConv attribute to be set.";
+ TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch)
+ << "CodeGenOpenCL: expect calling_conv equals
CallingConv::kDeviceKernelLaunch, but got "
+ << static_cast<int>(calling_conv.value());
functions.Set(gvar, prim_func);
}
diff --git a/src/backend/vulkan/codegen/spirv_utils.cc
b/src/backend/vulkan/codegen/spirv_utils.cc
index 11aecf1c43..6ee872a33a 100644
--- a/src/backend/vulkan/codegen/spirv_utils.cc
+++ b/src/backend/vulkan/codegen/spirv_utils.cc
@@ -124,10 +124,12 @@ std::pair<std::unordered_map<std::string,
runtime::SPIRVShader>, std::string> Lo
for (auto kv : mod->functions) {
TVM_FFI_ICHECK(kv.second->IsInstance<PrimFuncNode>()) << "CodeGenSPIRV:
Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
- auto calling_conv = f->GetAttr<int64_t>(tvm::attr::kCallingConv);
- TVM_FFI_ICHECK(calling_conv.has_value() &&
- calling_conv.value() ==
static_cast<int64_t>(CallingConv::kDeviceKernelLaunch))
- << "CodeGenSPIRV: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
+ auto calling_conv = f->GetAttr<CallingConv>(tvm::attr::kCallingConv);
+ TVM_FFI_ICHECK(calling_conv.has_value())
+ << "CodeGenSPIRV: expected kCallingConv attribute to be set.";
+ TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch)
+ << "CodeGenSPIRV: expect calling_conv equals
CallingConv::kDeviceKernelLaunch, but got "
+ << static_cast<int>(calling_conv.value());
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(global_symbol.has_value())
<< "CodeGenSPIRV: Expect PrimFunc to have the global_symbol attribute";
diff --git a/src/backend/webgpu/codegen/codegen_webgpu.cc
b/src/backend/webgpu/codegen/codegen_webgpu.cc
index 08c75ed840..9e7d2f5e84 100644
--- a/src/backend/webgpu/codegen/codegen_webgpu.cc
+++ b/src/backend/webgpu/codegen/codegen_webgpu.cc
@@ -760,10 +760,12 @@ ffi::Module BuildWebGPU(IRModule mod, Target target) {
TVM_FFI_ICHECK(kv.second->IsInstance<PrimFuncNode>())
<< "CodeGenWebGPU: Can only take PrimFunc";
auto f = Downcast<PrimFunc>(kv.second);
- auto calling_conv = f->GetAttr<int64_t>(tvm::attr::kCallingConv);
- TVM_FFI_ICHECK(calling_conv.has_value() &&
- calling_conv.value() ==
static_cast<int64_t>(CallingConv::kDeviceKernelLaunch))
- << "CodeGenWebGPU: expect calling_conv equals
CallingConv::kDeviceKernelLaunch";
+ auto calling_conv = f->GetAttr<CallingConv>(tvm::attr::kCallingConv);
+ TVM_FFI_ICHECK(calling_conv.has_value())
+ << "CodeGenWebGPU: expected kCallingConv attribute to be set.";
+ TVM_FFI_ICHECK(calling_conv.value() == CallingConv::kDeviceKernelLaunch)
+ << "CodeGenWebGPU: expect calling_conv equals
CallingConv::kDeviceKernelLaunch, but got "
+ << static_cast<int>(calling_conv.value());
auto global_symbol = f->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol);
TVM_FFI_ICHECK(global_symbol.has_value())
<< "CodeGenWebGPU: Expect PrimFunc to have the global_symbol
attribute";
diff --git a/src/tirx/analysis/verify_memory.cc
b/src/tirx/analysis/verify_memory.cc
index aa1a19cf0e..2c3396480c 100644
--- a/src/tirx/analysis/verify_memory.cc
+++ b/src/tirx/analysis/verify_memory.cc
@@ -177,8 +177,8 @@ std::vector<ffi::String> VerifyMemory_(const PrimFunc&
func) {
<< "' for primitive:" << std::endl
<< func;
- if (func->GetAttr<int64_t>(tvm::attr::kCallingConv,
static_cast<int64_t>(CallingConv::kDefault))
- .value() == static_cast<int64_t>(CallingConv::kDefault)) {
+ if (func->GetAttr<CallingConv>(tvm::attr::kCallingConv,
CallingConv::kDefault).value() ==
+ CallingConv::kDefault) {
MemoryAccessVerifier v(func, target.value()->GetTargetDeviceType());
v.Run();
return v.Errors();
diff --git a/src/tirx/transform/make_packed_api.cc
b/src/tirx/transform/make_packed_api.cc
index 4f8229080f..2d4eb80f03 100644
--- a/src/tirx/transform/make_packed_api.cc
+++ b/src/tirx/transform/make_packed_api.cc
@@ -178,8 +178,8 @@ class SubroutineCallRewriter : public StmtExprMutator {
ffi::Optional<ffi::String> RequiresPackedAPI(const PrimFunc& func) {
// A function with an explicit calling convention has already been
// lowered, and should not be modified.
- if (auto opt = func->GetAttr<int64_t>(tvm::attr::kCallingConv)) {
- if (CallingConv(opt.value()) != CallingConv::kDefault) {
+ if (auto opt = func->GetAttr<CallingConv>(tvm::attr::kCallingConv)) {
+ if (opt.value() != CallingConv::kDefault) {
return std::nullopt;
}
}
@@ -244,11 +244,10 @@ PrimFunc MakePackedAPI(PrimFunc func) {
ffi::Array<Var> args{v_self_handle, v_packed_args, v_num_packed_args,
v_result};
// reset global symbol to attach prefix
- func = WithAttrs(
- std::move(func),
- {{tvm::attr::kCallingConv, static_cast<int>(CallingConv::kCPackedFunc)},
- {tvm::attr::kTarget, target_host},
- {tvm::attr::kGlobalSymbol, ffi::symbol::tvm_ffi_symbol_prefix +
global_symbol.value()}});
+ func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv,
CallingConv::kCPackedFunc},
+ {tvm::attr::kTarget, target_host},
+ {tvm::attr::kGlobalSymbol,
+ ffi::symbol::tvm_ffi_symbol_prefix +
global_symbol.value()}});
Stmt body = ReturnRewriter(v_result)(func_ptr->body);
body = AttrStmt(make_zero(DataType::Int(32)), attr::compute_scope,
diff --git a/src/tirx/transform/split_host_device.cc
b/src/tirx/transform/split_host_device.cc
index acc5e473af..079309db3f 100644
--- a/src/tirx/transform/split_host_device.cc
+++ b/src/tirx/transform/split_host_device.cc
@@ -494,10 +494,10 @@ class DeviceKernelMutator : public StmtExprMutator {
write_ptr->body = ReturnRemover::Apply(write_ptr->body);
}
- func = WithAttrs(std::move(func), {{tvm::attr::kCallingConv,
-
static_cast<int>(tvm::CallingConv::kDeviceKernelLaunch)},
-
{tvm::tirx::attr::kKernelLaunchParams, info.launch_params},
- {tvm::attr::kGlobalSymbol,
info.global_symbol}});
+ func = WithAttrs(std::move(func),
+ {{tvm::attr::kCallingConv,
tvm::CallingConv::kDeviceKernelLaunch},
+ {tvm::tirx::attr::kKernelLaunchParams,
info.launch_params},
+ {tvm::attr::kGlobalSymbol, info.global_symbol}});
} else if (is_call_extern &&
!func->GetAttr<ffi::String>(tvm::attr::kGlobalSymbol)) {
func = WithAttr(func, tvm::attr::kGlobalSymbol, gvar->name_hint);