https://github.com/python3kgae created https://github.com/llvm/llvm-project/pull/101240
First step for support WaveSize attribute in https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html A new attribute HLSLWaveSizeAttr was supported in the AST. Implement both the wave size and the wave size range, rather than separately which might require more work. For #70118 >From c7a476a4d8b06e399e9c076cc15208871e1b5a25 Mon Sep 17 00:00:00 2001 From: Xiang Li <python3k...@outlook.com> Date: Tue, 30 Jul 2024 16:34:40 -0400 Subject: [PATCH] [HLSL] AST support for WaveSize attribute. First step for support WaveSize attribute in https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html A new attribute HLSLWaveSizeAttr was supported in the AST. Implement both the wave size and the wave size range, rather than separately which might require more work. For #70118 --- clang/include/clang/Basic/Attr.td | 16 +++ clang/include/clang/Basic/AttrDocs.td | 37 ++++++ clang/include/clang/Basic/DiagnosticGroups.td | 3 + .../clang/Basic/DiagnosticSemaKinds.td | 15 +++ clang/include/clang/Sema/SemaHLSL.h | 4 + clang/lib/Sema/SemaDecl.cpp | 4 + clang/lib/Sema/SemaDeclAttr.cpp | 3 + clang/lib/Sema/SemaHLSL.cpp | 118 +++++++++++++++++- clang/test/AST/HLSL/WaveSize.hlsl | 25 ++++ .../test/SemaHLSL/WaveSize-invalid-param.hlsl | 101 +++++++++++++++ .../SemaHLSL/WaveSize-invalid-profiles.hlsl | 20 +++ clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl | 24 ++++ .../include/llvm/Frontend/HLSL/HLSLWaveSize.h | 94 ++++++++++++++ llvm/include/llvm/Support/DXILABI.h | 3 + 14 files changed, 465 insertions(+), 2 deletions(-) create mode 100644 clang/test/AST/HLSL/WaveSize.hlsl create mode 100644 clang/test/SemaHLSL/WaveSize-invalid-param.hlsl create mode 100644 clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl create mode 100644 clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl create mode 100644 llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 46d0a66d59c37..8b2f8358aec28 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4625,6 +4625,22 @@ def HLSLParamModifier : TypeAttr { let Args = [DefaultBoolArgument<"MergedSpelling", /*default*/0, /*fake*/1>]; } +def HLSLWaveSize: InheritableAttr { + let Spellings = [Microsoft<"WaveSize">]; + let Args = [IntArgument<"Min">, DefaultIntArgument<"Max", 0>, DefaultIntArgument<"Preferred", 0>]; + let Subjects = SubjectList<[HLSLEntry]>; + let LangOpts = [HLSL]; + let AdditionalMembers = [{ + private: + int SpelledArgsCount = 0; + + public: + void setSpelledArgsCount(int C) { SpelledArgsCount = C; } + int getSpelledArgsCount() const { return SpelledArgsCount; } + }]; + let Documentation = [WaveSizeDocs]; +} + def RandomizeLayout : InheritableAttr { let Spellings = [GCC<"randomize_layout">]; let Subjects = SubjectList<[Record]>; diff --git a/clang/include/clang/Basic/AttrDocs.td b/clang/include/clang/Basic/AttrDocs.td index 4b8d520d73893..e3c98912c81f4 100644 --- a/clang/include/clang/Basic/AttrDocs.td +++ b/clang/include/clang/Basic/AttrDocs.td @@ -7322,6 +7322,43 @@ flag. }]; } +def WaveSizeDocs : Documentation { + let Category = DocCatFunction; + let Content = [{ +The ``WaveSize`` attribute specify a wave size on a shader entry point in order +to indicate either that a shader depends on or strongly prefers a specific wave +size. +There're 2 versions of the attribute: ``WaveSize`` and ``RangedWaveSize``. +The syntax for ``WaveSize`` is: + +.. code-block:: text + + ``[WaveSize(<numLanes>)]`` + +The allowed wave sizes that an HLSL shader may specify are the powers of 2 +between 4 and 128, inclusive. +In other words, the set: [4, 8, 16, 32, 64, 128]. + +The syntax for ``RangedWaveSize`` is: + +.. code-block:: text + + ``[WaveSize(<minWaveSize>, <maxWaveSize>, [prefWaveSize])]`` + +Where minWaveSize is the minimum wave size supported by the shader representing +the beginning of the allowed range, maxWaveSize is the maximum wave size +supported by the shader representing the end of the allowed range, and +prefWaveSize is the optional preferred wave size representing the size expected +to be the most optimal for this shader. + +``WaveSize`` is available for HLSL shader model 6.6 and later. +``RangedWaveSize`` available for HLSL shader model 6.8 and later. + +The full documentation is available here: https://microsoft.github.io/DirectX-Specs/d3d/HLSL_SM_6_6_WaveSize.html +and https://microsoft.github.io/hlsl-specs/proposals/0013-wave-size-range.html + }]; +} + def NumThreadsDocs : Documentation { let Category = DocCatFunction; let Content = [{ diff --git a/clang/include/clang/Basic/DiagnosticGroups.td b/clang/include/clang/Basic/DiagnosticGroups.td index 19c3f1e043349..122b95e9f9a2e 100644 --- a/clang/include/clang/Basic/DiagnosticGroups.td +++ b/clang/include/clang/Basic/DiagnosticGroups.td @@ -1547,6 +1547,9 @@ def DXILValidation : DiagGroup<"dxil-validation">; // Warning for HLSL API availability def HLSLAvailability : DiagGroup<"hlsl-availability">; +// Warning for HLSL Attributes on Statement. +def HLSLAttributeStatement : DiagGroup<"attribute-statement">; + // Warnings and notes related to const_var_decl_type attribute checks def ReadOnlyPlacementChecks : DiagGroup<"read-only-types">; diff --git a/clang/include/clang/Basic/DiagnosticSemaKinds.td b/clang/include/clang/Basic/DiagnosticSemaKinds.td index 581434d33c5c9..9010812837d42 100644 --- a/clang/include/clang/Basic/DiagnosticSemaKinds.td +++ b/clang/include/clang/Basic/DiagnosticSemaKinds.td @@ -12361,6 +12361,21 @@ def warn_hlsl_availability_unavailable : def err_hlsl_export_not_on_function : Error< "export declaration can only be used on functions">; +def err_hlsl_attribute_in_wrong_shader_model: Error< + "attribute %0 requires shader model %1 or greater">; + +def err_hlsl_wavesize_size: Error< + "wavesize arguments must be between 4 and 128 and a power of 2">; +def err_hlsl_wavesize_min_geq_max: Error< + "minimum wavesize value %0 must be less than maximum wavesize value %1">; +def warn_hlsl_wavesize_min_eq_max: Warning< + "wave size range minimum and maximum are equal">, + InGroup<HLSLAttributeStatement>, DefaultError; +def err_hlsl_wavesize_pref_size_out_of_range: Error< + "preferred wavesize value %0 must be between %1 and %2">; +def err_hlsl_wavesize_insufficient_shader_model: Error< + "wavesize only takes multiple arguments in shader model 6.8 or higher">; + // Layout randomization diagnostics. def err_non_designated_init_used : Error< "a randomized struct can only be initialized with a designated initializer">; diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 2ddbee67c414b..a4d76818d29d2 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -38,6 +38,9 @@ class SemaHLSL : public SemaBase { HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D, const AttributeCommonInfo &AL, int X, int Y, int Z); + HLSLWaveSizeAttr *mergeWaveSizeAttr(Decl *D, const AttributeCommonInfo &AL, + int Min, int Max, int Preferred, + int SpelledArgsCount); HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, llvm::Triple::EnvironmentType ShaderType); HLSLParamModifierAttr * @@ -53,6 +56,7 @@ class SemaHLSL : public SemaBase { void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU); void handleNumThreadsAttr(Decl *D, const ParsedAttr &AL); + void handleWaveSizeAttr(Decl *D, const ParsedAttr &AL); void handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL); void handlePackOffsetAttr(Decl *D, const ParsedAttr &AL); void handleShaderAttr(Decl *D, const ParsedAttr &AL); diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 694a754646f27..c9a7c9e54d13c 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -2862,6 +2862,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D, else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr)) NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ()); + else if (const auto *NT = dyn_cast<HLSLWaveSizeAttr>(Attr)) + NewAttr = + S.HLSL().mergeWaveSizeAttr(D, *NT, NT->getMin(), NT->getMax(), + NT->getPreferred(), NT->getSpelledArgsCount()); else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr)) NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType()); else if (isa<SuppressAttr>(Attr)) diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index 98e3df9083516..57ae83be12881 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -6887,6 +6887,9 @@ ProcessDeclAttribute(Sema &S, Scope *scope, Decl *D, const ParsedAttr &AL, case ParsedAttr::AT_HLSLNumThreads: S.HLSL().handleNumThreadsAttr(D, AL); break; + case ParsedAttr::AT_HLSLWaveSize: + S.HLSL().handleWaveSizeAttr(D, AL); + break; case ParsedAttr::AT_HLSLSV_GroupIndex: handleSimpleAttribute<HLSLSV_GroupIndexAttr>(S, D, AL); break; diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 9940bc5b4a606..d386897d8251e 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -20,7 +20,9 @@ #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" #include "llvm/ADT/StringRef.h" +#include "llvm/Frontend/HLSL/HLSLWaveSize.h" #include "llvm/Support/Casting.h" +#include "llvm/Support/DXILABI.h" #include "llvm/Support/ErrorHandling.h" #include "llvm/TargetParser/Triple.h" #include <iterator> @@ -144,6 +146,25 @@ HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D, HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z); } +HLSLWaveSizeAttr *SemaHLSL::mergeWaveSizeAttr(Decl *D, + const AttributeCommonInfo &AL, + int Min, int Max, int Preferred, + int SpelledArgsCount) { + if (HLSLWaveSizeAttr *NT = D->getAttr<HLSLWaveSizeAttr>()) { + if (NT->getMin() != Min || NT->getMax() != Max || + NT->getPreferred() != Preferred || + NT->getSpelledArgsCount() != SpelledArgsCount) { + Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; + Diag(AL.getLoc(), diag::note_conflicting_attribute); + } + return nullptr; + } + HLSLWaveSizeAttr *Result = ::new (getASTContext()) + HLSLWaveSizeAttr(getASTContext(), AL, Min, Max, Preferred); + Result->setSpelledArgsCount(SpelledArgsCount); + return Result; +} + HLSLShaderAttr * SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, llvm::Triple::EnvironmentType ShaderType) { @@ -215,7 +236,8 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); assert(ShaderAttr && "Entry point has no shader attribute"); llvm::Triple::EnvironmentType ST = ShaderAttr->getType(); - + auto &TargetInfo = getASTContext().getTargetInfo(); + VersionTuple Ver = TargetInfo.getTriple().getOSVersion(); switch (ST) { case llvm::Triple::Pixel: case llvm::Triple::Vertex: @@ -235,6 +257,13 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { llvm::Triple::Mesh}); FD->setInvalidDecl(); } + if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) { + DiagnoseAttrStageMismatch(NT, ST, + {llvm::Triple::Compute, + llvm::Triple::Amplification, + llvm::Triple::Mesh}); + FD->setInvalidDecl(); + } break; case llvm::Triple::Compute: @@ -245,6 +274,20 @@ void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { << llvm::Triple::getEnvironmentTypeName(ST); FD->setInvalidDecl(); } + if (const auto *NT = FD->getAttr<HLSLWaveSizeAttr>()) { + if (Ver.getMajor() < 6u || + (Ver.getMajor() == 6u && Ver.getMinor() < 6u)) { + Diag(NT->getLocation(), diag::err_hlsl_attribute_in_wrong_shader_model) + << "wavesize" + << "6.6"; + FD->setInvalidDecl(); + } else if (NT->getSpelledArgsCount() > 1 && + (Ver.getMajor() == 6u && Ver.getMinor() < 8u)) { + Diag(NT->getLocation(), + diag::err_hlsl_wavesize_insufficient_shader_model); + FD->setInvalidDecl(); + } + } break; default: llvm_unreachable("Unhandled environment in triple"); @@ -348,6 +391,77 @@ void SemaHLSL::handleNumThreadsAttr(Decl *D, const ParsedAttr &AL) { D->addAttr(NewAttr); } +void SemaHLSL::handleWaveSizeAttr(Decl *D, const ParsedAttr &AL) { + // validate that the wavesize argument is a power of 2 between 4 and 128 + // inclusive + unsigned SpelledArgsCount = AL.getNumArgs(); + if (SpelledArgsCount == 0 || SpelledArgsCount > 3) + return; + + uint32_t Min; + if (!SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(0), Min)) + return; + + uint32_t Max = 0; + if (SpelledArgsCount > 1 && + !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(1), Max)) + return; + + uint32_t Preferred = 0; + if (SpelledArgsCount > 2 && + !SemaRef.checkUInt32Argument(AL, AL.getArgAsExpr(2), Preferred)) + return; + llvm::hlsl::WaveSize WaveSize(Min, Max, Preferred); + llvm::hlsl::WaveSize::ValidationResult ValidationResult = WaveSize.validate(); + // WaveSize validation succeeds when not defined, but since we have an + // attribute, this means min was zero, which is invalid for min. + if (ValidationResult == llvm::hlsl::WaveSize::ValidationResult::Success && + !WaveSize.isDefined()) + ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidMin; + + // It is invalid to explicitly specify degenerate cases. + if (SpelledArgsCount > 1 && WaveSize.Max == 0) + ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidMax; + else if (SpelledArgsCount > 2 && WaveSize.Preferred == 0) + ValidationResult = llvm::hlsl::WaveSize::ValidationResult::InvalidPreferred; + + switch (ValidationResult) { + case llvm::hlsl::WaveSize::ValidationResult::Success: + break; + case llvm::hlsl::WaveSize::ValidationResult::InvalidMin: + case llvm::hlsl::WaveSize::ValidationResult::InvalidMax: + case llvm::hlsl::WaveSize::ValidationResult::InvalidPreferred: + case llvm::hlsl::WaveSize::ValidationResult::NoRangeOrMin: + Diag(AL.getLoc(), diag::err_hlsl_wavesize_size) + << llvm::dxil::MinWaveSize << llvm::dxil::MaxWaveSize; + break; + case llvm::hlsl::WaveSize::ValidationResult::MaxEqualsMin: + Diag(AL.getLoc(), diag::warn_hlsl_wavesize_min_eq_max) + << WaveSize.Min << WaveSize.Max; + break; + case llvm::hlsl::WaveSize::ValidationResult::MaxLessThanMin: + Diag(AL.getLoc(), diag::err_hlsl_wavesize_min_geq_max) + << WaveSize.Min << WaveSize.Max; + break; + case llvm::hlsl::WaveSize::ValidationResult::PreferredOutOfRange: + Diag(AL.getLoc(), diag::err_hlsl_wavesize_pref_size_out_of_range) + << WaveSize.Preferred << WaveSize.Min << WaveSize.Max; + break; + case llvm::hlsl::WaveSize::ValidationResult::MaxOrPreferredWhenUndefined: + case llvm::hlsl::WaveSize::ValidationResult::PreferredWhenNoRange: + llvm_unreachable("Should have hit InvalidMax or InvalidPreferred instead."); + break; + } + + if (ValidationResult != llvm::hlsl::WaveSize::ValidationResult::Success) + return; + + HLSLWaveSizeAttr *NewAttr = + mergeWaveSizeAttr(D, AL, Min, Max, Preferred, SpelledArgsCount); + if (NewAttr) + D->addAttr(NewAttr); +} + static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) { if (!T->hasUnsignedIntegerRepresentation()) return false; @@ -356,7 +470,7 @@ static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) { return true; } -void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { +void SemaHLSL::handleSV_DispatchThreadIDAttr(Decl *D, const ParsedAttr &AL) { auto *VD = cast<ValueDecl>(D); if (!isLegalTypeForHLSLSV_DispatchThreadID(VD->getType())) { Diag(AL.getLoc(), diag::err_hlsl_attr_invalid_type) diff --git a/clang/test/AST/HLSL/WaveSize.hlsl b/clang/test/AST/HLSL/WaveSize.hlsl new file mode 100644 index 0000000000000..fd6dc7c94d6d0 --- /dev/null +++ b/clang/test/AST/HLSL/WaveSize.hlsl @@ -0,0 +1,25 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-library -x hlsl -ast-dump -o - %s | FileCheck %s + +// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w0 'void ()' +// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 128 0 0 + [numthreads(8,8,1)] + [WaveSize(128)] + void w0() { + } + + + +// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w1 'void ()' +// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 64 0 + [numthreads(8,8,1)] + [WaveSize(8, 64)] + void w1() { + } + + +// CHECK-LABLE:FunctionDecl 0x{{[0-9a-f]+}} <{{.*}}> w2 'void ()' +// CHECK:HLSLWaveSizeAttr 0x{{[0-9a-f]+}} <{{.*}}> 8 128 64 + [numthreads(8,8,1)] + [WaveSize(8, 128, 64)] + void w2() { + } diff --git a/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl new file mode 100644 index 0000000000000..10c562839eef6 --- /dev/null +++ b/clang/test/SemaHLSL/WaveSize-invalid-param.hlsl @@ -0,0 +1,101 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-library -x hlsl %s -verify + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}} +[WaveSize(1)] +void e0() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}} +[WaveSize(4, 2)] +void e1() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}} +[WaveSize(4, 8, 7)] +void e2() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{minimum wavesize value 16 must be less than maximum wavesize value 8}} +[WaveSize(16, 8)] +void e3() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{preferred wavesize value 8 must be between 16 and 128}} +[WaveSize(16, 128, 8)] +void e4() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{preferred wavesize value 32 must be between 8 and 16}} +[WaveSize(8, 16, 32)] +void e5() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}} +[WaveSize(4, 0)] +void e6() { +} + + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}} +[WaveSize(4, 4, 0)] +void e7() { +} + + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{wave size range minimum and maximum are equal}} +[WaveSize(16, 16)] +void e8() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}} +[WaveSize(0)] +void e9() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{wavesize arguments must be between 4 and 128 and a power of 2}} +[WaveSize(-4)] +void e10() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{'WaveSize' attribute takes no more than 3 arguments}} +[WaveSize(16, 128, 64, 64)] +void e11() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{'WaveSize' attribute takes at least 1 argument}} +[WaveSize()] +void e12() { +} + +[shader("compute")] +[numthreads(1,1,1)] +// expected-error@+1 {{'WaveSize' attribute takes at least 1 argument}} +[WaveSize] +void e13() { +} diff --git a/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl b/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl new file mode 100644 index 0000000000000..13e27a5c4b685 --- /dev/null +++ b/clang/test/SemaHLSL/WaveSize-invalid-profiles.hlsl @@ -0,0 +1,20 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-pixel -x hlsl %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-vertex -x hlsl %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-geometry -x hlsl %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-hull -x hlsl %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.8-domain -x hlsl %s -verify + +#if __SHADER_TARGET_STAGE == __SHADER_STAGE_PIXEL +// expected-error@+10 {{attribute 'WaveSize' is unsupported in 'pixel' shaders, requires one of the following: compute, amplification, mesh}} +#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_VERTEX +// expected-error@+8 {{attribute 'WaveSize' is unsupported in 'vertex' shaders, requires one of the following: compute, amplification, mesh}} +#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_GEOMETRY +// expected-error@+6 {{attribute 'WaveSize' is unsupported in 'geometry' shaders, requires one of the following: compute, amplification, mesh}} +#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_HULL +// expected-error@+4 {{attribute 'WaveSize' is unsupported in 'hull' shaders, requires one of the following: compute, amplification, mesh}} +#elif __SHADER_TARGET_STAGE == __SHADER_STAGE_DOMAIN +// expected-error@+2 {{attribute 'WaveSize' is unsupported in 'domain' shaders, requires one of the following: compute, amplification, mesh}} +#endif +[WaveSize(16)] +void main() { +} diff --git a/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl b/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl new file mode 100644 index 0000000000000..fb9978c6ce3ce --- /dev/null +++ b/clang/test/SemaHLSL/WaveSize-sm6.6-6.5.hlsl @@ -0,0 +1,24 @@ +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.6-library -x hlsl %s -verify +// RUN: %clang_cc1 -triple dxil-pc-shadermodel6.5-library -x hlsl %s -verify + +[shader("compute")] +[numthreads(1,1,1)] +#if __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 6 +// expected-error@+4 {{wavesize only takes multiple arguments in shader model 6.8 or higher}} +#elif __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 5 +// expected-error@+2 {{attribute wavesize requires shader model 6.6 or greater}} +#endif +[WaveSize(4, 16)] +void e0() { +} + +[shader("compute")] +[numthreads(1,1,1)] +#if __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 6 +// expected-error@+4 {{wavesize only takes multiple arguments in shader model 6.8 or higher}} +#elif __SHADER_TARGET_MAJOR == 6 && __SHADER_TARGET_MINOR == 5 +// expected-error@+2 {{attribute wavesize requires shader model 6.6 or greater}} +#endif +[WaveSize(4, 16)] +void e1() { +} diff --git a/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h b/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h new file mode 100644 index 0000000000000..ec8f22f58e1ad --- /dev/null +++ b/llvm/include/llvm/Frontend/HLSL/HLSLWaveSize.h @@ -0,0 +1,94 @@ +//===- HLSLResource.h - HLSL Resource helper objects ----------------------===// +// +// Part of the LLVM Project, under the Apache License v2.0 with LLVM Exceptions. +// See https://llvm.org/LICENSE.txt for license information. +// SPDX-License-Identifier: Apache-2.0 WITH LLVM-exception +// +//===----------------------------------------------------------------------===// +/// +/// \file This file contains helper objects for working with HLSL WaveSize. +/// +//===----------------------------------------------------------------------===// + +#ifndef LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H +#define LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H + +namespace llvm { +namespace hlsl { + +// SM 6.6 allows WaveSize specification for only a single required size. +// SM 6.8+ allows specification of WaveSize as a min, max and preferred value. +struct WaveSize { + unsigned Min = 0; + unsigned Max = 0; + unsigned Preferred = 0; + + WaveSize() = default; + WaveSize(unsigned Min, unsigned Max = 0, unsigned Preferred = 0) + : Min(Min), Max(Max), Preferred(Preferred) {} + WaveSize(const WaveSize &) = default; + WaveSize &operator=(const WaveSize &) = default; + bool operator==(const WaveSize &Other) const { + return Min == Other.Min && Max == Other.Max && Preferred == Other.Preferred; + }; + + // Valid non-zero values are powers of 2 between 4 and 128, inclusive. + static bool isValidValue(unsigned Value) { + return (Value >= 4 && Value <= 128 && ((Value & (Value - 1)) == 0)); + } + // Valid representations: + // (not to be confused with encodings in metadata, PSV0, or RDAT) + // 0, 0, 0: Not defined + // Min, 0, 0: single WaveSize (SM 6.6/6.7) + // (single WaveSize is represented in metadata with the single Min value) + // Min, Max (> Min), 0 or Preferred (>= Min and <= Max): Range (SM 6.8+) + // (WaveSizeRange represenation in metadata is the same) + enum class ValidationResult { + Success, + InvalidMin, + InvalidMax, + InvalidPreferred, + MaxOrPreferredWhenUndefined, + PreferredWhenNoRange, + MaxEqualsMin, + MaxLessThanMin, + PreferredOutOfRange, + NoRangeOrMin, + }; + ValidationResult validate() const { + if (Min == 0) { // Not defined + if (Max != 0 || Preferred != 0) + return ValidationResult::MaxOrPreferredWhenUndefined; + else + // all 3 parameters are 0 + return ValidationResult::NoRangeOrMin; + } else if (!isValidValue(Min)) { + return ValidationResult::InvalidMin; + } else if (Max == 0) { // single WaveSize (SM 6.6/6.7) + if (Preferred != 0) + return ValidationResult::PreferredWhenNoRange; + } else if (!isValidValue(Max)) { + return ValidationResult::InvalidMax; + } else if (Min == Max) { + return ValidationResult::MaxEqualsMin; + } else if (Max < Min) { + return ValidationResult::MaxLessThanMin; + } else if (Preferred != 0) { + if (!isValidValue(Preferred)) + return ValidationResult::InvalidPreferred; + if (Preferred < Min || Preferred > Max) + return ValidationResult::PreferredOutOfRange; + } + return ValidationResult::Success; + } + bool isValid() const { return validate() == ValidationResult::Success; } + + bool isDefined() const { return Min != 0; } + bool isRange() const { return Max != 0; } + bool hasPreferred() const { return Preferred != 0; } +}; + +} // namespace hlsl +} // namespace llvm + +#endif // LLVM_FRONTEND_HLSL_HLSLWAVESIZE_H diff --git a/llvm/include/llvm/Support/DXILABI.h b/llvm/include/llvm/Support/DXILABI.h index a2222eec09ba8..8e86ba119345e 100644 --- a/llvm/include/llvm/Support/DXILABI.h +++ b/llvm/include/llvm/Support/DXILABI.h @@ -113,6 +113,9 @@ enum class SamplerFeedbackType : uint32_t { MipRegionUsed = 1, }; +const unsigned MinWaveSize = 4; +const unsigned MaxWaveSize = 128; + } // namespace dxil } // namespace llvm _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits