llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Sumit Agarwal (sumitsays) <details> <summary>Changes</summary> Resolves #<!-- -->99221 Key points: For SPIRV backend, it decompose into a `dot` followed a `add`. - [x] Implement dot2add clang builtin, - [x] Link dot2add clang builtin with hlsl_intrinsics.h - [x] Add sema checks for dot2add to CheckHLSLBuiltinFunctionCall in SemaHLSL.cpp - [x] Add codegen for dot2add to EmitHLSLBuiltinExpr in CGBuiltin.cpp - [x] Add codegen tests to clang/test/CodeGenHLSL/builtins/dot2add.hlsl - [x] Add sema tests to clang/test/SemaHLSL/BuiltIns/dot2add-errors.hlsl - [x] Create the int_dx_dot2add intrinsic in IntrinsicsDirectX.td - [x] Create the DXILOpMapping of int_dx_dot2add to 162 in DXIL.td - [x] Create the dot2add.ll and dot2add_errors.ll tests in llvm/test/CodeGen/DirectX/ - [ ] ~~Create the int_spv_dot2add intrinsic in IntrinsicsSPIRV.td~~ --- Not needed - [ ] ~~In SPIRVInstructionSelector.cpp create the dot2add lowering and map it to int_spv_dot2add in SPIRVInstructionSelector::selectIntrinsic.~~ --- Not needed - [ ] ~~Create SPIR-V backend test case in llvm/test/CodeGen/SPIRV/hlsl-intrinsics/dot2add.ll~~ --- Not needed --- Full diff: https://github.com/llvm/llvm-project/pull/131237.diff 11 Files Affected: - (modified) clang/include/clang/Basic/Builtins.td (+6) - (modified) clang/lib/CodeGen/CGBuiltin.cpp (+15) - (modified) clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h (+8) - (modified) clang/lib/Headers/hlsl/hlsl_intrinsics.h (+12) - (modified) clang/lib/Sema/SemaHLSL.cpp (+50-9) - (added) clang/test/CodeGenHLSL/builtins/dot2add.hlsl (+17) - (added) clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl (+11) - (modified) llvm/include/llvm/IR/IntrinsicsDirectX.td (+4) - (modified) llvm/lib/Target/DirectX/DXIL.td (+11) - (modified) llvm/lib/Target/DirectX/DXILOpLowering.cpp (+47-2) - (added) llvm/test/CodeGen/DirectX/dot2add.ll (+8) ``````````diff diff --git a/clang/include/clang/Basic/Builtins.td b/clang/include/clang/Basic/Builtins.td index 72a5e495c4059..76ab463ca0ed6 100644 --- a/clang/include/clang/Basic/Builtins.td +++ b/clang/include/clang/Basic/Builtins.td @@ -4891,6 +4891,12 @@ def HLSLDotProduct : LangBuiltin<"HLSL_LANG"> { let Prototype = "void(...)"; } +def HLSLDot2Add : LangBuiltin<"HLSL_LANG"> { + let Spellings = ["__builtin_hlsl_dot2add"]; + let Attributes = [NoThrow, Const, CustomTypeChecking]; + let Prototype = "void(...)"; +} + def HLSLDot4AddI8Packed : LangBuiltin<"HLSL_LANG"> { let Spellings = ["__builtin_hlsl_dot4add_i8packed"]; let Attributes = [NoThrow, Const]; diff --git a/clang/lib/CodeGen/CGBuiltin.cpp b/clang/lib/CodeGen/CGBuiltin.cpp index c126f88b9e3a5..b3d9db5be7d8d 100644 --- a/clang/lib/CodeGen/CGBuiltin.cpp +++ b/clang/lib/CodeGen/CGBuiltin.cpp @@ -19681,6 +19681,21 @@ Value *CodeGenFunction::EmitHLSLBuiltinExpr(unsigned BuiltinID, getDotProductIntrinsic(CGM.getHLSLRuntime(), VecTy0->getElementType()), ArrayRef<Value *>{Op0, Op1}, nullptr, "hlsl.dot"); } + case Builtin::BI__builtin_hlsl_dot2add: { + llvm::Triple::ArchType Arch = CGM.getTarget().getTriple().getArch(); + if (Arch != llvm::Triple::dxil) { + llvm_unreachable("Intrinsic dot2add can be executed as a builtin only on dxil"); + } + Value *A = EmitScalarExpr(E->getArg(0)); + Value *B = EmitScalarExpr(E->getArg(1)); + Value *C = EmitScalarExpr(E->getArg(2)); + + //llvm::Intrinsic::dx_##IntrinsicPostfix + Intrinsic::ID ID = llvm ::Intrinsic::dx_dot2add; + return Builder.CreateIntrinsic( + /*ReturnType=*/C->getType(), ID, ArrayRef<Value *>{A, B, C}, nullptr, + "hlsl.dot2add"); + } case Builtin::BI__builtin_hlsl_dot4add_i8packed: { Value *A = EmitScalarExpr(E->getArg(0)); Value *B = EmitScalarExpr(E->getArg(1)); diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h index 5f7c047dbf340..46653d7b295b2 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsic_helpers.h @@ -45,6 +45,14 @@ distance_vec_impl(vector<T, N> X, vector<T, N> Y) { return length_vec_impl(X - Y); } +constexpr float dot2add_impl(half2 a, half2 b, float c) { +#if defined(__DIRECTX__) + return __builtin_hlsl_dot2add(a, b, c); +#else + return dot(a, b) + c; +#endif +} + template <typename T> constexpr T reflect_impl(T I, T N) { return I - 2 * N * I * N; } diff --git a/clang/lib/Headers/hlsl/hlsl_intrinsics.h b/clang/lib/Headers/hlsl/hlsl_intrinsics.h index 5459cbeb34fd0..b1c1335ce3328 100644 --- a/clang/lib/Headers/hlsl/hlsl_intrinsics.h +++ b/clang/lib/Headers/hlsl/hlsl_intrinsics.h @@ -117,6 +117,18 @@ const inline float distance(__detail::HLSL_FIXED_VECTOR<float, N> X, return __detail::distance_vec_impl(X, Y); } +//===----------------------------------------------------------------------===// +// dot2add builtins +//===----------------------------------------------------------------------===// + +/// \fn float dot2add(half2 a, half2 b, float c) +/// \brief Dot product of 2 vector of type half and add a float scalar value. + +_HLSL_AVAILABILITY(shadermodel, 6.4) +const inline float dot2add(half2 a, half2 b, float c) { + return __detail::dot2add_impl(a, b, c); +} + //===----------------------------------------------------------------------===// // fmod builtins //===----------------------------------------------------------------------===// diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 36de110e75e8a..399371c4ae2f6 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1989,7 +1989,7 @@ void SemaHLSL::diagnoseAvailabilityViolations(TranslationUnitDecl *TU) { } // Helper function for CheckHLSLBuiltinFunctionCall -static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) { +static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall, unsigned NumArgs) { assert(TheCall->getNumArgs() > 1); ExprResult A = TheCall->getArg(0); @@ -1999,7 +1999,7 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) { SourceLocation BuiltinLoc = TheCall->getBeginLoc(); bool AllBArgAreVectors = true; - for (unsigned i = 1; i < TheCall->getNumArgs(); ++i) { + for (unsigned i = 1; i < NumArgs; ++i) { ExprResult B = TheCall->getArg(i); QualType ArgTyB = B.get()->getType(); auto *VecTyB = ArgTyB->getAs<VectorType>(); @@ -2049,6 +2049,10 @@ static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) { return false; } +static bool CheckVectorElementCallArgs(Sema *S, CallExpr *TheCall) { + return CheckVectorElementCallArgs(S, TheCall, TheCall->getNumArgs()); +} + static bool CheckAllArgsHaveSameType(Sema *S, CallExpr *TheCall) { assert(TheCall->getNumArgs() > 1); QualType ArgTy0 = TheCall->getArg(0)->getType(); @@ -2091,10 +2095,10 @@ static bool CheckArgTypeIsCorrect( return false; } -static bool CheckAllArgTypesAreCorrect( - Sema *S, CallExpr *TheCall, QualType ExpectedType, +static bool CheckArgTypesAreCorrect( + Sema *S, CallExpr *TheCall, unsigned NumArgs, QualType ExpectedType, llvm::function_ref<bool(clang::QualType PassedType)> Check) { - for (unsigned i = 0; i < TheCall->getNumArgs(); ++i) { + for (unsigned i = 0; i < NumArgs; ++i) { Expr *Arg = TheCall->getArg(i); if (CheckArgTypeIsCorrect(S, Arg, ExpectedType, Check)) { return true; @@ -2103,6 +2107,13 @@ static bool CheckAllArgTypesAreCorrect( return false; } +static bool CheckAllArgTypesAreCorrect( + Sema *S, CallExpr *TheCall, QualType ExpectedType, + llvm::function_ref<bool(clang::QualType PassedType)> Check) { + return CheckArgTypesAreCorrect(S, TheCall, TheCall->getNumArgs(), + ExpectedType, Check); +} + static bool CheckAllArgsHaveFloatRepresentation(Sema *S, CallExpr *TheCall) { auto checkAllFloatTypes = [](clang::QualType PassedType) -> bool { return !PassedType->hasFloatingRepresentation(); @@ -2146,15 +2157,17 @@ static bool CheckModifiableLValue(Sema *S, CallExpr *TheCall, return true; } -static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall) { +static bool CheckNoDoubleVectors(Sema *S, CallExpr *TheCall, + unsigned NumArgs, QualType ExpectedType) { auto checkDoubleVector = [](clang::QualType PassedType) -> bool { if (const auto *VecTy = PassedType->getAs<VectorType>()) return VecTy->getElementType()->isDoubleType(); return false; }; - return CheckAllArgTypesAreCorrect(S, TheCall, S->Context.FloatTy, - checkDoubleVector); + return CheckArgTypesAreCorrect(S, TheCall, NumArgs, + ExpectedType, checkDoubleVector); } + static bool CheckFloatingOrIntRepresentation(Sema *S, CallExpr *TheCall) { auto checkAllSignedTypes = [](clang::QualType PassedType) -> bool { return !PassedType->hasIntegerRepresentation() && @@ -2468,8 +2481,36 @@ bool SemaHLSL::CheckBuiltinFunctionCall(unsigned BuiltinID, CallExpr *TheCall) { return true; if (SemaRef.BuiltinVectorToScalarMath(TheCall)) return true; - if (CheckNoDoubleVectors(&SemaRef, TheCall)) + if (CheckNoDoubleVectors(&SemaRef, TheCall, + TheCall->getNumArgs(), SemaRef.Context.FloatTy)) + return true; + break; + } + case Builtin::BI__builtin_hlsl_dot2add: { + // Check number of arguments should be 3 + if (SemaRef.checkArgCount(TheCall, 3)) + return true; + + // Check first two arguments are vector of length 2 with half data type + auto checkHalfVectorOfSize2 = [](clang::QualType PassedType) -> bool { + if (const auto *VecTy = PassedType->getAs<VectorType>()) + return !(VecTy->getNumElements() == 2 && + VecTy->getElementType()->isHalfType()); + return true; + }; + if(CheckArgTypeIsCorrect(&SemaRef, TheCall->getArg(0), + SemaRef.getASTContext().HalfTy, + checkHalfVectorOfSize2)) + return true; + if(CheckArgTypeIsCorrect(&SemaRef, TheCall->getArg(1), + SemaRef.getASTContext().HalfTy, + checkHalfVectorOfSize2)) + return true; + + // Check third argument is a float + if (CheckArgTypeMatches(&SemaRef, TheCall->getArg(2), SemaRef.getASTContext().FloatTy)) return true; + TheCall->setType(TheCall->getArg(2)->getType()); break; } case Builtin::BI__builtin_hlsl_elementwise_firstbithigh: diff --git a/clang/test/CodeGenHLSL/builtins/dot2add.hlsl b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl new file mode 100644 index 0000000000000..ce325327a01b5 --- /dev/null +++ b/clang/test/CodeGenHLSL/builtins/dot2add.hlsl @@ -0,0 +1,17 @@ +// RUN: %clang_cc1 -finclude-default-header -triple \ +// RUN: dxil-pc-shadermodel6.3-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-DXIL +// RUN: %clang_cc1 -finclude-default-header -triple \ +// RUN: spirv-pc-vulkan-compute %s -emit-llvm -disable-llvm-passes -o - | \ +// RUN: FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV + +// Test basic lowering to runtime function call. + +float test(half2 p1, half2 p2, float p3) { + // CHECK-SPIRV: %[[MUL:.*]] = call {{.*}} float @llvm.spv.fdot.v2f32(<2 x float> %1, <2 x float> %2) + // CHECK-SPIRV: %[[C:.*]] = load float, ptr %c.addr, align 4 + // CHECK-SPIRV: %[[RES:.*]] = fadd {{.*}} float %[[MUL]], %[[C]] + // CHECK-DXIL: %[[RES:.*]] = call {{.*}} float @llvm.dx.dot2add.v2f32(<2 x float> %0, <2 x float> %1, float %2) + // CHECK: ret float %[[RES]] + return dot2add(p1, p2, p3); +} \ No newline at end of file diff --git a/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl b/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl new file mode 100644 index 0000000000000..61282a319dafd --- /dev/null +++ b/clang/test/SemaHLSL/BuiltIns/Dot2Add-errors.hlsl @@ -0,0 +1,11 @@ +// RUN: %clang_cc1 -finclude-default-header -triple dxil-pc-shadermodel6.6-library %s -emit-llvm-only -disable-llvm-passes -verify + +bool test_too_few_arg() { + return __builtin_hlsl_dot2add(); + // expected-error@-1 {{too few arguments to function call, expected 3, have 0}} +} + +bool test_too_many_arg(half2 p1, half2 p2, float p3) { + return __builtin_hlsl_dot2add(p1, p2, p3, p1); + // expected-error@-1 {{too many arguments to function call, expected 3, have 4}} +} diff --git a/llvm/include/llvm/IR/IntrinsicsDirectX.td b/llvm/include/llvm/IR/IntrinsicsDirectX.td index ead7286f4311c..775d325feeb14 100644 --- a/llvm/include/llvm/IR/IntrinsicsDirectX.td +++ b/llvm/include/llvm/IR/IntrinsicsDirectX.td @@ -100,6 +100,10 @@ def int_dx_udot : DefaultAttrsIntrinsic<[LLVMVectorElementType<0>], [llvm_anyint_ty, LLVMScalarOrSameVectorWidth<0, LLVMVectorElementType<0>>], [IntrNoMem, Commutative] >; +def int_dx_dot2add : + DefaultAttrsIntrinsic<[llvm_float_ty], + [llvm_anyfloat_ty, LLVMMatchType<0>, llvm_float_ty], + [IntrNoMem, Commutative]>; def int_dx_dot4add_i8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; def int_dx_dot4add_u8packed : DefaultAttrsIntrinsic<[llvm_i32_ty], [llvm_i32_ty, llvm_i32_ty, llvm_i32_ty], [IntrNoMem]>; diff --git a/llvm/lib/Target/DirectX/DXIL.td b/llvm/lib/Target/DirectX/DXIL.td index ebe1d876d58b1..193b592a525a0 100644 --- a/llvm/lib/Target/DirectX/DXIL.td +++ b/llvm/lib/Target/DirectX/DXIL.td @@ -1098,6 +1098,17 @@ def RawBufferStore : DXILOp<140, rawBufferStore> { let stages = [Stages<DXIL1_2, [all_stages]>]; } +def Dot2AddHalf : DXILOp<162, dot2AddHalf> { + let Doc = "dot product of 2 vectors of half having size = 2, returns " + "float"; + let intrinsics = [IntrinSelect<int_dx_dot2add>]; + let arguments = [FloatTy, HalfTy, HalfTy, HalfTy, HalfTy]; + let result = FloatTy; + let overloads = [Overloads<DXIL1_0, []>]; + let stages = [Stages<DXIL1_0, [all_stages]>]; + let attributes = [Attributes<DXIL1_0, [ReadNone]>]; +} + def Dot4AddI8Packed : DXILOp<163, dot4AddPacked> { let Doc = "signed dot product of 4 x i8 vectors packed into i32, with " "accumulate to i32"; diff --git a/llvm/lib/Target/DirectX/DXILOpLowering.cpp b/llvm/lib/Target/DirectX/DXILOpLowering.cpp index dff9f3e03079e..f7ed0c5071d75 100644 --- a/llvm/lib/Target/DirectX/DXILOpLowering.cpp +++ b/llvm/lib/Target/DirectX/DXILOpLowering.cpp @@ -54,10 +54,36 @@ static SmallVector<Value *> populateOperands(Value *Arg, IRBuilder<> &Builder) { return ExtractedElements; } +static SmallVector<Value *> argVectorFlatten(CallInst *Orig, + IRBuilder<> &Builder, + unsigned NumOperands) { + assert(NumOperands > 0); + Value *Arg0 = Orig->getOperand(0); + [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType()); + assert(VecArg0); + SmallVector<Value *> NewOperands = populateOperands(Arg0, Builder); + for (unsigned I = 1; I < NumOperands; ++I) { + Value *Arg = Orig->getOperand(I); + [[maybe_unused]] auto *VecArg = dyn_cast<FixedVectorType>(Arg->getType()); + assert(VecArg); + assert(VecArg0->getElementType() == VecArg->getElementType()); + assert(VecArg0->getNumElements() == VecArg->getNumElements()); + auto NextOperandList = populateOperands(Arg, Builder); + NewOperands.append(NextOperandList.begin(), NextOperandList.end()); + } + return NewOperands; +} + static SmallVector<Value *> argVectorFlatten(CallInst *Orig, IRBuilder<> &Builder) { // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. - unsigned NumOperands = Orig->getNumOperands() - 1; + return argVectorFlatten(Orig, Builder, Orig->getNumOperands() - 1); +} +/* +static SmallVector<Value *> argVectorFlattenExcludeLastElement(CallInst *Orig, + IRBuilder<> &Builder) { + // Note: arg[NumOperands-1] is a pointer and is not needed by our flattening. + unsigned NumOperands = Orig->getNumOperands() - 2; assert(NumOperands > 0); Value *Arg0 = Orig->getOperand(0); [[maybe_unused]] auto *VecArg0 = dyn_cast<FixedVectorType>(Arg0->getType()); @@ -74,7 +100,7 @@ static SmallVector<Value *> argVectorFlatten(CallInst *Orig, } return NewOperands; } - +*/ namespace { class OpLowerer { Module &M; @@ -168,6 +194,25 @@ class OpLowerer { } } else if (IsVectorArgExpansion) { Args = argVectorFlatten(CI, OpBuilder.getIRB()); + } else if (F.getIntrinsicID() == Intrinsic::dx_dot2add) { + // arg[NumOperands-1] is a pointer and is not needed by our flattening. + // arg[NumOperands-2] also does not need to be flattened because it is a scalar. + unsigned NumOperands = CI->getNumOperands() - 2; + Args.push_back(CI->getArgOperand(NumOperands)); + Args.append(argVectorFlatten(CI, OpBuilder.getIRB(), NumOperands)); + + /*unsigned NumOperands = CI->getNumOperands() - 1; + assert(NumOperands > 0); + Value *LastArg = CI->getOperand(NumOperands - 1); + + Args.push_back(LastArg); + + //dbgs() << "Value of LastArg" << LastArg->getName() << "\n"; + + + //Args = populateOperands(LastArg, OpBuilder.getIRB()); + Args.append(argVectorFlattenExcludeLastElement(CI, OpBuilder.getIRB())); + */ } else { Args.append(CI->arg_begin(), CI->arg_end()); } diff --git a/llvm/test/CodeGen/DirectX/dot2add.ll b/llvm/test/CodeGen/DirectX/dot2add.ll new file mode 100644 index 0000000000000..b1019c36b56e8 --- /dev/null +++ b/llvm/test/CodeGen/DirectX/dot2add.ll @@ -0,0 +1,8 @@ +; RUN: opt -S -dxil-op-lower -mtriple=dxil-pc-shadermodel6.3-compute %s | FileCheck %s + +define noundef float @dot2add_simple(<2 x half> noundef %a, <2 x half> noundef %b, float %c) { +entry: +; CHECK: call float @dx.op.dot2AddHalf(i32 162, float %c, half %0, half %1, half %2, half %3) + %ret = call float @llvm.dx.dot2add(<2 x half> %a, <2 x half> %b, float %c) + ret float %ret +} \ No newline at end of file `````````` </details> https://github.com/llvm/llvm-project/pull/131237 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits