https://github.com/lizhengxing updated https://github.com/llvm/llvm-project/pull/117781
>From 2941d87dbaf091aa443ad57ce55e98e7bab83d2b Mon Sep 17 00:00:00 2001 From: Zhengxing Li <zhengxin...@microsoft.com> Date: Wed, 13 Nov 2024 10:54:16 -0800 Subject: [PATCH 1/3] [HLSL] Implement SV_GroupThreadId semantic Support SV_GroupThreadId attribute. Translate it into dx.thread.id.in.group in clang codeGen. Fixes: #70122 --- clang/include/clang/Basic/Attr.td | 7 ++++ clang/include/clang/Basic/AttrDocs.td | 11 +++++++ clang/include/clang/Sema/SemaHLSL.h | 1 + clang/lib/CodeGen/CGHLSLRuntime.cpp | 5 +++ clang/lib/Parse/ParseHLSL.cpp | 1 + clang/lib/Sema/SemaDeclAttr.cpp | 3 ++ clang/lib/Sema/SemaHLSL.cpp | 10 ++++++ .../semantics/SV_GroupThreadID.hlsl | 32 +++++++++++++++++++ .../SemaHLSL/Semantics/entry_parameter.hlsl | 13 +++++--- .../Semantics/invalid_entry_parameter.hlsl | 22 +++++++++++++ .../Semantics/valid_entry_parameter.hlsl | 25 +++++++++++++++ 11 files changed, 125 insertions(+), 5 deletions(-) create mode 100644 clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 17fc36fbe2ac8c..90d2a2056fe1ba 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4651,6 +4651,13 @@ def HLSLNumThreads: InheritableAttr { let Documentation = [NumThreadsDocs]; } +def HLSLSV_GroupThreadID: HLSLAnnotationAttr { + let Spellings = [HLSLAnnotation<"SV_GroupThreadID">]; + let Subjects = SubjectList<[ParmVar, Field]>; + let LangOpts = [HLSL]; + let Documentation = [HLSLSV_GroupThreadIDDocs]; +} + def HLSLSV_GroupID: HLSLAnnotationAttr { let Spellings = [HLSLAnnotation<"SV_GroupID">]; let Subjects = SubjectList<[ParmVar, Field]>; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index 7a82b8fa320590..fdad4c9a3ea191 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -7941,6 +7941,17 @@ randomized. }]; } +def HLSLSV_GroupThreadIDDocs : Documentation { + let Category = DocCatFunction; + let Content = [{ +The ``SV_GroupThreadID`` semantic, when applied to an input parameter, specifies which +individual thread within a thread group is executing in. This attribute is +only supported in compute shaders. + +The full documentation is available here: https://docs.microsoft.com/en-us/windows/win32/direct3dhlsl/sv-groupthreadid + }]; +} + def HLSLSV_GroupIDDocs : Documentation { let Category = DocCatFunction; let Content = [{ diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index ee685d95c96154..f4cd11f423a84a 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -119,6 +119,7 @@ class SemaHLSL : public SemaBase { void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL); void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL); void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL); + void handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL); void handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL); void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL); void handleShaderAttr(Decl *D, const ParsedAttr &AL); diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 2c293523fca8ca..19db7faddaeac0 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -389,6 +389,11 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, CGM.getIntrinsic(getThreadIdIntrinsic()); return buildVectorInput(B, ThreadIDIntrinsic, Ty); } + if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) { + llvm::Function *GroupThreadIDIntrinsic = + CGM.getIntrinsic(Intrinsic::dx_thread_id_in_group); + return buildVectorInput(B, GroupThreadIDIntrinsic, Ty); + } if (D.hasAttr<HLSLSV_GroupIDAttr>()) { llvm::Function *GroupIDIntrinsic = CGM.getIntrinsic(Intrinsic::dx_group_id); return buildVectorInput(B, GroupIDIntrinsic, Ty); diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp index 4de342b63ed802..443bf2b9ec626a 100644 --- a/clang/lib/Parse/ParseHLSL.cpp +++ b/clang/lib/Parse/ParseHLSL.cpp @@ -280,6 +280,7 @@ void Parser::ParseHLSLAnnotations(ParsedAttributes &Attrs, case ParsedAttr::UnknownAttribute: Diag(Loc, diag::err_unknown_hlsl_semantic) << II; return; + case ParsedAttr::AT_HLSLSV_GroupThreadID: case ParsedAttr::AT_HLSLSV_GroupID: case ParsedAttr::AT_HLSLSV_GroupIndex: case ParsedAttr::AT_HLSLSV_DispatchThreadID: diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index 4fd8ef6dbebf84..5d7ee097383771 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -7114,6 +7114,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_HLSLWaveSize: S.HLSL().handleWaveSizeAttr(D, AL); break; + case ParsedAttr::AT_HLSLSV_GroupThreadID: + S.HLSL().handleSV_GroupThreadIDAttr(D, AL); + break; case ParsedAttr::AT_HLSLSV_GroupID: S.HLSL().handleSV_GroupIDAttr(D, AL); break; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 88db3e12541193..600c800029fd05 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -434,6 +434,7 @@ void SemaHLSL::CheckSemanticAnnotation( switch (AnnotationAttr->getKind()) { case attr::HLSLSV_DispatchThreadID: case attr::HLSLSV_GroupIndex: + case attr::HLSLSV_GroupThreadID: case attr::HLSLSV_GroupID: if (ST == llvm::Triple::Compute) return; @@ -787,6 +788,15 @@ void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { HLSLSV_DispatchThreadIDAttr(getASTContext(), AL)); } +void SemaHLSL::handleSV_GroupThreadIDAttr(Decl *D, const ParsedAttr &AL) { + auto *VD = cast<ValueDecl>(D); + if (!diagnoseInputIDType(VD->getType(), AL)) + return; + + D->addAttr(::new (getASTContext()) + HLSLSV_GroupThreadIDAttr(getASTContext(), AL)); +} + void SemaHLSL::handleSV_GroupIDAttr(Decl *D, const ParsedAttr &AL) { auto *VD = cast<ValueDecl>(D); if (!diagnoseInputIDType(VD->getType(), AL)) diff --git a/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl new file mode 100644 index 00000000000000..3533331c6f091c --- /dev/null +++ b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl @@ -0,0 +1,32 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s + +// Make sure SV_GroupThreadID translated into dx.thread.id.in.group. + +// CHECK: define void @foo() +// CHECK: %[[#ID:]] = call i32 @llvm.dx.thread.id.in.group(i32 0) +// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]]) +[shader("compute")] +[numthreads(8,8,1)] +void foo(uint Idx : SV_GroupThreadID) {} + +// CHECK: define void @bar() +// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0) +// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0 +// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1) +// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 +// CHECK: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]]) +[shader("compute")] +[numthreads(8,8,1)] +void bar(uint2 Idx : SV_GroupThreadID) {} + +// CHECK: define void @test() +// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0) +// CHECK: %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0 +// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1) +// CHECK: %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 +// CHECK: %[[#ID_Z:]] = call i32 @llvm.dx.thread.id.in.group(i32 2) +// CHECK: %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2 +// CHECK: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]]) +[shader("compute")] +[numthreads(8,8,1)] +void test(uint3 Idx : SV_GroupThreadID) {} diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl index 13c07038d2e4a4..71d32cd13832e1 100644 --- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl @@ -2,15 +2,18 @@ // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s [numthreads(8,8,1)] -// expected-error@+3 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+2 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+1 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}} -void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID) { -// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint)' +// expected-error@+4 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}} +// expected-error@+3 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}} +// expected-error@+2 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}} +// expected-error@+1 {{attribute 'SV_GroupThreadID' is unsupported in 'mesh' shaders, requires compute}} +void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint, uint)' // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int' // CHECK-NEXT: HLSLSV_GroupIndexAttr // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:42 ID 'uint' // CHECK-NEXT: HLSLSV_DispatchThreadIDAttr // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:73 GID 'uint' // CHECK-NEXT: HLSLSV_GroupIDAttr +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:96 GThreadID 'uint' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr } diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl index 4e1f88aa2294b5..a24112c8e1bb8f 100644 --- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl @@ -49,3 +49,25 @@ struct ST2_GID { static uint GID : SV_GroupID; uint s_gid : SV_GroupID; }; + +[numthreads(8,8,1)] +// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}} +void CSMain_GThreadID(float ID : SV_GroupThreadID) { +} + +[numthreads(8,8,1)] +// expected-error@+1 {{attribute 'SV_GroupThreadID' only applies to a field or parameter of type 'uint/uint2/uint3'}} +void CSMain2_GThreadID(ST GID : SV_GroupThreadID) { + +} + +void foo_GThreadID() { +// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}} + uint GThreadIS : SV_GroupThreadID; +} + +struct ST2_GThreadID { +// expected-warning@+1 {{'SV_GroupThreadID' attribute only applies to parameters and non-static data members}} + static uint GThreadID : SV_GroupThreadID; + uint s_gthreadid : SV_GroupThreadID; +}; diff --git a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl index 10a5e5dabac87b..6781f9241df240 100644 --- a/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/valid_entry_parameter.hlsl @@ -49,3 +49,28 @@ void CSMain3_GID(uint3 : SV_GroupID) { // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:24 'uint3' // CHECK-NEXT: HLSLSV_GroupIDAttr } + +[numthreads(8,8,1)] +void CSMain_GThreadID(uint ID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain_GThreadID 'void (uint)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:28 ID 'uint' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} +[numthreads(8,8,1)] +void CSMain1_GThreadID(uint2 ID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain1_GThreadID 'void (uint2)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint2' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} +[numthreads(8,8,1)] +void CSMain2_GThreadID(uint3 ID : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain2_GThreadID 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 ID 'uint3' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} +[numthreads(8,8,1)] +void CSMain3_GThreadID(uint3 : SV_GroupThreadID) { +// CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain3_GThreadID 'void (uint3)' +// CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:30 'uint3' +// CHECK-NEXT: HLSLSV_GroupThreadIDAttr +} >From dc8d779f067e5b8d22e56036c4ba7320e297f339 Mon Sep 17 00:00:00 2001 From: Zhengxing Li <zhengxin...@microsoft.com> Date: Thu, 5 Dec 2024 10:54:54 -0800 Subject: [PATCH 2/3] Don't test the Group/Thread input IDs with mesh shader The SV_GroupIndex, SV_DispatchThreadID, SV_GroupID and SV_GroupThreadID are actually legal for meash shader stage. It shouldn't test them with mesh shader. This commit tests them with vertex shader and move the test into invalid_entry_parameter.hlsl which's a better place for it. --- clang/test/SemaHLSL/Semantics/entry_parameter.hlsl | 5 ----- .../test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl | 8 ++++++++ 2 files changed, 8 insertions(+), 5 deletions(-) diff --git a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl index 71d32cd13832e1..393d7300605c09 100644 --- a/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/entry_parameter.hlsl @@ -1,11 +1,6 @@ // RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-compute -hlsl-entry CSMain -x hlsl -finclude-default-header -ast-dump -o - %s | FileCheck %s -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.0-mesh -hlsl-entry CSMain -x hlsl -finclude-default-header -verify -o - %s [numthreads(8,8,1)] -// expected-error@+4 {{attribute 'SV_GroupIndex' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+3 {{attribute 'SV_DispatchThreadID' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+2 {{attribute 'SV_GroupID' is unsupported in 'mesh' shaders, requires compute}} -// expected-error@+1 {{attribute 'SV_GroupThreadID' is unsupported in 'mesh' shaders, requires compute}} void CSMain(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) { // CHECK: FunctionDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> line:[[@LINE-1]]:6 CSMain 'void (int, uint, uint, uint)' // CHECK-NEXT: ParmVarDecl 0x{{[0-9a-fA-F]+}} <{{.*}}> col:17 GI 'int' diff --git a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl index a24112c8e1bb8f..1bb4ee5182d621 100644 --- a/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl +++ b/clang/test/SemaHLSL/Semantics/invalid_entry_parameter.hlsl @@ -71,3 +71,11 @@ struct ST2_GThreadID { static uint GThreadID : SV_GroupThreadID; uint s_gthreadid : SV_GroupThreadID; }; + + +[shader("vertex")] +// expected-error@+4 {{attribute 'SV_GroupIndex' is unsupported in 'vertex' shaders, requires compute}} +// expected-error@+3 {{attribute 'SV_DispatchThreadID' is unsupported in 'vertex' shaders, requires compute}} +// expected-error@+2 {{attribute 'SV_GroupID' is unsupported in 'vertex' shaders, requires compute}} +// expected-error@+1 {{attribute 'SV_GroupThreadID' is unsupported in 'vertex' shaders, requires compute}} +void vs_main(int GI : SV_GroupIndex, uint ID : SV_DispatchThreadID, uint GID : SV_GroupID, uint GThreadID : SV_GroupThreadID) {} >From 28f823454873d4bc029f2ec57bed3a9707bbf1b2 Mon Sep 17 00:00:00 2001 From: Zhengxing Li <zhengxin...@microsoft.com> Date: Thu, 5 Dec 2024 15:06:01 -0800 Subject: [PATCH 3/3] [HLSL][SPIR-V] Add SV_GroupThreadID semantic support The HLSL SV_GroupThreadID semantic attribute is lowered into @llvm.spv.thread.id.in.group intrinsic in LLVM IR for SPIR-V target. In the SPIR-V backend, this is now correctly translated to a `LocalInvocationId` builtin variable. Fixes #70122 --- clang/lib/CodeGen/CGHLSLRuntime.cpp | 2 +- clang/lib/CodeGen/CGHLSLRuntime.h | 1 + .../semantics/SV_GroupThreadID.hlsl | 34 +++++---- llvm/include/llvm/IR/IntrinsicsSPIRV.td | 1 + .../Target/SPIRV/SPIRVInstructionSelector.cpp | 59 ++++++++++---- .../SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll | 76 +++++++++++++++++++ 6 files changed, 144 insertions(+), 29 deletions(-) create mode 100644 llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll diff --git a/clang/lib/CodeGen/CGHLSLRuntime.cpp b/clang/lib/CodeGen/CGHLSLRuntime.cpp index 19db7faddaeac0..fb15b1993e74ad 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.cpp +++ b/clang/lib/CodeGen/CGHLSLRuntime.cpp @@ -391,7 +391,7 @@ llvm::Value *CGHLSLRuntime::emitInputSemantic(IRBuilder<> &B, } if (D.hasAttr<HLSLSV_GroupThreadIDAttr>()) { llvm::Function *GroupThreadIDIntrinsic = - CGM.getIntrinsic(Intrinsic::dx_thread_id_in_group); + CGM.getIntrinsic(getGroupThreadIdIntrinsic()); return buildVectorInput(B, GroupThreadIDIntrinsic, Ty); } if (D.hasAttr<HLSLSV_GroupIDAttr>()) { diff --git a/clang/lib/CodeGen/CGHLSLRuntime.h b/clang/lib/CodeGen/CGHLSLRuntime.h index bb120c8b5e9e60..f9efb1bc996412 100644 --- a/clang/lib/CodeGen/CGHLSLRuntime.h +++ b/clang/lib/CodeGen/CGHLSLRuntime.h @@ -86,6 +86,7 @@ class CGHLSLRuntime { GENERATE_HLSL_INTRINSIC_FUNCTION(Step, step) GENERATE_HLSL_INTRINSIC_FUNCTION(Radians, radians) GENERATE_HLSL_INTRINSIC_FUNCTION(ThreadId, thread_id) + GENERATE_HLSL_INTRINSIC_FUNCTION(GroupThreadId, thread_id_in_group) GENERATE_HLSL_INTRINSIC_FUNCTION(FDot, fdot) GENERATE_HLSL_INTRINSIC_FUNCTION(SDot, sdot) GENERATE_HLSL_INTRINSIC_FUNCTION(UDot, udot) diff --git a/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl index 3533331c6f091c..3d347b973f39c8 100644 --- a/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl +++ b/clang/test/CodeGenHLSL/semantics/SV_GroupThreadID.hlsl @@ -1,32 +1,36 @@ -// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.3-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-DXIL -DTARGET=dx +// RUN: %clang_cc1 -triple spirv-linux-vulkan-library -x hlsl -emit-llvm -finclude-default-header -disable-llvm-passes -o - %s | FileCheck %s --check-prefixes=CHECK,CHECK-SPIRV -DTARGET=spv -// Make sure SV_GroupThreadID translated into dx.thread.id.in.group. +// Make sure SV_GroupThreadID translated into dx.thread.id.in.group for directx target and spv.thread.id.in.group for spirv target. -// CHECK: define void @foo() -// CHECK: %[[#ID:]] = call i32 @llvm.dx.thread.id.in.group(i32 0) -// CHECK: call void @{{.*}}foo{{.*}}(i32 %[[#ID]]) +// CHECK: define void @foo() +// CHECK: %[[#ID:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0) +// CHECK-DXIL: call void @{{.*}}foo{{.*}}(i32 %[[#ID]]) +// CHECK-SPIRV: call spir_func void @{{.*}}foo{{.*}}(i32 %[[#ID]]) [shader("compute")] [numthreads(8,8,1)] void foo(uint Idx : SV_GroupThreadID) {} -// CHECK: define void @bar() -// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0) -// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0 -// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1) -// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 -// CHECK: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]]) +// CHECK: define void @bar() +// CHECK: %[[#ID_X:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0) +// CHECK: %[[#ID_X_:]] = insertelement <2 x i32> poison, i32 %[[#ID_X]], i64 0 +// CHECK: %[[#ID_Y:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 1) +// CHECK: %[[#ID_XY:]] = insertelement <2 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 +// CHECK-DXIL: call void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]]) +// CHECK-SPIRV: call spir_func void @{{.*}}bar{{.*}}(<2 x i32> %[[#ID_XY]]) [shader("compute")] [numthreads(8,8,1)] void bar(uint2 Idx : SV_GroupThreadID) {} // CHECK: define void @test() -// CHECK: %[[#ID_X:]] = call i32 @llvm.dx.thread.id.in.group(i32 0) +// CHECK: %[[#ID_X:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 0) // CHECK: %[[#ID_X_:]] = insertelement <3 x i32> poison, i32 %[[#ID_X]], i64 0 -// CHECK: %[[#ID_Y:]] = call i32 @llvm.dx.thread.id.in.group(i32 1) +// CHECK: %[[#ID_Y:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 1) // CHECK: %[[#ID_XY:]] = insertelement <3 x i32> %[[#ID_X_]], i32 %[[#ID_Y]], i64 1 -// CHECK: %[[#ID_Z:]] = call i32 @llvm.dx.thread.id.in.group(i32 2) +// CHECK: %[[#ID_Z:]] = call i32 @llvm.[[TARGET]].thread.id.in.group(i32 2) // CHECK: %[[#ID_XYZ:]] = insertelement <3 x i32> %[[#ID_XY]], i32 %[[#ID_Z]], i64 2 -// CHECK: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]]) +// CHECK-DXIL: call void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]]) +// CHECK-SPIRV: call spir_func void @{{.*}}test{{.*}}(<3 x i32> %[[#ID_XYZ]]) [shader("compute")] [numthreads(8,8,1)] void test(uint3 Idx : SV_GroupThreadID) {} diff --git a/llvm/include/llvm/IR/IntrinsicsSPIRV.td b/llvm/include/llvm/IR/IntrinsicsSPIRV.td index 1ae3129774e507..fd0c3b2a59e1db 100644 --- a/llvm/include/llvm/IR/IntrinsicsSPIRV.td +++ b/llvm/include/llvm/IR/IntrinsicsSPIRV.td @@ -59,6 +59,7 @@ let TargetPrefix = "spv" in { // The following intrinsic(s) are mirrored from IntrinsicsDirectX.td for HLSL support. def int_spv_thread_id : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>; + def int_spv_thread_id_in_group : Intrinsic<[llvm_i32_ty], [llvm_i32_ty], [IntrNoMem, IntrWillReturn]>; def int_spv_all : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>; def int_spv_any : DefaultAttrsIntrinsic<[llvm_i1_ty], [llvm_any_ty], [IntrNoMem]>; def int_spv_cross : DefaultAttrsIntrinsic<[llvm_anyfloat_ty], [LLVMMatchType<0>, LLVMMatchType<0>], [IntrNoMem]>; diff --git a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp index 3a98b74b3d6757..9c831028523fbf 100644 --- a/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp +++ b/llvm/lib/Target/SPIRV/SPIRVInstructionSelector.cpp @@ -265,6 +265,9 @@ class SPIRVInstructionSelector : public InstructionSelector { bool selectSpvThreadId(Register ResVReg, const SPIRVType *ResType, MachineInstr &I) const; + bool selectSpvGroupThreadId(Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; + bool selectWaveOpInst(Register ResVReg, const SPIRVType *ResType, MachineInstr &I, unsigned Opcode) const; @@ -309,6 +312,9 @@ class SPIRVInstructionSelector : public InstructionSelector { SPIRVType *widenTypeToVec4(const SPIRVType *Type, MachineInstr &I) const; void extractSubvector(Register &ResVReg, const SPIRVType *ResType, Register &ReadReg, MachineInstr &InsertionPoint) const; + bool loadVec3BuiltinInputID(SPIRV::BuiltIn::BuiltIn BuiltInValue, + Register ResVReg, const SPIRVType *ResType, + MachineInstr &I) const; }; } // end anonymous namespace @@ -2852,6 +2858,8 @@ bool SPIRVInstructionSelector::selectIntrinsic(Register ResVReg, break; case Intrinsic::spv_thread_id: return selectSpvThreadId(ResVReg, ResType, I); + case Intrinsic::spv_thread_id_in_group: + return selectSpvGroupThreadId(ResVReg, ResType, I); case Intrinsic::spv_fdot: return selectFloatDot(ResVReg, ResType, I); case Intrinsic::spv_udot: @@ -3551,13 +3559,12 @@ bool SPIRVInstructionSelector::selectLog10(Register ResVReg, .constrainAllUses(TII, TRI, RBI); } -bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg, - const SPIRVType *ResType, - MachineInstr &I) const { - // DX intrinsic: @llvm.dx.thread.id(i32) - // ID Name Description - // 93 ThreadId reads the thread ID - +// Generate the instructions to load 3-element vector builtin input +// IDs/Indices. +// Like: SV_DispatchThreadID, SV_GroupThreadID, etc.... +bool SPIRVInstructionSelector::loadVec3BuiltinInputID( + SPIRV::BuiltIn::BuiltIn BuiltInValue, Register ResVReg, + const SPIRVType *ResType, MachineInstr &I) const { MachineIRBuilder MIRBuilder(I); const SPIRVType *U32Type = GR.getOrCreateSPIRVIntegerType(32, MIRBuilder); const SPIRVType *Vec3Ty = @@ -3565,16 +3572,16 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg, const SPIRVType *PtrType = GR.getOrCreateSPIRVPointerType( Vec3Ty, MIRBuilder, SPIRV::StorageClass::Input); - // Create new register for GlobalInvocationID builtin variable. + // Create new register for the input ID builtin variable. Register NewRegister = MIRBuilder.getMRI()->createVirtualRegister(&SPIRV::iIDRegClass); MIRBuilder.getMRI()->setType(NewRegister, LLT::pointer(0, 64)); GR.assignSPIRVTypeToVReg(PtrType, NewRegister, MIRBuilder.getMF()); - // Build GlobalInvocationID global variable with the necessary decorations. + // Build global variable with the necessary decorations for the input ID + // builtin variable. Register Variable = GR.buildGlobalVariable( - NewRegister, PtrType, - getLinkStringForBuiltIn(SPIRV::BuiltIn::GlobalInvocationId), nullptr, + NewRegister, PtrType, getLinkStringForBuiltIn(BuiltInValue), nullptr, SPIRV::StorageClass::Input, nullptr, true, true, SPIRV::LinkageType::Import, MIRBuilder, false); @@ -3591,12 +3598,12 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg, .addUse(GR.getSPIRVTypeID(Vec3Ty)) .addUse(Variable); - // Get Thread ID index. Expecting operand is a constant immediate value, + // Get the input ID index. Expecting operand is a constant immediate value, // wrapped in a type assignment. assert(I.getOperand(2).isReg()); const uint32_t ThreadId = foldImm(I.getOperand(2), MRI); - // Extract the thread ID from the loaded vector value. + // Extract the input ID from the loaded vector value. MachineBasicBlock &BB = *I.getParent(); auto MIB = BuildMI(BB, I, I.getDebugLoc(), TII.get(SPIRV::OpCompositeExtract)) .addDef(ResVReg) @@ -3606,6 +3613,32 @@ bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg, return Result && MIB.constrainAllUses(TII, TRI, RBI); } +bool SPIRVInstructionSelector::selectSpvThreadId(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + // DX intrinsic: @llvm.dx.thread.id(i32) + // ID Name Description + // 93 ThreadId reads the thread ID + // + // In SPIR-V, llvm.dx.thread.id maps to a `GlobalInvocationId` builtin + // variable + return loadVec3BuiltinInputID(SPIRV::BuiltIn::GlobalInvocationId, ResVReg, + ResType, I); +} + +bool SPIRVInstructionSelector::selectSpvGroupThreadId(Register ResVReg, + const SPIRVType *ResType, + MachineInstr &I) const { + // DX intrinsic: @llvm.dx.thread.id.in.group(i32) + // ID Name Description + // 95 GroupThreadId Reads the thread ID within the group + // + // In SPIR-V, llvm.dx.thread.id.in.group maps to a `LocalInvocationId` builtin + // variable + return loadVec3BuiltinInputID(SPIRV::BuiltIn::LocalInvocationId, ResVReg, + ResType, I); +} + SPIRVType *SPIRVInstructionSelector::widenTypeToVec4(const SPIRVType *Type, MachineInstr &I) const { MachineIRBuilder MIRBuilder(I); diff --git a/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll new file mode 100644 index 00000000000000..a88debf97fa7bb --- /dev/null +++ b/llvm/test/CodeGen/SPIRV/hlsl-intrinsics/SV_GroupThreadID.ll @@ -0,0 +1,76 @@ +; RUN: llc -O0 -verify-machineinstrs -mtriple=spirv-vulkan-unknown %s -o - | FileCheck %s +; RUN: %if spirv-tools %{ llc -O0 -mtriple=spirv-vulkan-unknown %s -o - -filetype=obj | spirv-val %} + +; This file generated from the following command: +; clang -cc1 -triple spirv-vulkan-library -x hlsl -emit-llvm -disable-llvm-passes -finclude-default-header - -o - <<EOF +; [shader("compute")] +; [numthreads(1,1,1)] +; void main(uint3 ID : SV_GroupThreadID) {} +; EOF + +; CHECK-DAG: %[[#int:]] = OpTypeInt 32 0 +; CHECK-DAG: %[[#v3int:]] = OpTypeVector %[[#int]] 3 +; CHECK-DAG: %[[#ptr_Input_v3int:]] = OpTypePointer Input %[[#v3int]] +; CHECK-DAG: %[[#tempvar:]] = OpUndef %[[#v3int]] +; CHECK-DAG: %[[#LocalInvocationId:]] = OpVariable %[[#ptr_Input_v3int]] Input + +; CHECK-DAG: OpEntryPoint GLCompute {{.*}} %[[#LocalInvocationId]] +; CHECK-DAG: OpName %[[#LocalInvocationId]] "__spirv_BuiltInLocalInvocationId" +; CHECK-DAG: OpDecorate %[[#LocalInvocationId]] LinkageAttributes "__spirv_BuiltInLocalInvocationId" Import +; CHECK-DAG: OpDecorate %[[#LocalInvocationId]] BuiltIn LocalInvocationId + +; ModuleID = '-' +source_filename = "-" +target datalayout = "e-i64:64-v16:16-v24:32-v32:32-v48:64-v96:128-v192:256-v256:256-v512:512-v1024:1024" +target triple = "spirv-unknown-vulkan-library" + +; Function Attrs: noinline norecurse nounwind optnone +define internal spir_func void @main(<3 x i32> noundef %ID) #0 { +entry: + %ID.addr = alloca <3 x i32>, align 16 + store <3 x i32> %ID, ptr %ID.addr, align 16 + ret void +} + +; Function Attrs: norecurse +define void @main.1() #1 { +entry: + +; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]] +; CHECK: %[[#load0:]] = OpCompositeExtract %[[#int]] %[[#load]] 0 + %0 = call i32 @llvm.spv.thread.id.in.group(i32 0) + +; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load0]] %[[#tempvar]] 0 + %1 = insertelement <3 x i32> poison, i32 %0, i64 0 + +; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]] +; CHECK: %[[#load1:]] = OpCompositeExtract %[[#int]] %[[#load]] 1 + %2 = call i32 @llvm.spv.thread.id.in.group(i32 1) + +; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load1]] %[[#tempvar]] 1 + %3 = insertelement <3 x i32> %1, i32 %2, i64 1 + +; CHECK: %[[#load:]] = OpLoad %[[#v3int]] %[[#LocalInvocationId]] +; CHECK: %[[#load2:]] = OpCompositeExtract %[[#int]] %[[#load]] 2 + %4 = call i32 @llvm.spv.thread.id.in.group(i32 2) + +; CHECK: %[[#tempvar:]] = OpCompositeInsert %[[#v3int]] %[[#load2]] %[[#tempvar]] 2 + %5 = insertelement <3 x i32> %3, i32 %4, i64 2 + + call void @main(<3 x i32> %5) + ret void +} + +; Function Attrs: nounwind willreturn memory(none) +declare i32 @llvm.spv.thread.id.in.group(i32) #2 + +attributes #0 = { noinline norecurse nounwind optnone "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #1 = { norecurse "hlsl.numthreads"="1,1,1" "hlsl.shader"="compute" "no-trapping-math"="true" "stack-protector-buffer-size"="8" } +attributes #2 = { nounwind willreturn memory(none) } + +!llvm.module.flags = !{!0, !1} +!llvm.ident = !{!2} + +!0 = !{i32 1, !"wchar_size", i32 4} +!1 = !{i32 4, !"dx.disable_optimizations", i32 1} +!2 = !{!"clang version 19.0.0git (g...@github.com:llvm/llvm-project.git 91600507765679e92434ec7c5edb883bf01f847f)"} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits