This is an automated email from the ASF dual-hosted git repository.

tlopex pushed a commit to branch main
in repository https://gitbox.apache.org/repos/asf/tvm.git


The following commit(s) were added to refs/heads/main by this push:
     new a6625cbb94 [Fix][CUDA] Version compatibility of CUDA symbols (#19432)
a6625cbb94 is described below

commit a6625cbb944c7c57e78f8314e3668e60520b8a0e
Author: Xuhui Zheng <[email protected]>
AuthorDate: Fri Apr 24 07:46:58 2026 +0800

    [Fix][CUDA] Version compatibility of CUDA symbols (#19432)
    
    This pr closes issue #17771
    
    `CUDA_R_8F_E4M3`, `CUBLASLT_MATMUL_DESC_A_SCALE_POINTER` and
    `CUBLASLT_MATMUL_DESC_B_SCALE_POINTER` are only available in CUDA 11.8
    and above.
    
    The PR adds conditional compilation for the symbols.
---
 src/runtime/contrib/cublas/cublas.cc | 10 ++++++++++
 1 file changed, 10 insertions(+)

diff --git a/src/runtime/contrib/cublas/cublas.cc 
b/src/runtime/contrib/cublas/cublas.cc
index 6e667ee378..dcaf93d2da 100644
--- a/src/runtime/contrib/cublas/cublas.cc
+++ b/src/runtime/contrib/cublas/cublas.cc
@@ -170,8 +170,12 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t 
stream,
   } else if (TypeMatch(A->dtype, kDLInt, 8)) {
     ab_type = CUDA_R_8I;
   } else if (TypeMatch(A->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8)) {
+#if CUDART_VERSION >= 11080
     TVM_FFI_ICHECK(TypeMatch(B->dtype, DataType::TypeCode::kFloat8_e4m3fn, 8));
     ab_type = CUDA_R_8F_E4M3;
+#else
+    TVM_FFI_THROW(InternalError) << "Float8 (E4M3) is only supported in CUDA 
11.8 and above.";
+#endif
   }
 
   if (TypeMatch(C->dtype, kDLFloat, 16)) {
@@ -201,6 +205,7 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t stream,
                                                       &bias->data, 
sizeof(float*)));
   }
 
+#if CUDART_VERSION >= 11080
   if (scaleA != nullptr) {
     auto scaleA_data = static_cast<char*>(scaleA->data) + scaleA->byte_offset;
     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_A_SCALE_POINTER,
@@ -211,6 +216,11 @@ void CallCublasLt(cublasLtHandle_t hdl, cudaStream_t 
stream,
     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_B_SCALE_POINTER,
                                                       &scaleB_data, 
sizeof(float*)));
   }
+#else
+  if (scaleA != nullptr || scaleB != nullptr) {
+    TVM_FFI_THROW(InternalError) << "Scaling pointers are only supported in 
CUDA 11.8 and above.";
+  }
+#endif
 
   if (epilogue != CUBLASLT_EPILOGUE_DEFAULT) {
     CHECK_CUBLAS_ERROR(cublasLtMatmulDescSetAttribute(op_desc, 
CUBLASLT_MATMUL_DESC_EPILOGUE,

Reply via email to