This is an automated email from the ASF dual-hosted git repository.
tqchen pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git
The following commit(s) were added to refs/heads/main by this push:
new 349225ae23 [REFACTOR][PYTHON] Revisit lifted support modules from
tvm.contrib (#19653)
349225ae23 is described below
commit 349225ae23c060509de5b83730fe955b09f7cf65
Author: Balint Cristian <[email protected]>
AuthorDate: Mon Jun 1 15:21:07 2026 +0300
[REFACTOR][PYTHON] Revisit lifted support modules from tvm.contrib (#19653)
In continuation of #19624 this catches some unlifted entries.
Hope there is no more left, for consistency it now covers comments and
perhaps non-active (hotpath) parts.
---
python/tvm/support/nvcc.py | 8 ++++----
src/runtime/cuda/cuda_module.cc | 2 +-
src/target/rocm/llvm/codegen_amdgpu.cc | 2 +-
src/tirx/transform/unsupported_dtype_legalize.cc | 10 +++++-----
tests/python/{contrib => support}/test_ccache.py | 2 +-
tests/python/{contrib => support}/test_popen_pool.py | 0
tests/python/{contrib => support}/test_util.py | 0
web/README.md | 2 +-
8 files changed, 13 insertions(+), 13 deletions(-)
diff --git a/python/tvm/support/nvcc.py b/python/tvm/support/nvcc.py
index b985e74778..94dbd59ff6 100644
--- a/python/tvm/support/nvcc.py
+++ b/python/tvm/support/nvcc.py
@@ -906,7 +906,7 @@ def callback_libdevice_path(arch):
return ""
-@tvm_ffi.register_global_func("tvm.contrib.nvcc.get_compute_version")
+@tvm_ffi.register_global_func("tvm.support.nvcc.get_compute_version")
def get_target_compute_version(target=None):
"""Utility function to get compute capability of compilation target.
@@ -1060,7 +1060,7 @@ def have_cudagraph():
return False
-@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_bf16")
+@tvm_ffi.register_global_func("tvm.support.nvcc.supports_bf16")
def have_bf16(compute_version):
"""Either bf16 support is provided in the compute capability or not
@@ -1076,7 +1076,7 @@ def have_bf16(compute_version):
return False
-@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp8")
+@tvm_ffi.register_global_func("tvm.support.nvcc.supports_fp8")
def have_fp8(compute_version):
"""Whether fp8 support is provided in the specified compute capability or
not
@@ -1094,7 +1094,7 @@ def have_fp8(compute_version):
return False
-@tvm_ffi.register_global_func("tvm.contrib.nvcc.supports_fp4")
+@tvm_ffi.register_global_func("tvm.support.nvcc.supports_fp4")
def have_fp4(compute_version):
"""Whether fp4 support is provided in the specified compute capability or
not
diff --git a/src/runtime/cuda/cuda_module.cc b/src/runtime/cuda/cuda_module.cc
index 9492f943a8..3f182afb82 100644
--- a/src/runtime/cuda/cuda_module.cc
+++ b/src/runtime/cuda/cuda_module.cc
@@ -162,7 +162,7 @@ class CUDAModuleNode : public ffi::ModuleObj {
auto fcompile = ffi::Function::GetGlobal("tvm_callback_cuda_compile");
TVM_FFI_CHECK(fcompile.has_value(), RuntimeError)
<< "fmt=='cuda' requires tvm_callback_cuda_compile to be registered. "
- << "Import tvm.contrib.nvcc.";
+ << "Import tvm.support.nvcc.";
return (*fcompile)(source).cast<ffi::Bytes>();
}
diff --git a/src/target/rocm/llvm/codegen_amdgpu.cc
b/src/target/rocm/llvm/codegen_amdgpu.cc
index 2da399231e..12a8aed79b 100644
--- a/src/target/rocm/llvm/codegen_amdgpu.cc
+++ b/src/target/rocm/llvm/codegen_amdgpu.cc
@@ -306,7 +306,7 @@ ffi::Module BuildAMDGPU(IRModule mod, Target target) {
auto flink = tvm::ffi::Function::GetGlobal("tvm_callback_rocm_link");
TVM_FFI_ICHECK(flink.has_value())
- << "Require tvm_callback_rocm_link to exist, do import tvm.contrib.rocm";
+ << "Require tvm_callback_rocm_link to exist, do import tvm.support.rocm";
TVMFFIByteArray arr;
arr.data = &obj[0];
diff --git a/src/tirx/transform/unsupported_dtype_legalize.cc
b/src/tirx/transform/unsupported_dtype_legalize.cc
index 558a3ca437..0bd703358c 100644
--- a/src/tirx/transform/unsupported_dtype_legalize.cc
+++ b/src/tirx/transform/unsupported_dtype_legalize.cc
@@ -736,7 +736,7 @@ namespace transform {
bool CheckDataTypeSupport(const Target& target, const std::string&
support_func_name) {
bool has_native_support = false;
if (target->kind->name == "cuda") {
- if (auto get_cv =
tvm::ffi::Function::GetGlobal("tvm.contrib.nvcc.get_compute_version")) {
+ if (auto get_cv =
tvm::ffi::Function::GetGlobal("tvm.support.nvcc.get_compute_version")) {
std::string compute_version = (*get_cv)(target).cast<std::string>();
if (auto check_support =
tvm::ffi::Function::GetGlobal(support_func_name)) {
has_native_support = (*check_support)(compute_version).cast<bool>();
@@ -750,7 +750,7 @@ Pass BF16ComputeLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
if (opt_target.defined() &&
- CheckDataTypeSupport(opt_target.value(),
"tvm.contrib.nvcc.supports_bf16")) {
+ CheckDataTypeSupport(opt_target.value(),
"tvm.support.nvcc.supports_bf16")) {
return f;
}
return BF16ComputeLegalizer().Legalize(f);
@@ -767,7 +767,7 @@ Pass BF16StorageLegalize() {
auto pass_func = [](PrimFunc f, IRModule m, PassContext ctx) {
auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
if (opt_target.defined() &&
- CheckDataTypeSupport(opt_target.value(),
"tvm.contrib.nvcc.supports_bf16")) {
+ CheckDataTypeSupport(opt_target.value(),
"tvm.support.nvcc.supports_bf16")) {
return f;
}
return BF16StorageLegalizer().Legalize(f);
@@ -784,7 +784,7 @@ Pass FP8ComputeLegalize(ffi::String promote_dtype) {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
if (opt_target.defined() &&
- CheckDataTypeSupport(opt_target.value(),
"tvm.contrib.nvcc.supports_fp8")) {
+ CheckDataTypeSupport(opt_target.value(),
"tvm.support.nvcc.supports_fp8")) {
return f;
}
return
FP8ComputeLegalizer(DataType(ffi::StringToDLDataType(promote_dtype))).Legalize(f);
@@ -801,7 +801,7 @@ Pass FP8StorageLegalize() {
auto pass_func = [=](PrimFunc f, IRModule m, PassContext ctx) {
auto opt_target = f->GetAttr<Target>(tvm::attr::kTarget);
if (opt_target.defined() &&
- CheckDataTypeSupport(opt_target.value(),
"tvm.contrib.nvcc.supports_fp8")) {
+ CheckDataTypeSupport(opt_target.value(),
"tvm.support.nvcc.supports_fp8")) {
return f;
}
return FP8StorageLegalizer().Legalize(f);
diff --git a/tests/python/contrib/test_ccache.py
b/tests/python/support/test_ccache.py
similarity index 98%
rename from tests/python/contrib/test_ccache.py
rename to tests/python/support/test_ccache.py
index 013b6896cb..f1f182562c 100644
--- a/tests/python/contrib/test_ccache.py
+++ b/tests/python/support/test_ccache.py
@@ -14,7 +14,7 @@
# KIND, either express or implied. See the License for the
# specific language governing permissions and limitations
# under the License.
-"""Test contrib.cc with ccache"""
+"""Test support.cc with ccache"""
import os
import shutil
diff --git a/tests/python/contrib/test_popen_pool.py
b/tests/python/support/test_popen_pool.py
similarity index 100%
rename from tests/python/contrib/test_popen_pool.py
rename to tests/python/support/test_popen_pool.py
diff --git a/tests/python/contrib/test_util.py
b/tests/python/support/test_util.py
similarity index 100%
rename from tests/python/contrib/test_util.py
rename to tests/python/support/test_util.py
diff --git a/web/README.md b/web/README.md
index 9b3cda1fb7..9488389e9b 100644
--- a/web/README.md
+++ b/web/README.md
@@ -43,7 +43,7 @@ make
```
This command will create the follow files:
-- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.contrib.emcc` will link
into.
+- `dist/wasm/libtvm_runtime.bc` bitcode library `tvm.support.emcc` will link
into.
- `dist/wasm/tvmjs_runtime.wasm` a standalone wasm runtime for testing
purposes.
- `dist/wasm/tvmjs_runtime.wasi.js` a WASI compatible library generated by
emscripten that can be fed into runtime.