Author: Zhengxing li Date: 2024-11-26T10:45:31-08:00 New Revision: 5fd4f32f985f83414d82a1c2c55741e363693352
URL: https://github.com/llvm/llvm-project/commit/5fd4f32f985f83414d82a1c2c55741e363693352 DIFF: https://github.com/llvm/llvm-project/commit/5fd4f32f985f83414d82a1c2c55741e363693352.diff LOG: [HLSL] Implement SV_GroupID semantic (#115911) Support SV_GroupID attribute. Translate it into dx.group.id in clang codeGen. Fixes: #70120 Added: clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl Modified: clang/include/clang/Basic/Attr.td clang/include/clang/Basic/AttrDocs.td clang/include/clang/Sema/SemaHLSL.h clang/lib/CodeGen/CGHLSLRuntime.cpp clang/lib/Parse/ParseHLSL.cpp clang/lib/Sema/SemaDeclAttr.cpp clang/lib/Sema/SemaHLSL.cpp clang/test/SemaHLSL/Semantics/entry_parameter.hlsl clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl Removed: ################################################################################ diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 6db36a015acfd7..b055cbd769bb50 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4621,6 +4621,13 @@ def HLSLNumThreads: InheritableAttr { let Documentation = [NumThreadsDocs]; } +def HLSLSV_GroupID: HLSLAnnotationAttr { + let Spellings = [HLSLAnnotation<"SV_GroupID">]; + let Subjects = SubjectList<[ParmVar, Field]>; + let LangOpts = [HLSL]; + let Documentation = [HLSLSV_GroupIDDocs]; +} + def HLSLSV_GroupIndex: HLSLAnnotationAttr { let Spellings = [HLSLAnnotation<"SV_GroupIndex">]; let Subjects = SubjectList<[ParmVar, GlobalVar]>; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index cbbfedeec46cee..aafd4449e47004 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -7934,6 +7934,16 @@ randomized. }]; } +def HLSLSV_GroupIDDocs : Documentation { + let Category = DocCatFunction; + let Content = [{ +The ``SV_GroupID`` semantic, when applied to an input parameter, specifies which +thread group a shader is executing in. This attribute is only supported in compute shaders. + +The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupid + }]; +} + def HLSLSV_GroupIndexDocs : Documentation { let Category = DocCatFunction; let Content = [{ diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 06c541dec08cc8..ee685d95c96154 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase { void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL); void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL); void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL); + void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL); void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL); void handleShaderAttr(Decl *D, const ParsedAttr &AL); void handleResourceBindingAttr(Decl *D, const ParsedAttr &AL); @@ -136,6 +137,9 @@ class SemaHLSL : public SemaBase { bool CheckCompatibleParameterABI(FunctionDecl *New, FunctionDecl *Old); + // Diagnose whether the input ID is uint/unit2/uint3 type. + bool diagnoseInputIDType(QualType T, const ParsedAttr &AL); + ExprResult ActOnOutParamExpr(ParmVarDecl *Param, Expr *Arg); QualType getInoutParameterType(QualType Ty); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 7ba0d615018181..2c293523fca8ca 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -389,6 +389,10 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, CGM.getIntrinsic(getThreadIdIntrinsic()); return buildVectorInput(B, ThreadIDIntrinsic, Ty); } + if (D.hasAttr<HLSLSV_GroupIDAttr>()) { + llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id); + return buildVectorInput(B, GroupIDIntrinsic, Ty); + } assert(false && "Unhandled parameter attribute"); return nullptr; } diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp index 46a37e94353533..4de342b63ed802 100644 --- a/clang/lib/Parse/ParseHLSL.cpp +++ b/clang/lib/Parse/ParseHLSL.cpp @@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs, case ParsedAttr::UnknownAttribute: Diag(Loc, diag::err_unknown_hlsl_semantic) << II; return; + case ParsedAttr::AT_HLSLSV_GroupID: case ParsedAttr::AT_HLSLSV_GroupIndex: case ParsedAttr::AT_HLSLSV_DispatchThreadID: break; diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index 146d9c86e0715a..53cc8cb6afd7dc 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -7103,6 +7103,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_HLSLWaveSize: S.HLSL().handleWaveSizeAttr(D, AL); break; + case ParsedAttr::AT_HLSLSV_GroupID: + S.HLSL().handleSV_GroupIDAttr(D, AL); + break; case ParsedAttr::AT_HLSLSV_GroupIndex: handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL); break; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 8109c3a2cc0f1b..8b2f24a8e4be0a 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation( switch (AnnotationAttr->getKind()) { case attr::HLSLSV_DispatchThreadID: case attr::HLSLSV_GroupIndex: + case attr::HLSLSV_GroupID: if (ST == llvm::Triple::Compute) return; DiagnoseAttrStageMismatch(AnnotationAttr, ST, {llvm::Triple::Compute}); @@ -764,26 +765,36 @@ void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) { D->addAttr(NewAttr); } -static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) { - if (!T->hasUnsignedIntegerRepresentation()) +bool SemaHLSL::diagnoseInputIDType(QualType T, const ParsedAttr &AL) { + const auto *VT = T->getAs<VectorType>(); + + if (!T->hasUnsignedIntegerRepresentation() || + (VT && VT->getNumElements() > 3)) { + Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) + << AL << "uint/uint2/uint3"; return false; - if (const auto *VT = T->getAs<VectorType>()) - return VT->getNumElements() <= 3; + } + return true; } void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { auto *VD = cast<ValueDecl>(D); - if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) { - Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) - << AL << "uint/uint2/uint3"; + if (!diagnoseInputIDType(VD->getType(), AL)) return; - } D->addAttr(::new (getASTContext()) HLSLSV_DispatchThreadIDAttr(getASTContext(), AL)); } +void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) { + auto *VD = cast<ValueDecl>(D); + if (!diagnoseInputIDType(VD->getType(), AL)) + return; + + D->addAttr(::new (getASTContext()) HLSLSV_GroupIDAttr(getASTContext(), AL)); +} + void SemaHLSL::handlePackOffsetAttr(Decl *D, const ParsedAttr &AL) { if (!isa<VarDecl>(D) || !isa<HLSLBufferDecl>(D->getDeclContext())) { Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_ast_node) diff --git a/clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl new file mode 100644 index 00000000000000..5e09f0fe06d4e6 --- /dev/null +++ b/clang/test/CodeGenHLSL/semantics/SV_GroupID.hlsl @@ -0,0 +1,32 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s + +// Make sure SV_GroupID translated into dx.group.id. + +// CHECK: define void @foo() +// CHECK: %[[#ID:]] = call i32 @llvm.dx.group.id(i32 0) +// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]]) +[shader("compute")] +[numthreads(8,8,1)] +void foo(uint Idx : SV_GroupID) {} + +// CHECK: define void @bar() +// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.group.id(i32 0) +// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0 +// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.group.id(i32 1) +// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 +// CHECK: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]]) +[shader("compute")] +[numthreads(8,8,1)] +void bar(uint2 Idx : SV_GroupID) {} + +// CHECK: define void @test() +// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.group.id(i32 0) +// CHECK: %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0 +// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.group.id(i32 1) +// CHECK: %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 +// CHECK: %[[#ID_Z:]] = call i32 @llvm.dx.group.id(i32 2) +// CHECK: %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2 +// CHECK: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]]) +[shader("compute")] +[numthreads(8,8,1)] +void test(uint3 Idx : SV_GroupID) {} diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl index 8484259f84692b..13c07038d2e4a4 100644 --- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl @@ -2,12 +2,15 @@ // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s [numthreads(8,8,1)] -// expected-error@+2 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+1 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}} -void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID) { -// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint)' +// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}} +// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}} +// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}} +void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)' // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int' // CHECK-NEXT: HLSLSV_GroupIndexAttr // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint' // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint' +// CHECK-NEXT: HLSLSV_GroupIDAttr } diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl index bc3cf8bc51daf4..4e1f88aa2294b5 100644 --- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl @@ -27,3 +27,25 @@ struct ST2 { static uint X : SV_DispatchThreadID; uint s : SV_DispatchThreadID; }; + +[numthreads(8,8,1)] +// expected-error@+1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}} +void CSMain_GID(float ID : SV_GroupID) { +} + +[numthreads(8,8,1)] +// expected-error@+1 {{attribute 'SV_GroupID' only applies to a field or parameter of type 'uint/uint2/uint3'}} +void CSMain2_GID(ST GID : SV_GroupID) { + +} + +void foo_GID() { +// expected-warning@+1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}} + uint GIS : SV_GroupID; +} + +struct ST2_GID { +// expected-warning@+1 {{'SV_GroupID' attribute only applies to parameters and non-static data members}} + static uint GID : SV_GroupID; + uint s_gid : SV_GroupID; +}; diff --git a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl index 8e79fc4d85ec91..10a5e5dabac87b 100644 --- a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl @@ -24,3 +24,28 @@ void CSMain3(uint3 : SV_DispatchThreadID) { // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:20 'uint3' // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr } + +[numthreads(8,8,1)] +void CSMain_GID(uint ID : SV_GroupID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GID 'void (uint)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:22 ID 'uint' +// CHECK-NEXT: HLSLSV_GroupIDAttr +} +[numthreads(8,8,1)] +void CSMain1_GID(uint2 ID : SV_GroupID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GID 'void (uint2)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint2' +// CHECK-NEXT: HLSLSV_GroupIDAttr +} +[numthreads(8,8,1)] +void CSMain2_GID(uint3 ID : SV_GroupID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GID 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 ID 'uint3' +// CHECK-NEXT: HLSLSV_GroupIDAttr +} +[numthreads(8,8,1)] +void CSMain3_GID(uint3 : SV_GroupID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GID 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3' +// CHECK-NEXT: HLSLSV_GroupIDAttr +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits