https://github.com/Endilll created https://github.com/llvm/llvm-project/pull/88354
A follow-up to #87912. I'm moving more HLSL-related functions from `Sema` to `SemaHLSL`. I'm also dropping `HLSL` from their names in the process. >From ecff8db824552872ba055fdc0bca42b1a0386c39 Mon Sep 17 00:00:00 2001 From: Vlad Serebrennikov <serebrennikov.vladis...@gmail.com> Date: Thu, 11 Apr 2024 07:56:46 +0300 Subject: [PATCH] [clang][NFC] Move more functions to `SemaHLSL` --- clang/include/clang/Sema/Sema.h | 15 --- clang/include/clang/Sema/SemaHLSL.h | 27 +++- clang/lib/Parse/ParseHLSL.cpp | 10 +- clang/lib/Sema/SemaDecl.cpp | 130 +------------------ clang/lib/Sema/SemaDeclAttr.cpp | 54 +------- clang/lib/Sema/SemaHLSL.cpp | 186 +++++++++++++++++++++++++++- 6 files changed, 218 insertions(+), 204 deletions(-) diff --git a/clang/include/clang/Sema/Sema.h b/clang/include/clang/Sema/Sema.h index e3e255a0dd76f8..e904cd3ad13fb7 100644 --- a/clang/include/clang/Sema/Sema.h +++ b/clang/include/clang/Sema/Sema.h @@ -2940,13 +2940,6 @@ class Sema final : public SemaBase { QualType NewT, QualType OldT); void CheckMain(FunctionDecl *FD, const DeclSpec &D); void CheckMSVCRTEntryPoint(FunctionDecl *FD); - void ActOnHLSLTopLevelFunction(FunctionDecl *FD); - void CheckHLSLEntryPoint(FunctionDecl *FD); - void CheckHLSLSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param, - const HLSLAnnotationAttr *AnnotationAttr); - void DiagnoseHLSLAttrStageMismatch( - const Attr *A, HLSLShaderAttr::ShaderType Stage, - std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages); Attr *getImplicitCodeSegOrSectionAttrForFunction(const FunctionDecl *FD, bool IsDefinition); void CheckFunctionOrTemplateParamDeclarator(Scope *S, Declarator &D); @@ -3707,14 +3700,6 @@ class Sema final : public SemaBase { StringRef UuidAsWritten, MSGuidDecl *GuidDecl); BTFDeclTagAttr *mergeBTFDeclTagAttr(Decl *D, const BTFDeclTagAttr &AL); - HLSLNumThreadsAttr *mergeHLSLNumThreadsAttr(Decl *D, - const AttributeCommonInfo &AL, - int X, int Y, int Z); - HLSLShaderAttr *mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL, - HLSLShaderAttr::ShaderType ShaderType); - HLSLParamModifierAttr * - mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, - HLSLParamModifierAttr::Spelling Spelling); WebAssemblyImportNameAttr * mergeImportNameAttr(Decl *D, const WebAssemblyImportNameAttr &AL); diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index acc675963c23a5..34acaf19517f2a 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -13,12 +13,16 @@ #ifndef LLVM_CLANG_SEMA_SEMAHLSL_H #define LLVM_CLANG_SEMA_SEMAHLSL_H +#include "clang/AST/Attr.h" +#include "clang/AST/Decl.h" #include "clang/AST/DeclBase.h" #include "clang/AST/Expr.h" +#include "clang/Basic/AttributeCommonInfo.h" #include "clang/Basic/IdentifierTable.h" #include "clang/Basic/SourceLocation.h" #include "clang/Sema/Scope.h" #include "clang/Sema/SemaBase.h" +#include <initializer_list> namespace clang { @@ -26,10 +30,25 @@ class SemaHLSL : public SemaBase { public: SemaHLSL(Sema &S); - Decl *ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer, - SourceLocation KwLoc, IdentifierInfo *Ident, - SourceLocation IdentLoc, SourceLocation LBrace); - void ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace); + Decl *ActOnStartBuffer(Scope *BufferScope, bool CBuffer, SourceLocation KwLoc, + IdentifierInfo *Ident, SourceLocation IdentLoc, + SourceLocation LBrace); + void ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace); + HLSLNumThreadsAttr *mergeNumThreadsAttr(Decl *D, + const AttributeCommonInfo &AL, int X, + int Y, int Z); + HLSLShaderAttr *mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, + HLSLShaderAttr::ShaderType ShaderType); + HLSLParamModifierAttr * + mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, + HLSLParamModifierAttr::Spelling Spelling); + void ActOnTopLevelFunction(FunctionDecl *FD); + void CheckEntryPoint(FunctionDecl *FD); + void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param, + const HLSLAnnotationAttr *AnnotationAttr); + void DiagnoseAttrStageMismatch( + const Attr *A, HLSLShaderAttr::ShaderType Stage, + std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages); }; } // namespace clang diff --git a/clang/lib/Parse/ParseHLSL.cpp b/clang/lib/Parse/ParseHLSL.cpp index 5afc958600fa55..d97985d42369ad 100644 --- a/clang/lib/Parse/ParseHLSL.cpp +++ b/clang/lib/Parse/ParseHLSL.cpp @@ -72,9 +72,9 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) { return nullptr; } - Decl *D = Actions.HLSL().ActOnStartHLSLBuffer( - getCurScope(), IsCBuffer, BufferLoc, Identifier, IdentifierLoc, - T.getOpenLocation()); + Decl *D = Actions.HLSL().ActOnStartBuffer(getCurScope(), IsCBuffer, BufferLoc, + Identifier, IdentifierLoc, + T.getOpenLocation()); while (Tok.isNot(tok::r_brace) && Tok.isNot(tok::eof)) { // FIXME: support attribute on constants inside cbuffer/tbuffer. @@ -88,7 +88,7 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) { T.skipToEnd(); DeclEnd = T.getCloseLocation(); BufferScope.Exit(); - Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd); + Actions.HLSL().ActOnFinishBuffer(D, DeclEnd); return nullptr; } } @@ -96,7 +96,7 @@ Decl *Parser::ParseHLSLBuffer(SourceLocation &DeclEnd) { T.consumeClose(); DeclEnd = T.getCloseLocation(); BufferScope.Exit(); - Actions.HLSL().ActOnFinishHLSLBuffer(D, DeclEnd); + Actions.HLSL().ActOnFinishBuffer(D, DeclEnd); Actions.ProcessDeclAttributeList(Actions.CurScope, D, Attrs); return D; diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 8472aaeb6bad97..3beb4fb1f8c733 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -45,6 +45,7 @@ #include "clang/Sema/ParsedTemplate.h" #include "clang/Sema/Scope.h" #include "clang/Sema/ScopeInfo.h" +#include "clang/Sema/SemaHLSL.h" #include "clang/Sema/SemaInternal.h" #include "clang/Sema/Template.h" #include "llvm/ADT/SmallString.h" @@ -2972,10 +2973,10 @@ static bool mergeDeclAttribute(Sema &S, NamedDecl *D, else if (const auto *BTFA = dyn_cast<BTFDeclTagAttr>(Attr)) NewAttr = S.mergeBTFDeclTagAttr(D, *BTFA); else if (const auto *NT = dyn_cast<HLSLNumThreadsAttr>(Attr)) - NewAttr = - S.mergeHLSLNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), NT->getZ()); + NewAttr = S.HLSL().mergeNumThreadsAttr(D, *NT, NT->getX(), NT->getY(), + NT->getZ()); else if (const auto *SA = dyn_cast<HLSLShaderAttr>(Attr)) - NewAttr = S.mergeHLSLShaderAttr(D, *SA, SA->getType()); + NewAttr = S.HLSL().mergeShaderAttr(D, *SA, SA->getType()); else if (isa<SuppressAttr>(Attr)) // Do nothing. Each redeclaration should be suppressed separately. NewAttr = nullptr; @@ -10809,10 +10810,10 @@ Sema::ActOnFunctionDeclarator(Scope *S, Declarator &D, DeclContext *DC, if (getLangOpts().HLSL && D.isFunctionDefinition()) { // Any top level function could potentially be specified as an entry. if (!NewFD->isInvalidDecl() && S->getDepth() == 0 && Name.isIdentifier()) - ActOnHLSLTopLevelFunction(NewFD); + HLSL().ActOnTopLevelFunction(NewFD); if (NewFD->hasAttr<HLSLShaderAttr>()) - CheckHLSLEntryPoint(NewFD); + HLSL().CheckEntryPoint(NewFD); } // If this is the first declaration of a library builtin function, add @@ -12660,125 +12661,6 @@ void Sema::CheckMSVCRTEntryPoint(FunctionDecl *FD) { } } -void Sema::ActOnHLSLTopLevelFunction(FunctionDecl *FD) { - auto &TargetInfo = getASTContext().getTargetInfo(); - - if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) - return; - - StringRef Env = TargetInfo.getTriple().getEnvironmentName(); - HLSLShaderAttr::ShaderType ShaderType; - if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) { - if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) { - // The entry point is already annotated - check that it matches the - // triple. - if (Shader->getType() != ShaderType) { - Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch) - << Shader; - FD->setInvalidDecl(); - } - } else { - // Implicitly add the shader attribute if the entry function isn't - // explicitly annotated. - FD->addAttr(HLSLShaderAttr::CreateImplicit(Context, ShaderType, - FD->getBeginLoc())); - } - } else { - switch (TargetInfo.getTriple().getEnvironment()) { - case llvm::Triple::UnknownEnvironment: - case llvm::Triple::Library: - break; - default: - llvm_unreachable("Unhandled environment in triple"); - } - } -} - -void Sema::CheckHLSLEntryPoint(FunctionDecl *FD) { - const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); - assert(ShaderAttr && "Entry point has no shader attribute"); - HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); - - switch (ST) { - case HLSLShaderAttr::Pixel: - case HLSLShaderAttr::Vertex: - case HLSLShaderAttr::Geometry: - case HLSLShaderAttr::Hull: - case HLSLShaderAttr::Domain: - case HLSLShaderAttr::RayGeneration: - case HLSLShaderAttr::Intersection: - case HLSLShaderAttr::AnyHit: - case HLSLShaderAttr::ClosestHit: - case HLSLShaderAttr::Miss: - case HLSLShaderAttr::Callable: - if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) { - DiagnoseHLSLAttrStageMismatch(NT, ST, - {HLSLShaderAttr::Compute, - HLSLShaderAttr::Amplification, - HLSLShaderAttr::Mesh}); - FD->setInvalidDecl(); - } - break; - - case HLSLShaderAttr::Compute: - case HLSLShaderAttr::Amplification: - case HLSLShaderAttr::Mesh: - if (!FD->hasAttr<HLSLNumThreadsAttr>()) { - Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) - << HLSLShaderAttr::ConvertShaderTypeToStr(ST); - FD->setInvalidDecl(); - } - break; - } - - for (ParmVarDecl *Param : FD->parameters()) { - if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) { - CheckHLSLSemanticAnnotation(FD, Param, AnnotationAttr); - } else { - // FIXME: Handle struct parameters where annotations are on struct fields. - // See: https://github.com/llvm/llvm-project/issues/57875 - Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation); - Diag(Param->getLocation(), diag::note_previous_decl) << Param; - FD->setInvalidDecl(); - } - } - // FIXME: Verify return type semantic annotation. -} - -void Sema::CheckHLSLSemanticAnnotation( - FunctionDecl *EntryPoint, const Decl *Param, - const HLSLAnnotationAttr *AnnotationAttr) { - auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>(); - assert(ShaderAttr && "Entry point has no shader attribute"); - HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); - - switch (AnnotationAttr->getKind()) { - case attr::HLSLSV_DispatchThreadID: - case attr::HLSLSV_GroupIndex: - if (ST == HLSLShaderAttr::Compute) - return; - DiagnoseHLSLAttrStageMismatch(AnnotationAttr, ST, - {HLSLShaderAttr::Compute}); - break; - default: - llvm_unreachable("Unknown HLSLAnnotationAttr"); - } -} - -void Sema::DiagnoseHLSLAttrStageMismatch( - const Attr *A, HLSLShaderAttr::ShaderType Stage, - std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) { - SmallVector<StringRef, 8> StageStrings; - llvm::transform(AllowedStages, std::back_inserter(StageStrings), - [](HLSLShaderAttr::ShaderType ST) { - return StringRef( - HLSLShaderAttr::ConvertShaderTypeToStr(ST)); - }); - Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage) - << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage) - << (AllowedStages.size() != 1) << join(StageStrings, ", "); -} - bool Sema::CheckForConstantInitializer(Expr *Init, QualType DclT) { // FIXME: Need strict checking. In C89, we need to check for // any assignment, increment, decrement, function-calls, or diff --git a/clang/lib/Sema/SemaDeclAttr.cpp b/clang/lib/Sema/SemaDeclAttr.cpp index 8bce04640e748e..b91064e28e4153 100644 --- a/clang/lib/Sema/SemaDeclAttr.cpp +++ b/clang/lib/Sema/SemaDeclAttr.cpp @@ -39,6 +39,7 @@ #include "clang/Sema/ParsedAttr.h" #include "clang/Sema/Scope.h" #include "clang/Sema/ScopeInfo.h" +#include "clang/Sema/SemaHLSL.h" #include "clang/Sema/SemaInternal.h" #include "llvm/ADT/STLExtras.h" #include "llvm/ADT/StringExtras.h" @@ -7238,24 +7239,11 @@ static void handleHLSLNumThreadsAttr(Sema &S, Decl *D, const ParsedAttr &AL) { return; } - HLSLNumThreadsAttr *NewAttr = S.mergeHLSLNumThreadsAttr(D, AL, X, Y, Z); + HLSLNumThreadsAttr *NewAttr = S.HLSL().mergeNumThreadsAttr(D, AL, X, Y, Z); if (NewAttr) D->addAttr(NewAttr); } -HLSLNumThreadsAttr *Sema::mergeHLSLNumThreadsAttr(Decl *D, - const AttributeCommonInfo &AL, - int X, int Y, int Z) { - if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) { - if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) { - Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; - Diag(AL.getLoc(), diag::note_conflicting_attribute); - } - return nullptr; - } - return ::new (Context) HLSLNumThreadsAttr(Context, AL, X, Y, Z); -} - static bool isLegalTypeForHLSLSV_DispatchThreadID(QualType T) { if (!T->hasUnsignedIntegerRepresentation()) return false; @@ -7299,24 +7287,11 @@ static void handleHLSLShaderAttr(Sema &S, Decl *D, const ParsedAttr &AL) { // FIXME: check function match the shader stage. - HLSLShaderAttr *NewAttr = S.mergeHLSLShaderAttr(D, AL, ShaderType); + HLSLShaderAttr *NewAttr = S.HLSL().mergeShaderAttr(D, AL, ShaderType); if (NewAttr) D->addAttr(NewAttr); } -HLSLShaderAttr * -Sema::mergeHLSLShaderAttr(Decl *D, const AttributeCommonInfo &AL, - HLSLShaderAttr::ShaderType ShaderType) { - if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) { - if (NT->getType() != ShaderType) { - Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; - Diag(AL.getLoc(), diag::note_conflicting_attribute); - } - return nullptr; - } - return HLSLShaderAttr::Create(Context, ShaderType, AL); -} - static void handleHLSLResourceBindingAttr(Sema &S, Decl *D, const ParsedAttr &AL) { StringRef Space = "space0"; @@ -7391,34 +7366,13 @@ static void handleHLSLResourceBindingAttr(Sema &S, Decl *D, static void handleHLSLParamModifierAttr(Sema &S, Decl *D, const ParsedAttr &AL) { - HLSLParamModifierAttr *NewAttr = S.mergeHLSLParamModifierAttr( + HLSLParamModifierAttr *NewAttr = S.HLSL().mergeParamModifierAttr( D, AL, static_cast<HLSLParamModifierAttr::Spelling>(AL.getSemanticSpelling())); if (NewAttr) D->addAttr(NewAttr); } -HLSLParamModifierAttr * -Sema::mergeHLSLParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, - HLSLParamModifierAttr::Spelling Spelling) { - // We can only merge an `in` attribute with an `out` attribute. All other - // combinations of duplicated attributes are ill-formed. - if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) { - if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) || - (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) { - D->dropAttr<HLSLParamModifierAttr>(); - SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()}; - return HLSLParamModifierAttr::Create( - Context, /*MergedSpelling=*/true, AdjustedRange, - HLSLParamModifierAttr::Keyword_inout); - } - Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL; - Diag(PA->getLocation(), diag::note_conflicting_attribute); - return nullptr; - } - return HLSLParamModifierAttr::Create(Context, AL); -} - static void handleMSInheritanceAttr(Sema &S, Decl *D, const ParsedAttr &AL) { if (!S.LangOpts.CPlusPlus) { S.Diag(AL.getLoc(), diag::err_attribute_not_supported_in_lang) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 681849d6e6c8a2..bb9e37f18d370c 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -9,17 +9,25 @@ //===----------------------------------------------------------------------===// #include "clang/Sema/SemaHLSL.h" +#include "clang/Basic/DiagnosticSema.h" +#include "clang/Basic/LLVM.h" +#include "clang/Basic/TargetInfo.h" #include "clang/Sema/Sema.h" +#include "llvm/ADT/STLExtras.h" +#include "llvm/ADT/StringExtras.h" +#include "llvm/ADT/StringRef.h" +#include "llvm/Support/ErrorHandling.h" +#include "llvm/TargetParser/Triple.h" +#include <iterator> using namespace clang; SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {} -Decl *SemaHLSL::ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer, - SourceLocation KwLoc, - IdentifierInfo *Ident, - SourceLocation IdentLoc, - SourceLocation LBrace) { +Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer, + SourceLocation KwLoc, IdentifierInfo *Ident, + SourceLocation IdentLoc, + SourceLocation LBrace) { // For anonymous namespace, take the location of the left brace. DeclContext *LexicalParent = SemaRef.getCurLexicalContext(); HLSLBufferDecl *Result = HLSLBufferDecl::Create( @@ -31,8 +39,174 @@ Decl *SemaHLSL::ActOnStartHLSLBuffer(Scope *BufferScope, bool CBuffer, return Result; } -void SemaHLSL::ActOnFinishHLSLBuffer(Decl *Dcl, SourceLocation RBrace) { +void SemaHLSL::ActOnFinishBuffer(Decl *Dcl, SourceLocation RBrace) { auto *BufDecl = cast<HLSLBufferDecl>(Dcl); BufDecl->setRBraceLoc(RBrace); SemaRef.PopDeclContext(); } + +HLSLNumThreadsAttr *SemaHLSL::mergeNumThreadsAttr(Decl *D, + const AttributeCommonInfo &AL, + int X, int Y, int Z) { + if (HLSLNumThreadsAttr *NT = D->getAttr<HLSLNumThreadsAttr>()) { + if (NT->getX() != X || NT->getY() != Y || NT->getZ() != Z) { + Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; + Diag(AL.getLoc(), diag::note_conflicting_attribute); + } + return nullptr; + } + return ::new (getASTContext()) + HLSLNumThreadsAttr(getASTContext(), AL, X, Y, Z); +} + +HLSLShaderAttr * +SemaHLSL::mergeShaderAttr(Decl *D, const AttributeCommonInfo &AL, + HLSLShaderAttr::ShaderType ShaderType) { + if (HLSLShaderAttr *NT = D->getAttr<HLSLShaderAttr>()) { + if (NT->getType() != ShaderType) { + Diag(NT->getLocation(), diag::err_hlsl_attribute_param_mismatch) << AL; + Diag(AL.getLoc(), diag::note_conflicting_attribute); + } + return nullptr; + } + return HLSLShaderAttr::Create(getASTContext(), ShaderType, AL); +} + +HLSLParamModifierAttr * +SemaHLSL::mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, + HLSLParamModifierAttr::Spelling Spelling) { + // We can only merge an `in` attribute with an `out` attribute. All other + // combinations of duplicated attributes are ill-formed. + if (HLSLParamModifierAttr *PA = D->getAttr<HLSLParamModifierAttr>()) { + if ((PA->isIn() && Spelling == HLSLParamModifierAttr::Keyword_out) || + (PA->isOut() && Spelling == HLSLParamModifierAttr::Keyword_in)) { + D->dropAttr<HLSLParamModifierAttr>(); + SourceRange AdjustedRange = {PA->getLocation(), AL.getRange().getEnd()}; + return HLSLParamModifierAttr::Create( + getASTContext(), /*MergedSpelling=*/true, AdjustedRange, + HLSLParamModifierAttr::Keyword_inout); + } + Diag(AL.getLoc(), diag::err_hlsl_duplicate_parameter_modifier) << AL; + Diag(PA->getLocation(), diag::note_conflicting_attribute); + return nullptr; + } + return HLSLParamModifierAttr::Create(getASTContext(), AL); +} + +void SemaHLSL::ActOnTopLevelFunction(FunctionDecl *FD) { + auto &TargetInfo = getASTContext().getTargetInfo(); + + if (FD->getName() != TargetInfo.getTargetOpts().HLSLEntry) + return; + + StringRef Env = TargetInfo.getTriple().getEnvironmentName(); + HLSLShaderAttr::ShaderType ShaderType; + if (HLSLShaderAttr::ConvertStrToShaderType(Env, ShaderType)) { + if (const auto *Shader = FD->getAttr<HLSLShaderAttr>()) { + // The entry point is already annotated - check that it matches the + // triple. + if (Shader->getType() != ShaderType) { + Diag(Shader->getLocation(), diag::err_hlsl_entry_shader_attr_mismatch) + << Shader; + FD->setInvalidDecl(); + } + } else { + // Implicitly add the shader attribute if the entry function isn't + // explicitly annotated. + FD->addAttr(HLSLShaderAttr::CreateImplicit(getASTContext(), ShaderType, + FD->getBeginLoc())); + } + } else { + switch (TargetInfo.getTriple().getEnvironment()) { + case llvm::Triple::UnknownEnvironment: + case llvm::Triple::Library: + break; + default: + llvm_unreachable("Unhandled environment in triple"); + } + } +} + +void SemaHLSL::CheckEntryPoint(FunctionDecl *FD) { + const auto *ShaderAttr = FD->getAttr<HLSLShaderAttr>(); + assert(ShaderAttr && "Entry point has no shader attribute"); + HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); + + switch (ST) { + case HLSLShaderAttr::Pixel: + case HLSLShaderAttr::Vertex: + case HLSLShaderAttr::Geometry: + case HLSLShaderAttr::Hull: + case HLSLShaderAttr::Domain: + case HLSLShaderAttr::RayGeneration: + case HLSLShaderAttr::Intersection: + case HLSLShaderAttr::AnyHit: + case HLSLShaderAttr::ClosestHit: + case HLSLShaderAttr::Miss: + case HLSLShaderAttr::Callable: + if (const auto *NT = FD->getAttr<HLSLNumThreadsAttr>()) { + DiagnoseAttrStageMismatch(NT, ST, + {HLSLShaderAttr::Compute, + HLSLShaderAttr::Amplification, + HLSLShaderAttr::Mesh}); + FD->setInvalidDecl(); + } + break; + + case HLSLShaderAttr::Compute: + case HLSLShaderAttr::Amplification: + case HLSLShaderAttr::Mesh: + if (!FD->hasAttr<HLSLNumThreadsAttr>()) { + Diag(FD->getLocation(), diag::err_hlsl_missing_numthreads) + << HLSLShaderAttr::ConvertShaderTypeToStr(ST); + FD->setInvalidDecl(); + } + break; + } + + for (ParmVarDecl *Param : FD->parameters()) { + if (const auto *AnnotationAttr = Param->getAttr<HLSLAnnotationAttr>()) { + CheckSemanticAnnotation(FD, Param, AnnotationAttr); + } else { + // FIXME: Handle struct parameters where annotations are on struct fields. + // See: https://github.com/llvm/llvm-project/issues/57875 + Diag(FD->getLocation(), diag::err_hlsl_missing_semantic_annotation); + Diag(Param->getLocation(), diag::note_previous_decl) << Param; + FD->setInvalidDecl(); + } + } + // FIXME: Verify return type semantic annotation. +} + +void SemaHLSL::CheckSemanticAnnotation( + FunctionDecl *EntryPoint, const Decl *Param, + const HLSLAnnotationAttr *AnnotationAttr) { + auto *ShaderAttr = EntryPoint->getAttr<HLSLShaderAttr>(); + assert(ShaderAttr && "Entry point has no shader attribute"); + HLSLShaderAttr::ShaderType ST = ShaderAttr->getType(); + + switch (AnnotationAttr->getKind()) { + case attr::HLSLSV_DispatchThreadID: + case attr::HLSLSV_GroupIndex: + if (ST == HLSLShaderAttr::Compute) + return; + DiagnoseAttrStageMismatch(AnnotationAttr, ST, {HLSLShaderAttr::Compute}); + break; + default: + llvm_unreachable("Unknown HLSLAnnotationAttr"); + } +} + +void SemaHLSL::DiagnoseAttrStageMismatch( + const Attr *A, HLSLShaderAttr::ShaderType Stage, + std::initializer_list<HLSLShaderAttr::ShaderType> AllowedStages) { + SmallVector<StringRef, 8> StageStrings; + llvm::transform(AllowedStages, std::back_inserter(StageStrings), + [](HLSLShaderAttr::ShaderType ST) { + return StringRef( + HLSLShaderAttr::ConvertShaderTypeToStr(ST)); + }); + Diag(A->getLoc(), diag::err_hlsl_attr_unsupported_in_stage) + << A << HLSLShaderAttr::ConvertShaderTypeToStr(Stage) + << (AllowedStages.size() != 1) << join(StageStrings, ", "); +} _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits