https://github.com/Wolfram70 updated https://github.com/llvm/llvm-project/pull/168359
>From 02db2fe3ee31ea2a2183d8cbc4b7572bed839c65 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi <[email protected]> Date: Wed, 12 Nov 2025 09:02:24 +0000 Subject: [PATCH 1/3] [clang][NVPTX] Add intrinsics and builtins formixed-precision FP arithmetic This change adds NVVM intrinsics and clang builtins for mixed-precision FP arithmetic instructions. Tests are added in `mixed-precision-fp.ll` and `builtins-nvptx.c` and verified through `ptxas-13.0`. PTX Spec Reference: https://docs.nvidia.com/cuda/parallel-thread-execution/#mixed-precision-floating-point-instructions --- clang/include/clang/Basic/BuiltinsNVPTX.td | 64 +++++ clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp | 123 ++++++++++ clang/test/CodeGen/builtins-nvptx.c | 133 +++++++++++ llvm/include/llvm/IR/IntrinsicsNVVM.td | 25 ++ llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 44 ++++ llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll | 225 ++++++++++++++++++ 6 files changed, 614 insertions(+) create mode 100644 llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td index d923d2a90e908..47ba12bef058c 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.td +++ b/clang/include/clang/Basic/BuiltinsNVPTX.td @@ -401,6 +401,24 @@ def __nvvm_fma_rz_d : NVPTXBuiltin<"double(double, double, double)">; def __nvvm_fma_rm_d : NVPTXBuiltin<"double(double, double, double)">; def __nvvm_fma_rp_d : NVPTXBuiltin<"double(double, double, double)">; +def __nvvm_fma_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; + +def __nvvm_fma_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; +def __nvvm_fma_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; + // Rcp def __nvvm_rcp_rn_ftz_f : NVPTXBuiltin<"float(float)">; @@ -460,6 +478,52 @@ def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">; +def __nvvm_add_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; + +def __nvvm_add_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_add_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; + +// Sub + +def __nvvm_sub_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; + +def __nvvm_sub_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_sub_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; + // Convert def __nvvm_d2f_rn_ftz : NVPTXBuiltin<"float(double)">; diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp index 8a1cab3417d98..6f57620f0fb00 100644 --- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp @@ -415,6 +415,17 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID, return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF); } +static Value *MakeMixedPrecisionFPArithmetic(unsigned IntrinsicID, + const CallExpr *E, + CodeGenFunction &CGF) { + SmallVector<llvm::Value *, 3> Args; + for (unsigned i = 0; i < E->getNumArgs(); ++i) { + Args.push_back(CGF.EmitScalarExpr(E->getArg(i))); + } + return CGF.Builder.CreateCall( + CGF.CGM.getIntrinsic(IntrinsicID, {Args[0]->getType()}), Args); +} + } // namespace Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, @@ -1197,6 +1208,118 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, return Builder.CreateCall( CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_count), {EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))}); + case NVPTX::BI__nvvm_add_mixed_f16_f32: + case NVPTX::BI__nvvm_add_mixed_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_f32, E, + *this); + case NVPTX::BI__nvvm_add_mixed_rn_f16_f32: + case NVPTX::BI__nvvm_add_mixed_rn_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_f32, E, + *this); + case NVPTX::BI__nvvm_add_mixed_rz_f16_f32: + case NVPTX::BI__nvvm_add_mixed_rz_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_f32, E, + *this); + case NVPTX::BI__nvvm_add_mixed_rm_f16_f32: + case NVPTX::BI__nvvm_add_mixed_rm_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_f32, E, + *this); + case NVPTX::BI__nvvm_add_mixed_rp_f16_f32: + case NVPTX::BI__nvvm_add_mixed_rp_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_f32, E, + *this); + case NVPTX::BI__nvvm_add_mixed_sat_f16_f32: + case NVPTX::BI__nvvm_add_mixed_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_sat_f32, E, + *this); + case NVPTX::BI__nvvm_add_mixed_rn_sat_f16_f32: + case NVPTX::BI__nvvm_add_mixed_rn_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_sat_f32, + E, *this); + case NVPTX::BI__nvvm_add_mixed_rz_sat_f16_f32: + case NVPTX::BI__nvvm_add_mixed_rz_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_sat_f32, + E, *this); + case NVPTX::BI__nvvm_add_mixed_rm_sat_f16_f32: + case NVPTX::BI__nvvm_add_mixed_rm_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_sat_f32, + E, *this); + case NVPTX::BI__nvvm_add_mixed_rp_sat_f16_f32: + case NVPTX::BI__nvvm_add_mixed_rp_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_sat_f32, + E, *this); + case NVPTX::BI__nvvm_sub_mixed_f16_f32: + case NVPTX::BI__nvvm_sub_mixed_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_f32, E, + *this); + case NVPTX::BI__nvvm_sub_mixed_rn_f16_f32: + case NVPTX::BI__nvvm_sub_mixed_rn_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_f32, E, + *this); + case NVPTX::BI__nvvm_sub_mixed_rz_f16_f32: + case NVPTX::BI__nvvm_sub_mixed_rz_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_f32, E, + *this); + case NVPTX::BI__nvvm_sub_mixed_rm_f16_f32: + case NVPTX::BI__nvvm_sub_mixed_rm_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_f32, E, + *this); + case NVPTX::BI__nvvm_sub_mixed_rp_f16_f32: + case NVPTX::BI__nvvm_sub_mixed_rp_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_f32, E, + *this); + case NVPTX::BI__nvvm_sub_mixed_sat_f16_f32: + case NVPTX::BI__nvvm_sub_mixed_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_sat_f32, E, + *this); + case NVPTX::BI__nvvm_sub_mixed_rn_sat_f16_f32: + case NVPTX::BI__nvvm_sub_mixed_rn_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_sat_f32, + E, *this); + case NVPTX::BI__nvvm_sub_mixed_rz_sat_f16_f32: + case NVPTX::BI__nvvm_sub_mixed_rz_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_sat_f32, + E, *this); + case NVPTX::BI__nvvm_sub_mixed_rm_sat_f16_f32: + case NVPTX::BI__nvvm_sub_mixed_rm_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_sat_f32, + E, *this); + case NVPTX::BI__nvvm_sub_mixed_rp_sat_f16_f32: + case NVPTX::BI__nvvm_sub_mixed_rp_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_sat_f32, + E, *this); + case NVPTX::BI__nvvm_fma_mixed_rn_f16_f32: + case NVPTX::BI__nvvm_fma_mixed_rn_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_f32, E, + *this); + case NVPTX::BI__nvvm_fma_mixed_rz_f16_f32: + case NVPTX::BI__nvvm_fma_mixed_rz_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_f32, E, + *this); + case NVPTX::BI__nvvm_fma_mixed_rm_f16_f32: + case NVPTX::BI__nvvm_fma_mixed_rm_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_f32, E, + *this); + case NVPTX::BI__nvvm_fma_mixed_rp_f16_f32: + case NVPTX::BI__nvvm_fma_mixed_rp_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_f32, E, + *this); + case NVPTX::BI__nvvm_fma_mixed_rn_sat_f16_f32: + case NVPTX::BI__nvvm_fma_mixed_rn_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_sat_f32, + E, *this); + case NVPTX::BI__nvvm_fma_mixed_rz_sat_f16_f32: + case NVPTX::BI__nvvm_fma_mixed_rz_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_sat_f32, + E, *this); + case NVPTX::BI__nvvm_fma_mixed_rm_sat_f16_f32: + case NVPTX::BI__nvvm_fma_mixed_rm_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_sat_f32, + E, *this); + case NVPTX::BI__nvvm_fma_mixed_rp_sat_f16_f32: + case NVPTX::BI__nvvm_fma_mixed_rp_sat_bf16_f32: + return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_sat_f32, + E, *this); default: return nullptr; } diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index e3be262622844..1753b4c7767e9 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -1466,3 +1466,136 @@ __device__ void nvvm_min_max_sm86() { #endif // CHECK: ret void } + +#define F16 (__fp16)0.1f +#define F16_2 (__fp16)0.2f + +__device__ void nvvm_add_mixed_precision_sm100() { +#if __CUDA_ARCH__ >= 1000 + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_add_mixed_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_add_mixed_rn_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_add_mixed_rz_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_add_mixed_rm_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_add_mixed_rp_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_add_mixed_sat_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_add_mixed_rn_sat_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_add_mixed_rz_sat_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_add_mixed_rm_sat_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_add_mixed_rp_sat_f16_f32(F16, 1.0f); + + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_add_mixed_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_add_mixed_rn_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_add_mixed_rz_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_add_mixed_rm_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_add_mixed_rp_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_add_mixed_sat_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_add_mixed_rn_sat_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_add_mixed_rz_sat_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_add_mixed_rm_sat_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_add_mixed_rp_sat_bf16_f32(BF16, 1.0f); +#endif +} + +__device__ void nvvm_sub_mixed_precision_sm100() { +#if __CUDA_ARCH__ >= 1000 + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_sub_mixed_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_sub_mixed_rn_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_sub_mixed_rz_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_sub_mixed_rm_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_sub_mixed_rp_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_sub_mixed_sat_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_sub_mixed_rn_sat_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_sub_mixed_rz_sat_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_sub_mixed_rm_sat_f16_f32(F16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00) + __nvvm_sub_mixed_rp_sat_f16_f32(F16, 1.0f); + + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_sub_mixed_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_sub_mixed_rn_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_sub_mixed_rz_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_sub_mixed_rm_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_sub_mixed_rp_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_sub_mixed_sat_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_sub_mixed_rn_sat_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_sub_mixed_rz_sat_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_sub_mixed_rm_sat_bf16_f32(BF16, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) + __nvvm_sub_mixed_rp_sat_bf16_f32(BF16, 1.0f); +#endif +} + +__device__ void nvvm_fma_mixed_precision_sm100() { +#if __CUDA_ARCH__ >= 1000 + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) + __nvvm_fma_mixed_rn_f16_f32(F16, F16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) + __nvvm_fma_mixed_rz_f16_f32(F16, F16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) + __nvvm_fma_mixed_rm_f16_f32(F16, F16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) + __nvvm_fma_mixed_rp_f16_f32(F16, F16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) + __nvvm_fma_mixed_rn_sat_f16_f32(F16, F16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) + __nvvm_fma_mixed_rz_sat_f16_f32(F16, F16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) + __nvvm_fma_mixed_rm_sat_f16_f32(F16, F16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) + __nvvm_fma_mixed_rp_sat_f16_f32(F16, F16_2, 1.0f); + + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) + __nvvm_fma_mixed_rn_bf16_f32(BF16, BF16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) + __nvvm_fma_mixed_rz_bf16_f32(BF16, BF16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) + __nvvm_fma_mixed_rm_bf16_f32(BF16, BF16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) + __nvvm_fma_mixed_rp_bf16_f32(BF16, BF16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) + __nvvm_fma_mixed_rn_sat_bf16_f32(BF16, BF16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) + __nvvm_fma_mixed_rz_sat_bf16_f32(BF16, BF16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) + __nvvm_fma_mixed_rm_sat_bf16_f32(BF16, BF16_2, 1.0f); + // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) + __nvvm_fma_mixed_rp_sat_bf16_f32(BF16, BF16_2, 1.0f); +#endif +} diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 21badc2692037..7a7fce42e55b0 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1386,6 +1386,14 @@ let TargetPrefix = "nvvm" in { PureIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty, llvm_double_ty]>; } + + foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { + foreach sat = ["", "_sat"] in { + def int_nvvm_fma_mixed # rnd # sat # _f32 : + PureIntrinsic<[llvm_float_ty], + [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty]>; + } + } // // Rcp @@ -1453,6 +1461,23 @@ let TargetPrefix = "nvvm" in { } } + foreach rnd = ["", "_rn", "_rz", "_rm", "_rp"] in { + foreach sat = ["", "_sat"] in { + def int_nvvm_add_mixed # rnd # sat # _f32 : + PureIntrinsic<[llvm_float_ty], [llvm_anyfloat_ty, llvm_float_ty]>; + } + } + + // + // Sub + // + foreach rnd = ["", "_rn", "_rz", "_rm", "_rp"] in { + foreach sat = ["", "_sat"] in { + def int_nvvm_sub_mixed # rnd # sat # _f32 : + PureIntrinsic<[llvm_float_ty], [llvm_anyfloat_ty, llvm_float_ty]>; + } + } + // // Dot Product // diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index ea69a54e6db37..07483bf5a3e3d 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1694,6 +1694,20 @@ multiclass FMA_INST { defm INT_NVVM_FMA : FMA_INST; +foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in { + foreach sat = ["", "_SAT"] in { + foreach type = ["F16", "BF16"] in { + def INT_NVVM_FMA # rnd # sat # _F32_ # type : + BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c), + !tolower(!subst("_", ".", "fma" # rnd # sat # "_f32_" # type)), + [(set f32:$dst, + (!cast<Intrinsic>(!tolower("int_nvvm_fma_mixed" # rnd # sat # "_f32")) + !cast<ValueType>(!tolower(type)):$a, !cast<ValueType>(!tolower(type)):$b, f32:$c))]>, + Requires<[hasSM<100>, hasPTX<86>]>; + } + } +} + // // Rcp // @@ -1806,6 +1820,36 @@ def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d> def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>; def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>; +foreach rnd = ["", "_RN", "_RZ", "_RM", "_RP"] in { + foreach sat = ["", "_SAT"] in { + foreach type = ["F16", "BF16"] in { + def INT_NVVM_ADD # rnd # sat # _F32_ # type : + BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b), + !tolower(!subst("_", ".", "add" # rnd # sat # "_f32_" # type)), + [(set f32:$dst, + (!cast<Intrinsic>(!tolower("int_nvvm_add_mixed" # rnd # sat # "_f32")) + !cast<ValueType>(!tolower(type)):$a, f32:$b))]>, + Requires<[hasSM<100>, hasPTX<86>]>; + } + } +} +// +// Sub +// + +foreach rnd = ["", "_RN", "_RZ", "_RM", "_RP"] in { + foreach sat = ["", "_SAT"] in { + foreach type = ["F16", "BF16"] in { + def INT_NVVM_SUB # rnd # sat # _F32_ # type : + BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b), + !tolower(!subst("_", ".", "sub" # rnd # sat # "_f32_" # type)), + [(set f32:$dst, + (!cast<Intrinsic>(!tolower("int_nvvm_sub_mixed" # rnd # sat # "_f32")) + !cast<ValueType>(!tolower(type)):$a, f32:$b))]>, + Requires<[hasSM<100>, hasPTX<86>]>; + } + } +} // // BFIND // diff --git a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll new file mode 100644 index 0000000000000..a4f2fe68830f5 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll @@ -0,0 +1,225 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | FileCheck %s +; RUN: %if ptxas-sm_100 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_100 -mattr=+ptx86 | %ptxas-verify -arch=sm_100 %} + +; ADD + +define float @test_add_f32_f16(half %a, float %b) { +; CHECK-LABEL: test_add_f32_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<12>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_f16_param_0]; +; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_f16_param_1]; +; CHECK-NEXT: add.f32.f16 %r2, %rs1, %r1; +; CHECK-NEXT: add.rn.f32.f16 %r3, %rs1, %r2; +; CHECK-NEXT: add.rz.f32.f16 %r4, %rs1, %r3; +; CHECK-NEXT: add.rm.f32.f16 %r5, %rs1, %r4; +; CHECK-NEXT: add.rp.f32.f16 %r6, %rs1, %r5; +; CHECK-NEXT: add.sat.f32.f16 %r7, %rs1, %r6; +; CHECK-NEXT: add.rn.sat.f32.f16 %r8, %rs1, %r7; +; CHECK-NEXT: add.rz.sat.f32.f16 %r9, %rs1, %r8; +; CHECK-NEXT: add.rm.sat.f32.f16 %r10, %rs1, %r9; +; CHECK-NEXT: add.rp.sat.f32.f16 %r11, %rs1, %r10; +; CHECK-NEXT: st.param.b32 [func_retval0], %r11; +; CHECK-NEXT: ret; + %r1 = call float @llvm.nvvm.add.mixed.f32.f16(half %a, float %b) + %r2 = call float @llvm.nvvm.add.mixed.rn.f32.f16(half %a, float %r1) + %r3 = call float @llvm.nvvm.add.mixed.rz.f32.f16(half %a, float %r2) + %r4 = call float @llvm.nvvm.add.mixed.rm.f32.f16(half %a, float %r3) + %r5 = call float @llvm.nvvm.add.mixed.rp.f32.f16(half %a, float %r4) + + ; SAT + %r6 = call float @llvm.nvvm.add.mixed.sat.f32.f16(half %a, float %r5) + %r7 = call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half %a, float %r6) + %r8 = call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half %a, float %r7) + %r9 = call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half %a, float %r8) + %r10 = call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half %a, float %r9) + + ret float %r10 +} + +define float @test_add_f32_bf16(bfloat %a, float %b) { +; CHECK-LABEL: test_add_f32_bf16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<12>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_bf16_param_0]; +; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_bf16_param_1]; +; CHECK-NEXT: add.f32.bf16 %r2, %rs1, %r1; +; CHECK-NEXT: add.rn.f32.bf16 %r3, %rs1, %r2; +; CHECK-NEXT: add.rz.f32.bf16 %r4, %rs1, %r3; +; CHECK-NEXT: add.rm.f32.bf16 %r5, %rs1, %r4; +; CHECK-NEXT: add.rp.f32.bf16 %r6, %rs1, %r5; +; CHECK-NEXT: add.sat.f32.bf16 %r7, %rs1, %r6; +; CHECK-NEXT: add.rn.sat.f32.bf16 %r8, %rs1, %r7; +; CHECK-NEXT: add.rz.sat.f32.bf16 %r9, %rs1, %r8; +; CHECK-NEXT: add.rm.sat.f32.bf16 %r10, %rs1, %r9; +; CHECK-NEXT: add.rp.sat.f32.bf16 %r11, %rs1, %r10; +; CHECK-NEXT: st.param.b32 [func_retval0], %r11; +; CHECK-NEXT: ret; + %r1 = call float @llvm.nvvm.add.mixed.f32.bf16(bfloat %a, float %b) + %r2 = call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat %a, float %r1) + %r3 = call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat %a, float %r2) + %r4 = call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat %a, float %r3) + %r5 = call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat %a, float %r4) + + ; SAT + %r6 = call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat %a, float %r5) + %r7 = call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat %a, float %r6) + %r8 = call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat %a, float %r7) + %r9 = call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat %a, float %r8) + %r10 = call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat %a, float %r9) + + ret float %r10 +} + +; SUB + +define float @test_sub_f32_f16(half %a, float %b) { +; CHECK-LABEL: test_sub_f32_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<12>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_f16_param_0]; +; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_f16_param_1]; +; CHECK-NEXT: sub.f32.f16 %r2, %rs1, %r1; +; CHECK-NEXT: sub.rn.f32.f16 %r3, %rs1, %r2; +; CHECK-NEXT: sub.rz.f32.f16 %r4, %rs1, %r3; +; CHECK-NEXT: sub.rm.f32.f16 %r5, %rs1, %r4; +; CHECK-NEXT: sub.rp.f32.f16 %r6, %rs1, %r5; +; CHECK-NEXT: sub.sat.f32.f16 %r7, %rs1, %r6; +; CHECK-NEXT: sub.rn.sat.f32.f16 %r8, %rs1, %r7; +; CHECK-NEXT: sub.rz.sat.f32.f16 %r9, %rs1, %r8; +; CHECK-NEXT: sub.rm.sat.f32.f16 %r10, %rs1, %r9; +; CHECK-NEXT: sub.rp.sat.f32.f16 %r11, %rs1, %r10; +; CHECK-NEXT: st.param.b32 [func_retval0], %r11; +; CHECK-NEXT: ret; + %r1 = call float @llvm.nvvm.sub.mixed.f32.f16(half %a, float %b) + %r2 = call float @llvm.nvvm.sub.mixed.rn.f32.f16(half %a, float %r1) + %r3 = call float @llvm.nvvm.sub.mixed.rz.f32.f16(half %a, float %r2) + %r4 = call float @llvm.nvvm.sub.mixed.rm.f32.f16(half %a, float %r3) + %r5 = call float @llvm.nvvm.sub.mixed.rp.f32.f16(half %a, float %r4) + + ; SAT + %r6 = call float @llvm.nvvm.sub.mixed.sat.f32.f16(half %a, float %r5) + %r7 = call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half %a, float %r6) + %r8 = call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half %a, float %r7) + %r9 = call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half %a, float %r8) + %r10 = call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half %a, float %r9) + + ret float %r10 +} + +define float @test_sub_f32_bf16(bfloat %a, float %b) { +; CHECK-LABEL: test_sub_f32_bf16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<2>; +; CHECK-NEXT: .reg .b32 %r<12>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_bf16_param_0]; +; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_bf16_param_1]; +; CHECK-NEXT: sub.f32.bf16 %r2, %rs1, %r1; +; CHECK-NEXT: sub.rn.f32.bf16 %r3, %rs1, %r2; +; CHECK-NEXT: sub.rz.f32.bf16 %r4, %rs1, %r3; +; CHECK-NEXT: sub.rm.f32.bf16 %r5, %rs1, %r4; +; CHECK-NEXT: sub.rp.f32.bf16 %r6, %rs1, %r5; +; CHECK-NEXT: sub.sat.f32.bf16 %r7, %rs1, %r6; +; CHECK-NEXT: sub.rn.sat.f32.bf16 %r8, %rs1, %r7; +; CHECK-NEXT: sub.rz.sat.f32.bf16 %r9, %rs1, %r8; +; CHECK-NEXT: sub.rm.sat.f32.bf16 %r10, %rs1, %r9; +; CHECK-NEXT: sub.rp.sat.f32.bf16 %r11, %rs1, %r10; +; CHECK-NEXT: st.param.b32 [func_retval0], %r11; +; CHECK-NEXT: ret; + %r1 = call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat %a, float %b) + %r2 = call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat %a, float %r1) + %r3 = call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat %a, float %r2) + %r4 = call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat %a, float %r3) + %r5 = call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat %a, float %r4) + + ; SAT + %r6 = call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat %a, float %r5) + %r7 = call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat %a, float %r6) + %r8 = call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat %a, float %r7) + %r9 = call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat %a, float %r8) + %r10 = call float @llvm.nvvm.sub.mixed.rp.sat.f32.bf16(bfloat %a, float %r9) + + ret float %r10 +} + +; FMA + +define float @test_fma_f32_f16(half %a, half %b, float %c) { +; CHECK-LABEL: test_fma_f32_f16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .b32 %r<10>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_f16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_f32_f16_param_1]; +; CHECK-NEXT: ld.param.b32 %r1, [test_fma_f32_f16_param_2]; +; CHECK-NEXT: fma.rn.f32.f16 %r2, %rs1, %rs2, %r1; +; CHECK-NEXT: fma.rz.f32.f16 %r3, %rs1, %rs2, %r2; +; CHECK-NEXT: fma.rm.f32.f16 %r4, %rs1, %rs2, %r3; +; CHECK-NEXT: fma.rp.f32.f16 %r5, %rs1, %rs2, %r4; +; CHECK-NEXT: fma.rn.sat.f32.f16 %r6, %rs1, %rs2, %r5; +; CHECK-NEXT: fma.rz.sat.f32.f16 %r7, %rs1, %rs2, %r6; +; CHECK-NEXT: fma.rm.sat.f32.f16 %r8, %rs1, %rs2, %r7; +; CHECK-NEXT: fma.rp.sat.f32.f16 %r9, %rs1, %rs2, %r8; +; CHECK-NEXT: st.param.b32 [func_retval0], %r9; +; CHECK-NEXT: ret; + %r1= call float @llvm.nvvm.fma.mixed.rn.f32.f16(half %a, half %b, float %c) + %r2 = call float @llvm.nvvm.fma.mixed.rz.f32.f16(half %a, half %b, float %r1) + %r3 = call float @llvm.nvvm.fma.mixed.rm.f32.f16(half %a, half %b, float %r2) + %r4 = call float @llvm.nvvm.fma.mixed.rp.f32.f16(half %a, half %b, float %r3) + + ; SAT + %r5 = call float @llvm.nvvm.fma.mixed.rn.sat.f32.f16(half %a, half %b, float %r4) + %r6 = call float @llvm.nvvm.fma.mixed.rz.sat.f32.f16(half %a, half %b, float %r5) + %r7 = call float @llvm.nvvm.fma.mixed.rm.sat.f32.f16(half %a, half %b, float %r6) + %r8 = call float @llvm.nvvm.fma.mixed.rp.sat.f32.f16(half %a, half %b, float %r7) + + ret float %r8 +} + +define float @test_fma_f32_bf16(bfloat %a, bfloat %b, float %c) { +; CHECK-LABEL: test_fma_f32_bf16( +; CHECK: { +; CHECK-NEXT: .reg .b16 %rs<3>; +; CHECK-NEXT: .reg .b32 %r<10>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_bf16_param_0]; +; CHECK-NEXT: ld.param.b16 %rs2, [test_fma_f32_bf16_param_1]; +; CHECK-NEXT: ld.param.b32 %r1, [test_fma_f32_bf16_param_2]; +; CHECK-NEXT: fma.rn.f32.bf16 %r2, %rs1, %rs2, %r1; +; CHECK-NEXT: fma.rz.f32.bf16 %r3, %rs1, %rs2, %r2; +; CHECK-NEXT: fma.rm.f32.bf16 %r4, %rs1, %rs2, %r3; +; CHECK-NEXT: fma.rp.f32.bf16 %r5, %rs1, %rs2, %r4; +; CHECK-NEXT: fma.rn.sat.f32.bf16 %r6, %rs1, %rs2, %r5; +; CHECK-NEXT: fma.rz.sat.f32.bf16 %r7, %rs1, %rs2, %r6; +; CHECK-NEXT: fma.rm.sat.f32.bf16 %r8, %rs1, %rs2, %r7; +; CHECK-NEXT: fma.rp.sat.f32.bf16 %r9, %rs1, %rs2, %r8; +; CHECK-NEXT: st.param.b32 [func_retval0], %r9; +; CHECK-NEXT: ret; + %r1 = call float @llvm.nvvm.fma.mixed.rn.f32.bf16(bfloat %a, bfloat %b, float %c) + %r2 = call float @llvm.nvvm.fma.mixed.rz.f32.bf16(bfloat %a, bfloat %b, float %r1) + %r3 = call float @llvm.nvvm.fma.mixed.rm.f32.bf16(bfloat %a, bfloat %b, float %r2) + %r4 = call float @llvm.nvvm.fma.mixed.rp.f32.bf16(bfloat %a, bfloat %b, float %r3) + + ; SAT + %r5 = call float @llvm.nvvm.fma.mixed.rn.sat.f32.bf16(bfloat %a, bfloat %b, float %r4) + %r6 = call float @llvm.nvvm.fma.mixed.rz.sat.f32.bf16(bfloat %a, bfloat %b, float %r5) + %r7 = call float @llvm.nvvm.fma.mixed.rm.sat.f32.bf16(bfloat %a, bfloat %b, float %r6) + %r8 = call float @llvm.nvvm.fma.mixed.rp.sat.f32.bf16(bfloat %a, bfloat %b, float %r7) + + ret float %r8 +} >From 444a0a76bfc48750f490abcb854e46057743e723 Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi <[email protected]> Date: Mon, 17 Nov 2025 12:03:05 +0000 Subject: [PATCH 2/3] fix whitespace error --- llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll index a4f2fe68830f5..adebcf868b2e6 100644 --- a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll +++ b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll @@ -176,7 +176,7 @@ define float @test_fma_f32_f16(half %a, half %b, float %c) { ; CHECK-NEXT: fma.rp.sat.f32.f16 %r9, %rs1, %rs2, %r8; ; CHECK-NEXT: st.param.b32 [func_retval0], %r9; ; CHECK-NEXT: ret; - %r1= call float @llvm.nvvm.fma.mixed.rn.f32.f16(half %a, half %b, float %c) + %r1 = call float @llvm.nvvm.fma.mixed.rn.f32.f16(half %a, half %b, float %c) %r2 = call float @llvm.nvvm.fma.mixed.rz.f32.f16(half %a, half %b, float %r1) %r3 = call float @llvm.nvvm.fma.mixed.rm.f32.f16(half %a, half %b, float %r2) %r4 = call float @llvm.nvvm.fma.mixed.rp.f32.f16(half %a, half %b, float %r3) >From 9b02a28d4f02e732672f97835d2f2a8a8c655bfa Mon Sep 17 00:00:00 2001 From: Srinivasa Ravi <[email protected]> Date: Thu, 20 Nov 2025 17:52:51 +0000 Subject: [PATCH 3/3] remove mixed precision intrinsics and use idioms --- clang/include/clang/Basic/BuiltinsNVPTX.td | 98 +++----- clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp | 123 ---------- clang/test/CodeGen/builtins-nvptx.c | 213 +++++++---------- llvm/include/llvm/IR/IntrinsicsNVVM.td | 55 ++--- llvm/lib/Target/NVPTX/NVPTXIntrinsics.td | 60 ++++- llvm/test/CodeGen/NVPTX/fp-arith-sat.ll | 103 ++++++++ llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll | 55 +++++ llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll | 222 +++++++++--------- 8 files changed, 465 insertions(+), 464 deletions(-) create mode 100644 llvm/test/CodeGen/NVPTX/fp-arith-sat.ll create mode 100644 llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll diff --git a/clang/include/clang/Basic/BuiltinsNVPTX.td b/clang/include/clang/Basic/BuiltinsNVPTX.td index 47ba12bef058c..5409c25c38508 100644 --- a/clang/include/clang/Basic/BuiltinsNVPTX.td +++ b/clang/include/clang/Basic/BuiltinsNVPTX.td @@ -389,36 +389,26 @@ def __nvvm_fma_rn_relu_bf16 : NVPTXBuiltinSMAndPTX<"__bf16(__bf16, __bf16, __bf1 def __nvvm_fma_rn_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>; def __nvvm_fma_rn_relu_bf16x2 : NVPTXBuiltinSMAndPTX<"_Vector<2, __bf16>(_Vector<2, __bf16>, _Vector<2, __bf16>, _Vector<2, __bf16>)", SM_80, PTX70>; def __nvvm_fma_rn_ftz_f : NVPTXBuiltin<"float(float, float, float)">; +def __nvvm_fma_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rn_f : NVPTXBuiltin<"float(float, float, float)">; +def __nvvm_fma_rn_sat_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rz_ftz_f : NVPTXBuiltin<"float(float, float, float)">; +def __nvvm_fma_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rz_f : NVPTXBuiltin<"float(float, float, float)">; +def __nvvm_fma_rz_sat_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rm_ftz_f : NVPTXBuiltin<"float(float, float, float)">; +def __nvvm_fma_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rm_f : NVPTXBuiltin<"float(float, float, float)">; +def __nvvm_fma_rm_sat_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rp_ftz_f : NVPTXBuiltin<"float(float, float, float)">; +def __nvvm_fma_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rp_f : NVPTXBuiltin<"float(float, float, float)">; +def __nvvm_fma_rp_sat_f : NVPTXBuiltin<"float(float, float, float)">; def __nvvm_fma_rn_d : NVPTXBuiltin<"double(double, double, double)">; def __nvvm_fma_rz_d : NVPTXBuiltin<"double(double, double, double)">; def __nvvm_fma_rm_d : NVPTXBuiltin<"double(double, double, double)">; def __nvvm_fma_rp_d : NVPTXBuiltin<"double(double, double, double)">; -def __nvvm_fma_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, __fp16, float)", SM_100, PTX86>; - -def __nvvm_fma_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; -def __nvvm_fma_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, __bf16, float)", SM_100, PTX86>; - // Rcp def __nvvm_rcp_rn_ftz_f : NVPTXBuiltin<"float(float)">; @@ -465,64 +455,50 @@ def __nvvm_rsqrt_approx_d : NVPTXBuiltin<"double(double)">; // Add def __nvvm_add_rn_ftz_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_add_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rn_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_add_rn_sat_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rz_ftz_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_add_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rz_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_add_rz_sat_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rm_ftz_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_add_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rm_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_add_rm_sat_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rp_ftz_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_add_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rp_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_add_rp_sat_f : NVPTXBuiltin<"float(float, float)">; def __nvvm_add_rn_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rz_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rm_d : NVPTXBuiltin<"double(double, double)">; def __nvvm_add_rp_d : NVPTXBuiltin<"double(double, double)">; -def __nvvm_add_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; - -def __nvvm_add_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_add_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; - // Sub -def __nvvm_sub_mixed_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rn_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rz_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rm_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rp_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rn_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rz_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rm_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rp_sat_f16_f32 : NVPTXBuiltinSMAndPTX<"float(__fp16, float)", SM_100, PTX86>; - -def __nvvm_sub_mixed_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rn_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rz_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rm_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rp_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rn_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rz_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rm_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; -def __nvvm_sub_mixed_rp_sat_bf16_f32 : NVPTXBuiltinSMAndPTX<"float(__bf16, float)", SM_100, PTX86>; +def __nvvm_sub_rn_ftz_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rn_ftz_sat_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rn_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rn_sat_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rz_ftz_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rz_ftz_sat_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rz_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rz_sat_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rm_ftz_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rm_ftz_sat_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rm_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rm_sat_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rp_ftz_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rp_ftz_sat_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rp_f : NVPTXBuiltin<"float(float, float)">; +def __nvvm_sub_rp_sat_f : NVPTXBuiltin<"float(float, float)">; + +def __nvvm_sub_rn_d : NVPTXBuiltin<"double(double, double)">; +def __nvvm_sub_rz_d : NVPTXBuiltin<"double(double, double)">; +def __nvvm_sub_rm_d : NVPTXBuiltin<"double(double, double)">; +def __nvvm_sub_rp_d : NVPTXBuiltin<"double(double, double)">; // Convert diff --git a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp index 6f57620f0fb00..8a1cab3417d98 100644 --- a/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp +++ b/clang/lib/CodeGen/TargetBuiltins/NVPTX.cpp @@ -415,17 +415,6 @@ static Value *MakeHalfType(unsigned IntrinsicID, unsigned BuiltinID, return MakeHalfType(CGF.CGM.getIntrinsic(IntrinsicID), BuiltinID, E, CGF); } -static Value *MakeMixedPrecisionFPArithmetic(unsigned IntrinsicID, - const CallExpr *E, - CodeGenFunction &CGF) { - SmallVector<llvm::Value *, 3> Args; - for (unsigned i = 0; i < E->getNumArgs(); ++i) { - Args.push_back(CGF.EmitScalarExpr(E->getArg(i))); - } - return CGF.Builder.CreateCall( - CGF.CGM.getIntrinsic(IntrinsicID, {Args[0]->getType()}), Args); -} - } // namespace Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, @@ -1208,118 +1197,6 @@ Value *CodeGenFunction::EmitNVPTXBuiltinExpr(unsigned BuiltinID, return Builder.CreateCall( CGM.getIntrinsic(Intrinsic::nvvm_barrier_cta_sync_count), {EmitScalarExpr(E->getArg(0)), EmitScalarExpr(E->getArg(1))}); - case NVPTX::BI__nvvm_add_mixed_f16_f32: - case NVPTX::BI__nvvm_add_mixed_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_f32, E, - *this); - case NVPTX::BI__nvvm_add_mixed_rn_f16_f32: - case NVPTX::BI__nvvm_add_mixed_rn_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_f32, E, - *this); - case NVPTX::BI__nvvm_add_mixed_rz_f16_f32: - case NVPTX::BI__nvvm_add_mixed_rz_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_f32, E, - *this); - case NVPTX::BI__nvvm_add_mixed_rm_f16_f32: - case NVPTX::BI__nvvm_add_mixed_rm_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_f32, E, - *this); - case NVPTX::BI__nvvm_add_mixed_rp_f16_f32: - case NVPTX::BI__nvvm_add_mixed_rp_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_f32, E, - *this); - case NVPTX::BI__nvvm_add_mixed_sat_f16_f32: - case NVPTX::BI__nvvm_add_mixed_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_sat_f32, E, - *this); - case NVPTX::BI__nvvm_add_mixed_rn_sat_f16_f32: - case NVPTX::BI__nvvm_add_mixed_rn_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rn_sat_f32, - E, *this); - case NVPTX::BI__nvvm_add_mixed_rz_sat_f16_f32: - case NVPTX::BI__nvvm_add_mixed_rz_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rz_sat_f32, - E, *this); - case NVPTX::BI__nvvm_add_mixed_rm_sat_f16_f32: - case NVPTX::BI__nvvm_add_mixed_rm_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rm_sat_f32, - E, *this); - case NVPTX::BI__nvvm_add_mixed_rp_sat_f16_f32: - case NVPTX::BI__nvvm_add_mixed_rp_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_add_mixed_rp_sat_f32, - E, *this); - case NVPTX::BI__nvvm_sub_mixed_f16_f32: - case NVPTX::BI__nvvm_sub_mixed_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_f32, E, - *this); - case NVPTX::BI__nvvm_sub_mixed_rn_f16_f32: - case NVPTX::BI__nvvm_sub_mixed_rn_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_f32, E, - *this); - case NVPTX::BI__nvvm_sub_mixed_rz_f16_f32: - case NVPTX::BI__nvvm_sub_mixed_rz_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_f32, E, - *this); - case NVPTX::BI__nvvm_sub_mixed_rm_f16_f32: - case NVPTX::BI__nvvm_sub_mixed_rm_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_f32, E, - *this); - case NVPTX::BI__nvvm_sub_mixed_rp_f16_f32: - case NVPTX::BI__nvvm_sub_mixed_rp_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_f32, E, - *this); - case NVPTX::BI__nvvm_sub_mixed_sat_f16_f32: - case NVPTX::BI__nvvm_sub_mixed_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_sat_f32, E, - *this); - case NVPTX::BI__nvvm_sub_mixed_rn_sat_f16_f32: - case NVPTX::BI__nvvm_sub_mixed_rn_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rn_sat_f32, - E, *this); - case NVPTX::BI__nvvm_sub_mixed_rz_sat_f16_f32: - case NVPTX::BI__nvvm_sub_mixed_rz_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rz_sat_f32, - E, *this); - case NVPTX::BI__nvvm_sub_mixed_rm_sat_f16_f32: - case NVPTX::BI__nvvm_sub_mixed_rm_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rm_sat_f32, - E, *this); - case NVPTX::BI__nvvm_sub_mixed_rp_sat_f16_f32: - case NVPTX::BI__nvvm_sub_mixed_rp_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_sub_mixed_rp_sat_f32, - E, *this); - case NVPTX::BI__nvvm_fma_mixed_rn_f16_f32: - case NVPTX::BI__nvvm_fma_mixed_rn_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_f32, E, - *this); - case NVPTX::BI__nvvm_fma_mixed_rz_f16_f32: - case NVPTX::BI__nvvm_fma_mixed_rz_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_f32, E, - *this); - case NVPTX::BI__nvvm_fma_mixed_rm_f16_f32: - case NVPTX::BI__nvvm_fma_mixed_rm_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_f32, E, - *this); - case NVPTX::BI__nvvm_fma_mixed_rp_f16_f32: - case NVPTX::BI__nvvm_fma_mixed_rp_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_f32, E, - *this); - case NVPTX::BI__nvvm_fma_mixed_rn_sat_f16_f32: - case NVPTX::BI__nvvm_fma_mixed_rn_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rn_sat_f32, - E, *this); - case NVPTX::BI__nvvm_fma_mixed_rz_sat_f16_f32: - case NVPTX::BI__nvvm_fma_mixed_rz_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rz_sat_f32, - E, *this); - case NVPTX::BI__nvvm_fma_mixed_rm_sat_f16_f32: - case NVPTX::BI__nvvm_fma_mixed_rm_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rm_sat_f32, - E, *this); - case NVPTX::BI__nvvm_fma_mixed_rp_sat_f16_f32: - case NVPTX::BI__nvvm_fma_mixed_rp_sat_bf16_f32: - return MakeMixedPrecisionFPArithmetic(Intrinsic::nvvm_fma_mixed_rp_sat_f32, - E, *this); default: return nullptr; } diff --git a/clang/test/CodeGen/builtins-nvptx.c b/clang/test/CodeGen/builtins-nvptx.c index 1753b4c7767e9..0409603598b6f 100644 --- a/clang/test/CodeGen/builtins-nvptx.c +++ b/clang/test/CodeGen/builtins-nvptx.c @@ -1467,135 +1467,94 @@ __device__ void nvvm_min_max_sm86() { // CHECK: ret void } -#define F16 (__fp16)0.1f -#define F16_2 (__fp16)0.2f - -__device__ void nvvm_add_mixed_precision_sm100() { -#if __CUDA_ARCH__ >= 1000 - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_add_mixed_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_add_mixed_rn_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_add_mixed_rz_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_add_mixed_rm_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_add_mixed_rp_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_add_mixed_sat_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_add_mixed_rn_sat_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_add_mixed_rz_sat_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_add_mixed_rm_sat_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_add_mixed_rp_sat_f16_f32(F16, 1.0f); - - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_add_mixed_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_add_mixed_rn_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_add_mixed_rz_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_add_mixed_rm_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_add_mixed_rp_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_add_mixed_sat_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_add_mixed_rn_sat_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_add_mixed_rz_sat_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_add_mixed_rm_sat_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_add_mixed_rp_sat_bf16_f32(BF16, 1.0f); -#endif +// CHECK-LABEL: nvvm_add_sub_fma_f32_sat +__device__ void nvvm_add_sub_fma_f32_sat() { + // CHECK: call float @llvm.nvvm.add.rn.sat.f + __nvvm_add_rn_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.add.rn.ftz.sat.f + __nvvm_add_rn_ftz_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.add.rz.sat.f + __nvvm_add_rz_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.add.rz.ftz.sat.f + __nvvm_add_rz_ftz_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.add.rm.sat.f + __nvvm_add_rm_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.add.rm.ftz.sat.f + __nvvm_add_rm_ftz_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.add.rp.sat.f + __nvvm_add_rp_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.add.rp.ftz.sat.f + __nvvm_add_rp_ftz_sat_f(1.0f, 2.0f); + + // CHECK: call float @llvm.nvvm.sub.rn.sat.f + __nvvm_sub_rn_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rn.ftz.sat.f + __nvvm_sub_rn_ftz_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rz.sat.f + __nvvm_sub_rz_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rz.ftz.sat.f + __nvvm_sub_rz_ftz_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rm.sat.f + __nvvm_sub_rm_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rm.ftz.sat.f + __nvvm_sub_rm_ftz_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rp.sat.f + __nvvm_sub_rp_sat_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rp.ftz.sat.f + __nvvm_sub_rp_ftz_sat_f(1.0f, 2.0f); + + // CHECK: call float @llvm.nvvm.fma.rn.sat.f + __nvvm_fma_rn_sat_f(1.0f, 2.0f, 3.0f); + // CHECK: call float @llvm.nvvm.fma.rn.ftz.sat.f + __nvvm_fma_rn_ftz_sat_f(1.0f, 2.0f, 3.0f); + // CHECK: call float @llvm.nvvm.fma.rz.sat.f + __nvvm_fma_rz_sat_f(1.0f, 2.0f, 3.0f); + // CHECK: call float @llvm.nvvm.fma.rz.ftz.sat.f + __nvvm_fma_rz_ftz_sat_f(1.0f, 2.0f, 3.0f); + // CHECK: call float @llvm.nvvm.fma.rm.sat.f + __nvvm_fma_rm_sat_f(1.0f, 2.0f, 3.0f); + // CHECK: call float @llvm.nvvm.fma.rm.ftz.sat.f + __nvvm_fma_rm_ftz_sat_f(1.0f, 2.0f, 3.0f); + // CHECK: call float @llvm.nvvm.fma.rp.sat.f + __nvvm_fma_rp_sat_f(1.0f, 2.0f, 3.0f); + // CHECK: call float @llvm.nvvm.fma.rp.ftz.sat.f + __nvvm_fma_rp_ftz_sat_f(1.0f, 2.0f, 3.0f); + + // CHECK: ret void } -__device__ void nvvm_sub_mixed_precision_sm100() { -#if __CUDA_ARCH__ >= 1000 - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_sub_mixed_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_sub_mixed_rn_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_sub_mixed_rz_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_sub_mixed_rm_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_sub_mixed_rp_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_sub_mixed_sat_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_sub_mixed_rn_sat_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_sub_mixed_rz_sat_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_sub_mixed_rm_sat_f16_f32(F16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half 0xH2E66, float 1.000000e+00) - __nvvm_sub_mixed_rp_sat_f16_f32(F16, 1.0f); - - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_sub_mixed_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_sub_mixed_rn_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_sub_mixed_rz_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_sub_mixed_rm_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_sub_mixed_rp_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_sub_mixed_sat_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_sub_mixed_rn_sat_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_sub_mixed_rz_sat_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_sub_mixed_rm_sat_bf16_f32(BF16, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.sub.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, float 1.000000e+00) - __nvvm_sub_mixed_rp_sat_bf16_f32(BF16, 1.0f); -#endif +// CHECK-LABEL: nvvm_sub_f32 +__device__ void nvvm_sub_f32() { + // CHECK: call float @llvm.nvvm.sub.rn.f + __nvvm_sub_rn_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rn.ftz.f + __nvvm_sub_rn_ftz_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rz.f + __nvvm_sub_rz_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rz.ftz.f + __nvvm_sub_rz_ftz_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rm.f + __nvvm_sub_rm_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rm.ftz.f + __nvvm_sub_rm_ftz_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rp.f + __nvvm_sub_rp_f(1.0f, 2.0f); + // CHECK: call float @llvm.nvvm.sub.rp.ftz.f + __nvvm_sub_rp_ftz_f(1.0f, 2.0f); + + // CHECK: ret void } -__device__ void nvvm_fma_mixed_precision_sm100() { -#if __CUDA_ARCH__ >= 1000 - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) - __nvvm_fma_mixed_rn_f16_f32(F16, F16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) - __nvvm_fma_mixed_rz_f16_f32(F16, F16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) - __nvvm_fma_mixed_rm_f16_f32(F16, F16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) - __nvvm_fma_mixed_rp_f16_f32(F16, F16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) - __nvvm_fma_mixed_rn_sat_f16_f32(F16, F16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) - __nvvm_fma_mixed_rz_sat_f16_f32(F16, F16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) - __nvvm_fma_mixed_rm_sat_f16_f32(F16, F16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.sat.f32.f16(half 0xH2E66, half 0xH3266, float 1.000000e+00) - __nvvm_fma_mixed_rp_sat_f16_f32(F16, F16_2, 1.0f); - - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) - __nvvm_fma_mixed_rn_bf16_f32(BF16, BF16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) - __nvvm_fma_mixed_rz_bf16_f32(BF16, BF16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) - __nvvm_fma_mixed_rm_bf16_f32(BF16, BF16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) - __nvvm_fma_mixed_rp_bf16_f32(BF16, BF16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rn.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) - __nvvm_fma_mixed_rn_sat_bf16_f32(BF16, BF16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rz.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) - __nvvm_fma_mixed_rz_sat_bf16_f32(BF16, BF16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rm.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) - __nvvm_fma_mixed_rm_sat_bf16_f32(BF16, BF16_2, 1.0f); - // CHECK_PTX86_SM100: call float @llvm.nvvm.fma.mixed.rp.sat.f32.bf16(bfloat 0xR3DCD, bfloat 0xR3E4D, float 1.000000e+00) - __nvvm_fma_mixed_rp_sat_bf16_f32(BF16, BF16_2, 1.0f); -#endif +// CHECK-LABEL: nvvm_sub_f64 +__device__ void nvvm_sub_f64() { + // CHECK: call double @llvm.nvvm.sub.rn.d + __nvvm_sub_rn_d(1.0f, 2.0f); + // CHECK: call double @llvm.nvvm.sub.rz.d + __nvvm_sub_rz_d(1.0f, 2.0f); + // CHECK: call double @llvm.nvvm.sub.rm.d + __nvvm_sub_rm_d(1.0f, 2.0f); + // CHECK: call double @llvm.nvvm.sub.rp.d + __nvvm_sub_rp_d(1.0f, 2.0f); + + // CHECK: ret void } diff --git a/llvm/include/llvm/IR/IntrinsicsNVVM.td b/llvm/include/llvm/IR/IntrinsicsNVVM.td index 7a7fce42e55b0..11e14f29fa74c 100644 --- a/llvm/include/llvm/IR/IntrinsicsNVVM.td +++ b/llvm/include/llvm/IR/IntrinsicsNVVM.td @@ -1376,24 +1376,18 @@ let TargetPrefix = "nvvm" in { } // ftz } // variant - foreach rnd = ["rn", "rz", "rm", "rp"] in { - foreach ftz = ["", "_ftz"] in - def int_nvvm_fma_ # rnd # ftz # _f : NVVMBuiltin, - PureIntrinsic<[llvm_float_ty], - [llvm_float_ty, llvm_float_ty, llvm_float_ty]>; + foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { + foreach ftz = ["", "_ftz"] in { + foreach sat = ["", "_sat"] in + def int_nvvm_fma # rnd # ftz # sat # _f : NVVMBuiltin, + PureIntrinsic<[llvm_float_ty], + [llvm_float_ty, llvm_float_ty, llvm_float_ty]>; + } - def int_nvvm_fma_ # rnd # _d : NVVMBuiltin, + def int_nvvm_fma # rnd # _d : NVVMBuiltin, PureIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty, llvm_double_ty]>; } - - foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { - foreach sat = ["", "_sat"] in { - def int_nvvm_fma_mixed # rnd # sat # _f32 : - PureIntrinsic<[llvm_float_ty], - [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty]>; - } - } // // Rcp @@ -1451,30 +1445,31 @@ let TargetPrefix = "nvvm" in { // Add // let IntrProperties = [IntrNoMem, IntrSpeculatable, Commutative] in { - foreach rnd = ["rn", "rz", "rm", "rp"] in { - foreach ftz = ["", "_ftz"] in - def int_nvvm_add_ # rnd # ftz # _f : NVVMBuiltin, - DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>; + foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { + foreach ftz = ["", "_ftz"] in { + foreach sat = ["", "_sat"] in + def int_nvvm_add # rnd # ftz # sat # _f : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>; + } - def int_nvvm_add_ # rnd # _d : NVVMBuiltin, + def int_nvvm_add # rnd # _d : NVVMBuiltin, DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>; } } - - foreach rnd = ["", "_rn", "_rz", "_rm", "_rp"] in { - foreach sat = ["", "_sat"] in { - def int_nvvm_add_mixed # rnd # sat # _f32 : - PureIntrinsic<[llvm_float_ty], [llvm_anyfloat_ty, llvm_float_ty]>; - } - } // // Sub // - foreach rnd = ["", "_rn", "_rz", "_rm", "_rp"] in { - foreach sat = ["", "_sat"] in { - def int_nvvm_sub_mixed # rnd # sat # _f32 : - PureIntrinsic<[llvm_float_ty], [llvm_anyfloat_ty, llvm_float_ty]>; + let IntrProperties = [IntrNoMem, IntrSpeculatable] in { + foreach rnd = ["_rn", "_rz", "_rm", "_rp"] in { + foreach ftz = ["", "_ftz"] in { + foreach sat = ["", "_sat"] in + def int_nvvm_sub # rnd # ftz # sat # _f : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_float_ty], [llvm_float_ty, llvm_float_ty]>; + } + + def int_nvvm_sub # rnd # _d : NVVMBuiltin, + DefaultAttrsIntrinsic<[llvm_double_ty], [llvm_double_ty, llvm_double_ty]>; } } diff --git a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td index 07483bf5a3e3d..255b25ed78fc5 100644 --- a/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td +++ b/llvm/lib/Target/NVPTX/NVPTXIntrinsics.td @@ -1637,13 +1637,21 @@ multiclass FMA_INST { FMA_TUPLE<"_rp_f64", int_nvvm_fma_rp_d, B64>, FMA_TUPLE<"_rn_ftz_f32", int_nvvm_fma_rn_ftz_f, B32>, + FMA_TUPLE<"_rn_ftz_sat_f32", int_nvvm_fma_rn_ftz_sat_f, B32>, FMA_TUPLE<"_rn_f32", int_nvvm_fma_rn_f, B32>, + FMA_TUPLE<"_rn_sat_f32", int_nvvm_fma_rn_sat_f, B32>, FMA_TUPLE<"_rz_ftz_f32", int_nvvm_fma_rz_ftz_f, B32>, + FMA_TUPLE<"_rz_ftz_sat_f32", int_nvvm_fma_rz_ftz_sat_f, B32>, FMA_TUPLE<"_rz_f32", int_nvvm_fma_rz_f, B32>, + FMA_TUPLE<"_rz_sat_f32", int_nvvm_fma_rz_sat_f, B32>, FMA_TUPLE<"_rm_f32", int_nvvm_fma_rm_f, B32>, + FMA_TUPLE<"_rm_sat_f32", int_nvvm_fma_rm_sat_f, B32>, FMA_TUPLE<"_rm_ftz_f32", int_nvvm_fma_rm_ftz_f, B32>, + FMA_TUPLE<"_rm_ftz_sat_f32", int_nvvm_fma_rm_ftz_sat_f, B32>, FMA_TUPLE<"_rp_f32", int_nvvm_fma_rp_f, B32>, + FMA_TUPLE<"_rp_sat_f32", int_nvvm_fma_rp_sat_f, B32>, FMA_TUPLE<"_rp_ftz_f32", int_nvvm_fma_rp_ftz_f, B32>, + FMA_TUPLE<"_rp_ftz_sat_f32", int_nvvm_fma_rp_ftz_sat_f, B32>, FMA_TUPLE<"_rn_f16", int_nvvm_fma_rn_f16, B16, [hasPTX<42>, hasSM<53>]>, FMA_TUPLE<"_rn_ftz_f16", int_nvvm_fma_rn_ftz_f16, B16, @@ -1701,8 +1709,10 @@ foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in { BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B16:$b, B32:$c), !tolower(!subst("_", ".", "fma" # rnd # sat # "_f32_" # type)), [(set f32:$dst, - (!cast<Intrinsic>(!tolower("int_nvvm_fma_mixed" # rnd # sat # "_f32")) - !cast<ValueType>(!tolower(type)):$a, !cast<ValueType>(!tolower(type)):$b, f32:$c))]>, + (!cast<Intrinsic>(!tolower("int_nvvm_fma" # rnd # sat # "_f")) + (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)), + (f32 (fpextend !cast<ValueType>(!tolower(type)):$b)), + f32:$c))]>, Requires<[hasSM<100>, hasPTX<86>]>; } } @@ -1807,45 +1817,77 @@ let Predicates = [doRsqrtOpt] in { // def INT_NVVM_ADD_RN_FTZ_F : F_MATH_2<"add.rn.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_f>; +def INT_NVVM_ADD_RN_SAT_FTZ_F : F_MATH_2<"add.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rn_ftz_sat_f>; def INT_NVVM_ADD_RN_F : F_MATH_2<"add.rn.f32", B32, B32, B32, int_nvvm_add_rn_f>; +def INT_NVVM_ADD_RN_SAT_F : F_MATH_2<"add.rn.sat.f32", B32, B32, B32, int_nvvm_add_rn_sat_f>; def INT_NVVM_ADD_RZ_FTZ_F : F_MATH_2<"add.rz.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_f>; +def INT_NVVM_ADD_RZ_SAT_FTZ_F : F_MATH_2<"add.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rz_ftz_sat_f>; def INT_NVVM_ADD_RZ_F : F_MATH_2<"add.rz.f32", B32, B32, B32, int_nvvm_add_rz_f>; +def INT_NVVM_ADD_RZ_SAT_F : F_MATH_2<"add.rz.sat.f32", B32, B32, B32, int_nvvm_add_rz_sat_f>; def INT_NVVM_ADD_RM_FTZ_F : F_MATH_2<"add.rm.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_f>; +def INT_NVVM_ADD_RM_SAT_FTZ_F : F_MATH_2<"add.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rm_ftz_sat_f>; def INT_NVVM_ADD_RM_F : F_MATH_2<"add.rm.f32", B32, B32, B32, int_nvvm_add_rm_f>; +def INT_NVVM_ADD_RM_SAT_F : F_MATH_2<"add.rm.sat.f32", B32, B32, B32, int_nvvm_add_rm_sat_f>; def INT_NVVM_ADD_RP_FTZ_F : F_MATH_2<"add.rp.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_f>; +def INT_NVVM_ADD_RP_SAT_FTZ_F : F_MATH_2<"add.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_add_rp_ftz_sat_f>; def INT_NVVM_ADD_RP_F : F_MATH_2<"add.rp.f32", B32, B32, B32, int_nvvm_add_rp_f>; +def INT_NVVM_ADD_RP_SAT_F : F_MATH_2<"add.rp.sat.f32", B32, B32, B32, int_nvvm_add_rp_sat_f>; def INT_NVVM_ADD_RN_D : F_MATH_2<"add.rn.f64", B64, B64, B64, int_nvvm_add_rn_d>; def INT_NVVM_ADD_RZ_D : F_MATH_2<"add.rz.f64", B64, B64, B64, int_nvvm_add_rz_d>; def INT_NVVM_ADD_RM_D : F_MATH_2<"add.rm.f64", B64, B64, B64, int_nvvm_add_rm_d>; def INT_NVVM_ADD_RP_D : F_MATH_2<"add.rp.f64", B64, B64, B64, int_nvvm_add_rp_d>; -foreach rnd = ["", "_RN", "_RZ", "_RM", "_RP"] in { +foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in { foreach sat = ["", "_SAT"] in { foreach type = ["F16", "BF16"] in { def INT_NVVM_ADD # rnd # sat # _F32_ # type : BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b), !tolower(!subst("_", ".", "add" # rnd # sat # "_f32_" # type)), [(set f32:$dst, - (!cast<Intrinsic>(!tolower("int_nvvm_add_mixed" # rnd # sat # "_f32")) - !cast<ValueType>(!tolower(type)):$a, f32:$b))]>, + (!cast<Intrinsic>(!tolower("int_nvvm_add" # rnd # sat # "_f")) + (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)), + f32:$b))]>, Requires<[hasSM<100>, hasPTX<86>]>; } } } -// + // Sub // -foreach rnd = ["", "_RN", "_RZ", "_RM", "_RP"] in { +def INT_NVVM_SUB_RN_FTZ_F : F_MATH_2<"sub.rn.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_f>; +def INT_NVVM_SUB_RN_SAT_FTZ_F : F_MATH_2<"sub.rn.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rn_ftz_sat_f>; +def INT_NVVM_SUB_RN_F : F_MATH_2<"sub.rn.f32", B32, B32, B32, int_nvvm_sub_rn_f>; +def INT_NVVM_SUB_RN_SAT_F : F_MATH_2<"sub.rn.sat.f32", B32, B32, B32, int_nvvm_sub_rn_sat_f>; +def INT_NVVM_SUB_RZ_FTZ_F : F_MATH_2<"sub.rz.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_f>; +def INT_NVVM_SUB_RZ_SAT_FTZ_F : F_MATH_2<"sub.rz.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rz_ftz_sat_f>; +def INT_NVVM_SUB_RZ_F : F_MATH_2<"sub.rz.f32", B32, B32, B32, int_nvvm_sub_rz_f>; +def INT_NVVM_SUB_RZ_SAT_F : F_MATH_2<"sub.rz.sat.f32", B32, B32, B32, int_nvvm_sub_rz_sat_f>; +def INT_NVVM_SUB_RM_FTZ_F : F_MATH_2<"sub.rm.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_f>; +def INT_NVVM_SUB_RM_SAT_FTZ_F : F_MATH_2<"sub.rm.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rm_ftz_sat_f>; +def INT_NVVM_SUB_RM_F : F_MATH_2<"sub.rm.f32", B32, B32, B32, int_nvvm_sub_rm_f>; +def INT_NVVM_SUB_RM_SAT_F : F_MATH_2<"sub.rm.sat.f32", B32, B32, B32, int_nvvm_sub_rm_sat_f>; +def INT_NVVM_SUB_RP_FTZ_F : F_MATH_2<"sub.rp.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_f>; +def INT_NVVM_SUB_RP_SAT_FTZ_F : F_MATH_2<"sub.rp.sat.ftz.f32", B32, B32, B32, int_nvvm_sub_rp_ftz_sat_f>; +def INT_NVVM_SUB_RP_F : F_MATH_2<"sub.rp.f32", B32, B32, B32, int_nvvm_sub_rp_f>; +def INT_NVVM_SUB_RP_SAT_F : F_MATH_2<"sub.rp.sat.f32", B32, B32, B32, int_nvvm_sub_rp_sat_f>; + +def INT_NVVM_SUB_RN_D : F_MATH_2<"sub.rn.f64", B64, B64, B64, int_nvvm_sub_rn_d>; +def INT_NVVM_SUB_RZ_D : F_MATH_2<"sub.rz.f64", B64, B64, B64, int_nvvm_sub_rz_d>; +def INT_NVVM_SUB_RM_D : F_MATH_2<"sub.rm.f64", B64, B64, B64, int_nvvm_sub_rm_d>; +def INT_NVVM_SUB_RP_D : F_MATH_2<"sub.rp.f64", B64, B64, B64, int_nvvm_sub_rp_d>; + +foreach rnd = ["_RN", "_RZ", "_RM", "_RP"] in { foreach sat = ["", "_SAT"] in { foreach type = ["F16", "BF16"] in { def INT_NVVM_SUB # rnd # sat # _F32_ # type : BasicNVPTXInst<(outs B32:$dst), (ins B16:$a, B32:$b), !tolower(!subst("_", ".", "sub" # rnd # sat # "_f32_" # type)), [(set f32:$dst, - (!cast<Intrinsic>(!tolower("int_nvvm_sub_mixed" # rnd # sat # "_f32")) - !cast<ValueType>(!tolower(type)):$a, f32:$b))]>, + (!cast<Intrinsic>(!tolower("int_nvvm_sub" # rnd # sat # "_f")) + (f32 (fpextend !cast<ValueType>(!tolower(type)):$a)), + f32:$b))]>, Requires<[hasSM<100>, hasPTX<86>]>; } } diff --git a/llvm/test/CodeGen/NVPTX/fp-arith-sat.ll b/llvm/test/CodeGen/NVPTX/fp-arith-sat.ll new file mode 100644 index 0000000000000..20afa329599b1 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/fp-arith-sat.ll @@ -0,0 +1,103 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 | FileCheck %s +; RUN: %if ptxas-sm_20 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 | %ptxas-verify -arch=sm_20 %} + +define float @add_sat_f32(float %a, float %b) { +; CHECK-LABEL: add_sat_f32( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<11>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [add_sat_f32_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [add_sat_f32_param_1]; +; CHECK-NEXT: add.rn.sat.f32 %r3, %r1, %r2; +; CHECK-NEXT: add.rn.sat.ftz.f32 %r4, %r1, %r3; +; CHECK-NEXT: add.rz.sat.f32 %r5, %r1, %r4; +; CHECK-NEXT: add.rz.sat.ftz.f32 %r6, %r1, %r5; +; CHECK-NEXT: add.rm.sat.f32 %r7, %r1, %r6; +; CHECK-NEXT: add.rm.sat.ftz.f32 %r8, %r1, %r7; +; CHECK-NEXT: add.rp.sat.f32 %r9, %r1, %r8; +; CHECK-NEXT: add.rp.sat.ftz.f32 %r10, %r1, %r9; +; CHECK-NEXT: st.param.b32 [func_retval0], %r10; +; CHECK-NEXT: ret; + %r1 = call float @llvm.nvvm.add.rn.sat.f(float %a, float %b) + %r2 = call float @llvm.nvvm.add.rn.ftz.sat.f(float %a, float %r1) + + %r3 = call float @llvm.nvvm.add.rz.sat.f(float %a, float %r2) + %r4 = call float @llvm.nvvm.add.rz.ftz.sat.f(float %a, float %r3) + + %r5 = call float @llvm.nvvm.add.rm.sat.f(float %a, float %r4) + %r6 = call float @llvm.nvvm.add.rm.ftz.sat.f(float %a, float %r5) + + %r7 = call float @llvm.nvvm.add.rp.sat.f(float %a, float %r6) + %r8 = call float @llvm.nvvm.add.rp.ftz.sat.f(float %a, float %r7) + + ret float %r8 +} + +define float @sub_sat_f32(float %a, float %b) { +; CHECK-LABEL: sub_sat_f32( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<11>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [sub_sat_f32_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [sub_sat_f32_param_1]; +; CHECK-NEXT: sub.rn.sat.f32 %r3, %r1, %r2; +; CHECK-NEXT: sub.rn.sat.ftz.f32 %r4, %r1, %r3; +; CHECK-NEXT: sub.rz.sat.f32 %r5, %r1, %r4; +; CHECK-NEXT: sub.rz.sat.ftz.f32 %r6, %r1, %r5; +; CHECK-NEXT: sub.rm.sat.f32 %r7, %r1, %r6; +; CHECK-NEXT: sub.rm.sat.ftz.f32 %r8, %r1, %r7; +; CHECK-NEXT: sub.rp.sat.f32 %r9, %r1, %r8; +; CHECK-NEXT: sub.rp.sat.ftz.f32 %r10, %r1, %r9; +; CHECK-NEXT: st.param.b32 [func_retval0], %r10; +; CHECK-NEXT: ret; + %r1 = call float @llvm.nvvm.sub.rn.sat.f(float %a, float %b) + %r2 = call float @llvm.nvvm.sub.rn.ftz.sat.f(float %a, float %r1) + + %r3 = call float @llvm.nvvm.sub.rz.sat.f(float %a, float %r2) + %r4 = call float @llvm.nvvm.sub.rz.ftz.sat.f(float %a, float %r3) + + %r5 = call float @llvm.nvvm.sub.rm.sat.f(float %a, float %r4) + %r6 = call float @llvm.nvvm.sub.rm.ftz.sat.f(float %a, float %r5) + + %r7 = call float @llvm.nvvm.sub.rp.sat.f(float %a, float %r6) + %r8 = call float @llvm.nvvm.sub.rp.ftz.sat.f(float %a, float %r7) + + ret float %r8 +} + +define float @fma_sat_f32(float %a, float %b, float %c) { +; CHECK-LABEL: fma_sat_f32( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<12>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [fma_sat_f32_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [fma_sat_f32_param_1]; +; CHECK-NEXT: ld.param.b32 %r3, [fma_sat_f32_param_2]; +; CHECK-NEXT: fma.rn.sat.f32 %r4, %r1, %r2, %r3; +; CHECK-NEXT: fma.rn.ftz.sat.f32 %r5, %r1, %r2, %r4; +; CHECK-NEXT: fma.rz.sat.f32 %r6, %r1, %r2, %r5; +; CHECK-NEXT: fma.rz.ftz.sat.f32 %r7, %r1, %r2, %r6; +; CHECK-NEXT: fma.rm.sat.f32 %r8, %r1, %r2, %r7; +; CHECK-NEXT: fma.rm.ftz.sat.f32 %r9, %r1, %r2, %r8; +; CHECK-NEXT: fma.rp.sat.f32 %r10, %r1, %r2, %r9; +; CHECK-NEXT: fma.rp.ftz.sat.f32 %r11, %r1, %r2, %r10; +; CHECK-NEXT: st.param.b32 [func_retval0], %r11; +; CHECK-NEXT: ret; + %r1 = call float @llvm.nvvm.fma.rn.sat.f(float %a, float %b, float %c) + %r2 = call float @llvm.nvvm.fma.rn.ftz.sat.f(float %a, float %b, float %r1) + + %r3 = call float @llvm.nvvm.fma.rz.sat.f(float %a, float %b, float %r2) + %r4 = call float @llvm.nvvm.fma.rz.ftz.sat.f(float %a, float %b, float %r3) + + %r5 = call float @llvm.nvvm.fma.rm.sat.f(float %a, float %b, float %r4) + %r6 = call float @llvm.nvvm.fma.rm.ftz.sat.f(float %a, float %b, float %r5) + + %r7 = call float @llvm.nvvm.fma.rp.sat.f(float %a, float %b, float %r6) + %r8 = call float @llvm.nvvm.fma.rp.ftz.sat.f(float %a, float %b, float %r7) + + ret float %r8 +} diff --git a/llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll b/llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll new file mode 100644 index 0000000000000..1f6bf5f9e16f2 --- /dev/null +++ b/llvm/test/CodeGen/NVPTX/fp-sub-intrins.ll @@ -0,0 +1,55 @@ +; NOTE: Assertions have been autogenerated by utils/update_llc_test_checks.py UTC_ARGS: --version 6 +; RUN: llc < %s -mtriple=nvptx64 -mcpu=sm_20 | FileCheck %s +; RUN: %if ptxas-sm_20 %{ llc < %s -mtriple=nvptx64 -mcpu=sm_20 | %ptxas-verify -arch=sm_20 %} + +define float @sub_f32(float %a, float %b) { +; CHECK-LABEL: sub_f32( +; CHECK: { +; CHECK-NEXT: .reg .b32 %r<11>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b32 %r1, [sub_f32_param_0]; +; CHECK-NEXT: ld.param.b32 %r2, [sub_f32_param_1]; +; CHECK-NEXT: sub.rn.f32 %r3, %r1, %r2; +; CHECK-NEXT: sub.rn.ftz.f32 %r4, %r1, %r3; +; CHECK-NEXT: sub.rz.f32 %r5, %r1, %r4; +; CHECK-NEXT: sub.rz.ftz.f32 %r6, %r1, %r5; +; CHECK-NEXT: sub.rm.f32 %r7, %r1, %r6; +; CHECK-NEXT: sub.rm.ftz.f32 %r8, %r1, %r7; +; CHECK-NEXT: sub.rp.f32 %r9, %r1, %r8; +; CHECK-NEXT: sub.rp.ftz.f32 %r10, %r1, %r9; +; CHECK-NEXT: st.param.b32 [func_retval0], %r10; +; CHECK-NEXT: ret; + %r1 = call float @llvm.nvvm.sub.rn.f(float %a, float %b) + %r2 = call float @llvm.nvvm.sub.rn.ftz.f(float %a, float %r1) + %r3 = call float @llvm.nvvm.sub.rz.f(float %a, float %r2) + %r4 = call float @llvm.nvvm.sub.rz.ftz.f(float %a, float %r3) + %r5 = call float @llvm.nvvm.sub.rm.f(float %a, float %r4) + %r6 = call float @llvm.nvvm.sub.rm.ftz.f(float %a, float %r5) + %r7 = call float @llvm.nvvm.sub.rp.f(float %a, float %r6) + %r8 = call float @llvm.nvvm.sub.rp.ftz.f(float %a, float %r7) + + ret float %r8 +} + +define double @sub_f64(double %a, double %b) { +; CHECK-LABEL: sub_f64( +; CHECK: { +; CHECK-NEXT: .reg .b64 %rd<7>; +; CHECK-EMPTY: +; CHECK-NEXT: // %bb.0: +; CHECK-NEXT: ld.param.b64 %rd1, [sub_f64_param_0]; +; CHECK-NEXT: ld.param.b64 %rd2, [sub_f64_param_1]; +; CHECK-NEXT: sub.rn.f64 %rd3, %rd1, %rd2; +; CHECK-NEXT: sub.rz.f64 %rd4, %rd1, %rd3; +; CHECK-NEXT: sub.rm.f64 %rd5, %rd1, %rd4; +; CHECK-NEXT: sub.rp.f64 %rd6, %rd1, %rd5; +; CHECK-NEXT: st.param.b64 [func_retval0], %rd6; +; CHECK-NEXT: ret; + %r1 = call double @llvm.nvvm.sub.rn.d(double %a, double %b) + %r2 = call double @llvm.nvvm.sub.rz.d(double %a, double %r1) + %r3 = call double @llvm.nvvm.sub.rm.d(double %a, double %r2) + %r4 = call double @llvm.nvvm.sub.rp.d(double %a, double %r3) + + ret double %r4 +} diff --git a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll index adebcf868b2e6..535e60c99526a 100644 --- a/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll +++ b/llvm/test/CodeGen/NVPTX/mixed-precision-fp.ll @@ -8,74 +8,69 @@ define float @test_add_f32_f16(half %a, float %b) { ; CHECK-LABEL: test_add_f32_f16( ; CHECK: { ; CHECK-NEXT: .reg .b16 %rs<2>; -; CHECK-NEXT: .reg .b32 %r<12>; +; CHECK-NEXT: .reg .b32 %r<10>; ; CHECK-EMPTY: ; CHECK-NEXT: // %bb.0: ; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_f16_param_0]; ; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_f16_param_1]; -; CHECK-NEXT: add.f32.f16 %r2, %rs1, %r1; -; CHECK-NEXT: add.rn.f32.f16 %r3, %rs1, %r2; -; CHECK-NEXT: add.rz.f32.f16 %r4, %rs1, %r3; -; CHECK-NEXT: add.rm.f32.f16 %r5, %rs1, %r4; -; CHECK-NEXT: add.rp.f32.f16 %r6, %rs1, %r5; -; CHECK-NEXT: add.sat.f32.f16 %r7, %rs1, %r6; -; CHECK-NEXT: add.rn.sat.f32.f16 %r8, %rs1, %r7; -; CHECK-NEXT: add.rz.sat.f32.f16 %r9, %rs1, %r8; -; CHECK-NEXT: add.rm.sat.f32.f16 %r10, %rs1, %r9; -; CHECK-NEXT: add.rp.sat.f32.f16 %r11, %rs1, %r10; -; CHECK-NEXT: st.param.b32 [func_retval0], %r11; +; CHECK-NEXT: add.rn.f32.f16 %r2, %rs1, %r1; +; CHECK-NEXT: add.rz.f32.f16 %r3, %rs1, %r2; +; CHECK-NEXT: add.rm.f32.f16 %r4, %rs1, %r3; +; CHECK-NEXT: add.rp.f32.f16 %r5, %rs1, %r4; +; CHECK-NEXT: add.rn.sat.f32.f16 %r6, %rs1, %r5; +; CHECK-NEXT: add.rz.sat.f32.f16 %r7, %rs1, %r6; +; CHECK-NEXT: add.rm.sat.f32.f16 %r8, %rs1, %r7; +; CHECK-NEXT: add.rp.sat.f32.f16 %r9, %rs1, %r8; +; CHECK-NEXT: st.param.b32 [func_retval0], %r9; ; CHECK-NEXT: ret; - %r1 = call float @llvm.nvvm.add.mixed.f32.f16(half %a, float %b) - %r2 = call float @llvm.nvvm.add.mixed.rn.f32.f16(half %a, float %r1) - %r3 = call float @llvm.nvvm.add.mixed.rz.f32.f16(half %a, float %r2) - %r4 = call float @llvm.nvvm.add.mixed.rm.f32.f16(half %a, float %r3) - %r5 = call float @llvm.nvvm.add.mixed.rp.f32.f16(half %a, float %r4) + %r0 = fpext half %a to float + + %r1 = call float @llvm.nvvm.add.rn.f(float %r0, float %b) + %r2 = call float @llvm.nvvm.add.rz.f(float %r0, float %r1) + %r3 = call float @llvm.nvvm.add.rm.f(float %r0, float %r2) + %r4 = call float @llvm.nvvm.add.rp.f(float %r0, float %r3) ; SAT - %r6 = call float @llvm.nvvm.add.mixed.sat.f32.f16(half %a, float %r5) - %r7 = call float @llvm.nvvm.add.mixed.rn.sat.f32.f16(half %a, float %r6) - %r8 = call float @llvm.nvvm.add.mixed.rz.sat.f32.f16(half %a, float %r7) - %r9 = call float @llvm.nvvm.add.mixed.rm.sat.f32.f16(half %a, float %r8) - %r10 = call float @llvm.nvvm.add.mixed.rp.sat.f32.f16(half %a, float %r9) + %r5 = call float @llvm.nvvm.add.rn.sat.f(float %r0, float %r4) + %r6 = call float @llvm.nvvm.add.rz.sat.f(float %r0, float %r5) + %r7 = call float @llvm.nvvm.add.rm.sat.f(float %r0, float %r6) + %r8 = call float @llvm.nvvm.add.rp.sat.f(float %r0, float %r7) - ret float %r10 + ret float %r8 } define float @test_add_f32_bf16(bfloat %a, float %b) { ; CHECK-LABEL: test_add_f32_bf16( ; CHECK: { ; CHECK-NEXT: .reg .b16 %rs<2>; -; CHECK-NEXT: .reg .b32 %r<12>; +; CHECK-NEXT: .reg .b32 %r<10>; ; CHECK-EMPTY: ; CHECK-NEXT: // %bb.0: ; CHECK-NEXT: ld.param.b16 %rs1, [test_add_f32_bf16_param_0]; ; CHECK-NEXT: ld.param.b32 %r1, [test_add_f32_bf16_param_1]; -; CHECK-NEXT: add.f32.bf16 %r2, %rs1, %r1; -; CHECK-NEXT: add.rn.f32.bf16 %r3, %rs1, %r2; -; CHECK-NEXT: add.rz.f32.bf16 %r4, %rs1, %r3; -; CHECK-NEXT: add.rm.f32.bf16 %r5, %rs1, %r4; -; CHECK-NEXT: add.rp.f32.bf16 %r6, %rs1, %r5; -; CHECK-NEXT: add.sat.f32.bf16 %r7, %rs1, %r6; -; CHECK-NEXT: add.rn.sat.f32.bf16 %r8, %rs1, %r7; -; CHECK-NEXT: add.rz.sat.f32.bf16 %r9, %rs1, %r8; -; CHECK-NEXT: add.rm.sat.f32.bf16 %r10, %rs1, %r9; -; CHECK-NEXT: add.rp.sat.f32.bf16 %r11, %rs1, %r10; -; CHECK-NEXT: st.param.b32 [func_retval0], %r11; +; CHECK-NEXT: add.rn.f32.bf16 %r2, %rs1, %r1; +; CHECK-NEXT: add.rz.f32.bf16 %r3, %rs1, %r2; +; CHECK-NEXT: add.rm.f32.bf16 %r4, %rs1, %r3; +; CHECK-NEXT: add.rp.f32.bf16 %r5, %rs1, %r4; +; CHECK-NEXT: add.rn.sat.f32.bf16 %r6, %rs1, %r5; +; CHECK-NEXT: add.rz.sat.f32.bf16 %r7, %rs1, %r6; +; CHECK-NEXT: add.rm.sat.f32.bf16 %r8, %rs1, %r7; +; CHECK-NEXT: add.rp.sat.f32.bf16 %r9, %rs1, %r8; +; CHECK-NEXT: st.param.b32 [func_retval0], %r9; ; CHECK-NEXT: ret; - %r1 = call float @llvm.nvvm.add.mixed.f32.bf16(bfloat %a, float %b) - %r2 = call float @llvm.nvvm.add.mixed.rn.f32.bf16(bfloat %a, float %r1) - %r3 = call float @llvm.nvvm.add.mixed.rz.f32.bf16(bfloat %a, float %r2) - %r4 = call float @llvm.nvvm.add.mixed.rm.f32.bf16(bfloat %a, float %r3) - %r5 = call float @llvm.nvvm.add.mixed.rp.f32.bf16(bfloat %a, float %r4) + %r0 = fpext bfloat %a to float - ; SAT - %r6 = call float @llvm.nvvm.add.mixed.sat.f32.bf16(bfloat %a, float %r5) - %r7 = call float @llvm.nvvm.add.mixed.rn.sat.f32.bf16(bfloat %a, float %r6) - %r8 = call float @llvm.nvvm.add.mixed.rz.sat.f32.bf16(bfloat %a, float %r7) - %r9 = call float @llvm.nvvm.add.mixed.rm.sat.f32.bf16(bfloat %a, float %r8) - %r10 = call float @llvm.nvvm.add.mixed.rp.sat.f32.bf16(bfloat %a, float %r9) + %r1 = call float @llvm.nvvm.add.rn.f(float %r0, float %b) + %r2 = call float @llvm.nvvm.add.rz.f(float %r0, float %r1) + %r3 = call float @llvm.nvvm.add.rm.f(float %r0, float %r2) + %r4 = call float @llvm.nvvm.add.rp.f(float %r0, float %r3) - ret float %r10 + ; SAT + %r5 = call float @llvm.nvvm.add.rn.sat.f(float %r0, float %r4) + %r6 = call float @llvm.nvvm.add.rz.sat.f(float %r0, float %r5) + %r7 = call float @llvm.nvvm.add.rm.sat.f(float %r0, float %r6) + %r8 = call float @llvm.nvvm.add.rp.sat.f(float %r0, float %r7) + ret float %r8 } ; SUB @@ -84,74 +79,69 @@ define float @test_sub_f32_f16(half %a, float %b) { ; CHECK-LABEL: test_sub_f32_f16( ; CHECK: { ; CHECK-NEXT: .reg .b16 %rs<2>; -; CHECK-NEXT: .reg .b32 %r<12>; +; CHECK-NEXT: .reg .b32 %r<9>; ; CHECK-EMPTY: ; CHECK-NEXT: // %bb.0: ; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_f16_param_0]; ; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_f16_param_1]; -; CHECK-NEXT: sub.f32.f16 %r2, %rs1, %r1; -; CHECK-NEXT: sub.rn.f32.f16 %r3, %rs1, %r2; -; CHECK-NEXT: sub.rz.f32.f16 %r4, %rs1, %r3; +; CHECK-NEXT: sub.rn.f32.f16 %r2, %rs1, %r1; +; CHECK-NEXT: sub.rz.f32.f16 %r3, %rs1, %r2; +; CHECK-NEXT: sub.rm.f32.f16 %r4, %rs1, %r3; ; CHECK-NEXT: sub.rm.f32.f16 %r5, %rs1, %r4; -; CHECK-NEXT: sub.rp.f32.f16 %r6, %rs1, %r5; -; CHECK-NEXT: sub.sat.f32.f16 %r7, %rs1, %r6; -; CHECK-NEXT: sub.rn.sat.f32.f16 %r8, %rs1, %r7; -; CHECK-NEXT: sub.rz.sat.f32.f16 %r9, %rs1, %r8; -; CHECK-NEXT: sub.rm.sat.f32.f16 %r10, %rs1, %r9; -; CHECK-NEXT: sub.rp.sat.f32.f16 %r11, %rs1, %r10; -; CHECK-NEXT: st.param.b32 [func_retval0], %r11; +; CHECK-NEXT: sub.rn.sat.f32.f16 %r6, %rs1, %r5; +; CHECK-NEXT: sub.rz.sat.f32.f16 %r7, %rs1, %r6; +; CHECK-NEXT: sub.rm.sat.f32.f16 %r8, %rs1, %r7; +; CHECK-NEXT: st.param.b32 [func_retval0], %r8; ; CHECK-NEXT: ret; - %r1 = call float @llvm.nvvm.sub.mixed.f32.f16(half %a, float %b) - %r2 = call float @llvm.nvvm.sub.mixed.rn.f32.f16(half %a, float %r1) - %r3 = call float @llvm.nvvm.sub.mixed.rz.f32.f16(half %a, float %r2) - %r4 = call float @llvm.nvvm.sub.mixed.rm.f32.f16(half %a, float %r3) - %r5 = call float @llvm.nvvm.sub.mixed.rp.f32.f16(half %a, float %r4) + %r0 = fpext half %a to float + + %r1 = call float @llvm.nvvm.sub.rn.f(float %r0, float %b) + %r2 = call float @llvm.nvvm.sub.rz.f(float %r0, float %r1) + %r3 = call float @llvm.nvvm.sub.rm.f(float %r0, float %r2) + %r4 = call float @llvm.nvvm.sub.rm.f(float %r0, float %r3) ; SAT - %r6 = call float @llvm.nvvm.sub.mixed.sat.f32.f16(half %a, float %r5) - %r7 = call float @llvm.nvvm.sub.mixed.rn.sat.f32.f16(half %a, float %r6) - %r8 = call float @llvm.nvvm.sub.mixed.rz.sat.f32.f16(half %a, float %r7) - %r9 = call float @llvm.nvvm.sub.mixed.rm.sat.f32.f16(half %a, float %r8) - %r10 = call float @llvm.nvvm.sub.mixed.rp.sat.f32.f16(half %a, float %r9) + %r5 = call float @llvm.nvvm.sub.rn.sat.f(float %r0, float %r4) + %r6 = call float @llvm.nvvm.sub.rz.sat.f(float %r0, float %r5) + %r7 = call float @llvm.nvvm.sub.rm.sat.f(float %r0, float %r6) + %r8 = call float @llvm.nvvm.sub.rp.sat.f(float %r0, float %r7) - ret float %r10 + ret float %r7 } define float @test_sub_f32_bf16(bfloat %a, float %b) { ; CHECK-LABEL: test_sub_f32_bf16( ; CHECK: { ; CHECK-NEXT: .reg .b16 %rs<2>; -; CHECK-NEXT: .reg .b32 %r<12>; +; CHECK-NEXT: .reg .b32 %r<10>; ; CHECK-EMPTY: ; CHECK-NEXT: // %bb.0: ; CHECK-NEXT: ld.param.b16 %rs1, [test_sub_f32_bf16_param_0]; ; CHECK-NEXT: ld.param.b32 %r1, [test_sub_f32_bf16_param_1]; -; CHECK-NEXT: sub.f32.bf16 %r2, %rs1, %r1; -; CHECK-NEXT: sub.rn.f32.bf16 %r3, %rs1, %r2; -; CHECK-NEXT: sub.rz.f32.bf16 %r4, %rs1, %r3; -; CHECK-NEXT: sub.rm.f32.bf16 %r5, %rs1, %r4; -; CHECK-NEXT: sub.rp.f32.bf16 %r6, %rs1, %r5; -; CHECK-NEXT: sub.sat.f32.bf16 %r7, %rs1, %r6; -; CHECK-NEXT: sub.rn.sat.f32.bf16 %r8, %rs1, %r7; -; CHECK-NEXT: sub.rz.sat.f32.bf16 %r9, %rs1, %r8; -; CHECK-NEXT: sub.rm.sat.f32.bf16 %r10, %rs1, %r9; -; CHECK-NEXT: sub.rp.sat.f32.bf16 %r11, %rs1, %r10; -; CHECK-NEXT: st.param.b32 [func_retval0], %r11; +; CHECK-NEXT: sub.rn.f32.bf16 %r2, %rs1, %r1; +; CHECK-NEXT: sub.rz.f32.bf16 %r3, %rs1, %r2; +; CHECK-NEXT: sub.rm.f32.bf16 %r4, %rs1, %r3; +; CHECK-NEXT: sub.rp.f32.bf16 %r5, %rs1, %r4; +; CHECK-NEXT: sub.rn.sat.f32.bf16 %r6, %rs1, %r5; +; CHECK-NEXT: sub.rz.sat.f32.bf16 %r7, %rs1, %r6; +; CHECK-NEXT: sub.rm.sat.f32.bf16 %r8, %rs1, %r7; +; CHECK-NEXT: sub.rp.sat.f32.bf16 %r9, %rs1, %r8; +; CHECK-NEXT: st.param.b32 [func_retval0], %r9; ; CHECK-NEXT: ret; - %r1 = call float @llvm.nvvm.sub.mixed.f32.bf16(bfloat %a, float %b) - %r2 = call float @llvm.nvvm.sub.mixed.rn.f32.bf16(bfloat %a, float %r1) - %r3 = call float @llvm.nvvm.sub.mixed.rz.f32.bf16(bfloat %a, float %r2) - %r4 = call float @llvm.nvvm.sub.mixed.rm.f32.bf16(bfloat %a, float %r3) - %r5 = call float @llvm.nvvm.sub.mixed.rp.f32.bf16(bfloat %a, float %r4) + %r0 = fpext bfloat %a to float + + %r1 = call float @llvm.nvvm.sub.rn.f(float %r0, float %b) + %r2 = call float @llvm.nvvm.sub.rz.f(float %r0, float %r1) + %r3 = call float @llvm.nvvm.sub.rm.f(float %r0, float %r2) + %r4 = call float @llvm.nvvm.sub.rp.f(float %r0, float %r3) ; SAT - %r6 = call float @llvm.nvvm.sub.mixed.sat.f32.bf16(bfloat %a, float %r5) - %r7 = call float @llvm.nvvm.sub.mixed.rn.sat.f32.bf16(bfloat %a, float %r6) - %r8 = call float @llvm.nvvm.sub.mixed.rz.sat.f32.bf16(bfloat %a, float %r7) - %r9 = call float @llvm.nvvm.sub.mixed.rm.sat.f32.bf16(bfloat %a, float %r8) - %r10 = call float @llvm.nvvm.sub.mixed.rp.sat.f32.bf16(bfloat %a, float %r9) + %r5 = call float @llvm.nvvm.sub.rn.sat.f(float %r0, float %r4) + %r6 = call float @llvm.nvvm.sub.rz.sat.f(float %r0, float %r5) + %r7 = call float @llvm.nvvm.sub.rm.sat.f(float %r0, float %r6) + %r8 = call float @llvm.nvvm.sub.rp.sat.f(float %r0, float %r7) - ret float %r10 + ret float %r8 } ; FMA @@ -160,7 +150,7 @@ define float @test_fma_f32_f16(half %a, half %b, float %c) { ; CHECK-LABEL: test_fma_f32_f16( ; CHECK: { ; CHECK-NEXT: .reg .b16 %rs<3>; -; CHECK-NEXT: .reg .b32 %r<10>; +; CHECK-NEXT: .reg .b32 %r<9>; ; CHECK-EMPTY: ; CHECK-NEXT: // %bb.0: ; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_f16_param_0]; @@ -173,19 +163,21 @@ define float @test_fma_f32_f16(half %a, half %b, float %c) { ; CHECK-NEXT: fma.rn.sat.f32.f16 %r6, %rs1, %rs2, %r5; ; CHECK-NEXT: fma.rz.sat.f32.f16 %r7, %rs1, %rs2, %r6; ; CHECK-NEXT: fma.rm.sat.f32.f16 %r8, %rs1, %rs2, %r7; -; CHECK-NEXT: fma.rp.sat.f32.f16 %r9, %rs1, %rs2, %r8; -; CHECK-NEXT: st.param.b32 [func_retval0], %r9; +; CHECK-NEXT: st.param.b32 [func_retval0], %r8; ; CHECK-NEXT: ret; - %r1 = call float @llvm.nvvm.fma.mixed.rn.f32.f16(half %a, half %b, float %c) - %r2 = call float @llvm.nvvm.fma.mixed.rz.f32.f16(half %a, half %b, float %r1) - %r3 = call float @llvm.nvvm.fma.mixed.rm.f32.f16(half %a, half %b, float %r2) - %r4 = call float @llvm.nvvm.fma.mixed.rp.f32.f16(half %a, half %b, float %r3) + %r0 = fpext half %a to float + %r1 = fpext half %b to float + + %r2 = call float @llvm.nvvm.fma.rn.f(float %r0, float %r1, float %c) + %r3 = call float @llvm.nvvm.fma.rz.f(float %r0, float %r1, float %r2) + %r4 = call float @llvm.nvvm.fma.rm.f(float %r0, float %r1, float %r3) + %r5 = call float @llvm.nvvm.fma.rp.f(float %r0, float %r1, float %r4) ; SAT - %r5 = call float @llvm.nvvm.fma.mixed.rn.sat.f32.f16(half %a, half %b, float %r4) - %r6 = call float @llvm.nvvm.fma.mixed.rz.sat.f32.f16(half %a, half %b, float %r5) - %r7 = call float @llvm.nvvm.fma.mixed.rm.sat.f32.f16(half %a, half %b, float %r6) - %r8 = call float @llvm.nvvm.fma.mixed.rp.sat.f32.f16(half %a, half %b, float %r7) + %r6 = call float @llvm.nvvm.fma.rn.sat.f(float %r0, float %r1, float %r5) + %r7 = call float @llvm.nvvm.fma.rz.sat.f(float %r0, float %r1, float %r6) + %r8 = call float @llvm.nvvm.fma.rm.sat.f(float %r0, float %r1, float %r7) + %r9 = call float @llvm.nvvm.fma.rp.sat.f(float %r0, float %r1, float %r8) ret float %r8 } @@ -194,7 +186,7 @@ define float @test_fma_f32_bf16(bfloat %a, bfloat %b, float %c) { ; CHECK-LABEL: test_fma_f32_bf16( ; CHECK: { ; CHECK-NEXT: .reg .b16 %rs<3>; -; CHECK-NEXT: .reg .b32 %r<10>; +; CHECK-NEXT: .reg .b32 %r<9>; ; CHECK-EMPTY: ; CHECK-NEXT: // %bb.0: ; CHECK-NEXT: ld.param.b16 %rs1, [test_fma_f32_bf16_param_0]; @@ -207,19 +199,21 @@ define float @test_fma_f32_bf16(bfloat %a, bfloat %b, float %c) { ; CHECK-NEXT: fma.rn.sat.f32.bf16 %r6, %rs1, %rs2, %r5; ; CHECK-NEXT: fma.rz.sat.f32.bf16 %r7, %rs1, %rs2, %r6; ; CHECK-NEXT: fma.rm.sat.f32.bf16 %r8, %rs1, %rs2, %r7; -; CHECK-NEXT: fma.rp.sat.f32.bf16 %r9, %rs1, %rs2, %r8; -; CHECK-NEXT: st.param.b32 [func_retval0], %r9; +; CHECK-NEXT: st.param.b32 [func_retval0], %r8; ; CHECK-NEXT: ret; - %r1 = call float @llvm.nvvm.fma.mixed.rn.f32.bf16(bfloat %a, bfloat %b, float %c) - %r2 = call float @llvm.nvvm.fma.mixed.rz.f32.bf16(bfloat %a, bfloat %b, float %r1) - %r3 = call float @llvm.nvvm.fma.mixed.rm.f32.bf16(bfloat %a, bfloat %b, float %r2) - %r4 = call float @llvm.nvvm.fma.mixed.rp.f32.bf16(bfloat %a, bfloat %b, float %r3) + %r0 = fpext bfloat %a to float + %r1 = fpext bfloat %b to float + + %r2 = call float @llvm.nvvm.fma.rn.f(float %r0, float %r1, float %c) + %r3 = call float @llvm.nvvm.fma.rz.f(float %r0, float %r1, float %r2) + %r4 = call float @llvm.nvvm.fma.rm.f(float %r0, float %r1, float %r3) + %r5 = call float @llvm.nvvm.fma.rp.f(float %r0, float %r1, float %r4) ; SAT - %r5 = call float @llvm.nvvm.fma.mixed.rn.sat.f32.bf16(bfloat %a, bfloat %b, float %r4) - %r6 = call float @llvm.nvvm.fma.mixed.rz.sat.f32.bf16(bfloat %a, bfloat %b, float %r5) - %r7 = call float @llvm.nvvm.fma.mixed.rm.sat.f32.bf16(bfloat %a, bfloat %b, float %r6) - %r8 = call float @llvm.nvvm.fma.mixed.rp.sat.f32.bf16(bfloat %a, bfloat %b, float %r7) + %r6 = call float @llvm.nvvm.fma.rn.sat.f(float %r0, float %r1, float %r5) + %r7 = call float @llvm.nvvm.fma.rz.sat.f(float %r0, float %r1, float %r6) + %r8 = call float @llvm.nvvm.fma.rm.sat.f(float %r0, float %r1, float %r7) + %r9 = call float @llvm.nvvm.fma.rp.sat.f(float %r0, float %r1, float %r8) ret float %r8 } _______________________________________________ cfe-commits mailing list [email protected] https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits
