https://github.com/hekota updated https://github.com/llvm/llvm-project/pull/111203
>From f545a14e11556c91d10b14617e3588fe5eae6d42 Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Fri, 4 Oct 2024 12:21:51 -0700 Subject: [PATCH 1/4] [HLSL] Collect explicit resource binding information (part 1) - Do not create resource binding attribute if it is not valid - Store basic resource binding information on HLSLResourceBindingAttr - Move UDT type checking to to ActOnVariableDeclarator Part 1 of #110719 --- clang/include/clang/Basic/Attr.td | 29 +++ clang/include/clang/Sema/SemaHLSL.h | 2 + clang/lib/Sema/SemaDecl.cpp | 3 + clang/lib/Sema/SemaHLSL.cpp | 227 ++++++++++++------ .../resource_binding_attr_error_udt.hlsl | 8 +- 5 files changed, 188 insertions(+), 81 deletions(-) diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index fbcbf0ed416416..668c599da81390 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4588,6 +4588,35 @@ def HLSLResourceBinding: InheritableAttr { let LangOpts = [HLSL]; let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>]; let Documentation = [HLSLResourceBindingDocs]; + let AdditionalMembers = [{ + enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; + + const FieldDecl *ResourceField = nullptr; + RegisterType RegType; + unsigned SlotNumber; + unsigned SpaceNumber; + + void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) { + RegType = RT; + SlotNumber = SlotNum; + SpaceNumber = SpaceNum; + } + void setResourceField(const FieldDecl *FD) { + ResourceField = FD; + } + const FieldDecl *getResourceField() { + return ResourceField; + } + RegisterType getRegisterType() { + return RegType; + } + unsigned getSlotNumber() { + return SlotNumber; + } + unsigned getSpaceNumber() { + return SpaceNumber; + } + }]; } def HLSLPackOffset: HLSLAnnotationAttr { diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index fa957abc9791af..018e7ea5901a2b 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -28,6 +28,7 @@ class AttributeCommonInfo; class IdentifierInfo; class ParsedAttr; class Scope; +class VarDecl; // FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no // longer need to create builtin buffer types in HLSLExternalSemaSource. @@ -62,6 +63,7 @@ class SemaHLSL : public SemaBase { const Attr *A, llvm::Triple::EnvironmentType Stage, std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages); void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU); + void ProcessResourceBindingOnDecl(VarDecl *D); QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 2bf610746bc317..8e27a5e068e702 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -7876,6 +7876,9 @@ NamedDecl *Sema::ActOnVariableDeclarator( // Handle attributes prior to checking for duplicates in MergeVarDecl ProcessDeclAttributes(S, NewVD, D); + if (getLangOpts().HLSL) + HLSL().ProcessResourceBindingOnDecl(NewVD); + // FIXME: This is probably the wrong location to be doing this and we should // probably be doing this for more attributes (especially for function // pointer attributes such as format, warn_unused_result, etc.). Ideally diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index fbcba201a351a6..568a8de30c1fc5 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -41,9 +41,7 @@ using namespace clang; using llvm::dxil::ResourceClass; - -enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; - +using RegisterType = HLSLResourceBindingAttr::RegisterType; static RegisterType getRegisterType(ResourceClass RC) { switch (RC) { case ResourceClass::SRV: @@ -985,44 +983,43 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) { return LocInfo; } -// get the record decl from a var decl that we expect -// represents a resource -static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) { - const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); - assert(Ty && "Resource must have an element type."); - - if (Ty->isBuiltinType()) - return nullptr; - - CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl(); - assert(TheRecordDecl && "Resource should have a resource type declaration."); - return TheRecordDecl; -} - +// Returns handle type of a resource, if the VarDecl is a resource +// or an array of resources static const HLSLAttributedResourceType * -findAttributedResourceTypeOnField(VarDecl *VD) { +FindHandleTypeOnResource(const VarDecl *VD) { + // If VarDecl is a resource class, the first field must + // be the resource handle of type HLSLAttributedResourceType assert(VD != nullptr && "expected VarDecl"); - if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) { - for (auto *FD : RD->fields()) { - if (const HLSLAttributedResourceType *AttrResType = - dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr())) - return AttrResType; + const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); + if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) { + if (!RD->fields().empty()) { + const auto &FirstFD = RD->fields().begin(); + return dyn_cast<HLSLAttributedResourceType>( + FirstFD->getType().getTypePtr()); } } return nullptr; } -// Iterate over RecordType fields and return true if any of them matched the -// register type -static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, - RegisterType RegType) { +// Walks though the user defined record type, finds resource class +// that matches the RegisterBinding.Type and assigns it to +// RegisterBinding::Decl. +static bool +ProcessResourceBindingOnUserRecordDecl(const RecordType *RT, + HLSLResourceBindingAttr *RBA) { + llvm::SmallVector<const Type *> TypesToScan; TypesToScan.emplace_back(RT); + RegisterType RegType = RBA->getRegisterType(); while (!TypesToScan.empty()) { const Type *T = TypesToScan.pop_back_val(); - while (T->isArrayType()) + + while (T->isArrayType()) { + // FIXME: calculate the binding size from the array dimensions (or + // unbounded for unsized array) size *= (size_of_array); T = T->getArrayElementTypeNoTypeQual(); + } if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { if (RegType == RegisterType::C) return true; @@ -1037,8 +1034,12 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, if (const HLSLAttributedResourceType *AttrResType = dyn_cast<HLSLAttributedResourceType>(FieldTy)) { ResourceClass RC = AttrResType->getAttrs().ResourceClass; - if (getRegisterType(RC) == RegType) + if (getRegisterType(RC) == RegType) { + assert(RBA->getResourceField() == nullptr && + "multiple register bindings of the same type are not allowed"); + RBA->setResourceField(FD); return true; + } } else { TypesToScan.emplace_back(FD->getType().getTypePtr()); } @@ -1047,26 +1048,28 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, return false; } -static void CheckContainsResourceForRegisterType(Sema &S, - SourceLocation &ArgLoc, - Decl *D, RegisterType RegType, - bool SpecifiedSpace) { +// return false if the register binding is not valid +static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, + Decl *D, RegisterType RegType, + bool SpecifiedSpace) { int RegTypeNum = static_cast<int>(RegType); // check if the decl type is groupshared if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - return; + return false; } // Cbuffers and Tbuffers are HLSLBufferDecl types if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer : ResourceClass::SRV; - if (RegType != getRegisterType(RC)) - S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) - << RegTypeNum; - return; + if (RegType == getRegisterType(RC)) + return true; + + S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) + << RegTypeNum; + return false; } // Samplers, UAVs, and SRVs are VarDecl types @@ -1075,11 +1078,13 @@ static void CheckContainsResourceForRegisterType(Sema &S, // Resource if (const HLSLAttributedResourceType *AttrResType = - findAttributedResourceTypeOnField(VD)) { - if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass)) - S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) - << RegTypeNum; - return; + FindHandleTypeOnResource(VD)) { + if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass)) + return true; + + S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) + << RegTypeNum; + return false; } const clang::Type *Ty = VD->getType().getTypePtr(); @@ -1088,36 +1093,43 @@ static void CheckContainsResourceForRegisterType(Sema &S, // Basic types if (Ty->isArithmeticType()) { + bool IsValid = true; bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext()); - if (SpecifiedSpace && !DeclaredInCOrTBuffer) + if (SpecifiedSpace && !DeclaredInCOrTBuffer) { S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); + IsValid = false; + } if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) { // Default Globals if (RegType == RegisterType::CBuffer) S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); - else if (RegType != RegisterType::C) + else if (RegType != RegisterType::C) { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + IsValid = false; + } } else { if (RegType == RegisterType::C) S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); - else + else { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + IsValid = false; + } } - } else if (Ty->isRecordType()) { - // Class/struct types - walk the declaration and check each field and - // subclass - if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType)) - S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member) - << RegTypeNum; - } else { - // Anything else is an error - S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + return IsValid; } + if (Ty->isRecordType()) + // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl + // that is called from ActOnVariableDeclarator + return true; + + // Anything else is an error + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + return false; } -static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, +static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType regType) { // make sure that there are no two register annotations // applied to the decl with the same register type @@ -1135,21 +1147,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType otherRegType = getRegisterType(attr->getSlot()); if (RegisterTypesDetected[static_cast<int>(otherRegType)]) { - if (PreviousConflicts[TheDecl].count(otherRegType)) - continue; int otherRegTypeNum = static_cast<int>(otherRegType); S.Diag(TheDecl->getLocation(), diag::err_hlsl_duplicate_register_annotation) << otherRegTypeNum; - PreviousConflicts[TheDecl].insert(otherRegType); - } else { - RegisterTypesDetected[static_cast<int>(otherRegType)] = true; + return false; } + RegisterTypesDetected[static_cast<int>(otherRegType)] = true; } } + return true; } -static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, +static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace) { @@ -1159,10 +1169,11 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, "expecting VarDecl or HLSLBufferDecl"); // check if the declaration contains resource matching the register type - CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace); + if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace)) + return false; // next, if multiple register annotations exist, check that none conflict. - ValidateMultipleRegisterAnnotations(S, D, RegType); + return ValidateMultipleRegisterAnnotations(S, D, RegType); } void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { @@ -1203,23 +1214,24 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { Slot = Str; } - RegisterType regType; + RegisterType RegType; + unsigned SlotNum = 0; + unsigned SpaceNum = 0; // Validate. if (!Slot.empty()) { - regType = getRegisterType(Slot); - if (regType == RegisterType::I) { + RegType = getRegisterType(Slot); + if (RegType == RegisterType::I) { Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i); return; } - if (regType == RegisterType::Invalid) { + if (RegType == RegisterType::Invalid) { Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1); return; } - StringRef SlotNum = Slot.substr(1); - unsigned Num = 0; - if (SlotNum.getAsInteger(10, Num)) { + StringRef SlotNumStr = Slot.substr(1); + if (SlotNumStr.getAsInteger(10, SlotNum)) { Diag(ArgLoc, diag::err_hlsl_unsupported_register_number); return; } @@ -1229,20 +1241,22 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space; return; } - StringRef SpaceNum = Space.substr(5); - unsigned Num = 0; - if (SpaceNum.getAsInteger(10, Num)) { + StringRef SpaceNumStr = Space.substr(5); + if (SpaceNumStr.getAsInteger(10, SpaceNum)) { Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space; return; } - DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, regType, - SpecifiedSpace); + if (!DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, RegType, + SpecifiedSpace)) + return; HLSLResourceBindingAttr *NewAttr = HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL); - if (NewAttr) + if (NewAttr) { + NewAttr->setBinding(RegType, SlotNum, SpaceNum); TheDecl->addAttr(NewAttr); + } } void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) { @@ -2228,3 +2242,62 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) { Ty.addRestrict(); return Ty; } + +// Walks though existing explicit bindings, finds the actual resource class +// decl the binding applies to and sets it to attr->ResourceField. +// Additional processing of resource binding can be added here later on, +// such as preparation for overapping resource detection or implicit binding. +void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) { + if (!D->hasGlobalStorage()) + return; + + for (Attr *A : D->attrs()) { + HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); + if (!RBA) + continue; + + // // Cbuffers and Tbuffers are HLSLBufferDecl types + if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { + assert(RBA->getRegisterType() == + getRegisterType(CBufferOrTBuffer->isCBuffer() + ? ResourceClass::CBuffer + : ResourceClass::SRV) && + "this should have been handled in DiagnoseLocalRegisterBinding"); + // should we handle HLSLBufferDecl here? + continue; + } + + // Samplers, UAVs, and SRVs are VarDecl types + assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl"); + const VarDecl *VD = cast<VarDecl>(D); + + // Register binding directly on global resource class variable + if (const HLSLAttributedResourceType *AttrResType = + FindHandleTypeOnResource(VD)) { + // FIXME: if array, calculate the binding size from the array dimensions + // (or unbounded for unsized array) + assert(RBA->getResourceField() == nullptr); + continue; + } + + // Global array + const clang::Type *Ty = VD->getType().getTypePtr(); + while (Ty->isArrayType()) { + Ty = Ty->getArrayElementTypeNoTypeQual(); + } + + // Basic types + if (Ty->isArithmeticType()) { + continue; + } + + if (Ty->isRecordType()) { + if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(), + RBA)) { + SemaRef.Diag(D->getLocation(), + diag::warn_hlsl_user_defined_type_missing_member) + << static_cast<int>(RBA->getRegisterType()); + } + } + } +} diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl index ea2d576e4cca55..40517f393e1284 100644 --- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl +++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl @@ -106,7 +106,6 @@ struct Eg12{ MySRV s1; MySRV s2; }; -// expected-warning@+3{{binding type 'u' only applies to types containing UAV resources}} // expected-warning@+2{{binding type 'u' only applies to types containing UAV resources}} // expected-error@+1{{binding type 'u' cannot be applied more than once}} Eg12 e12 : register(u9) : register(u10); @@ -115,12 +114,14 @@ struct Eg13{ MySRV s1; MySRV s2; }; -// expected-warning@+4{{binding type 'u' only applies to types containing UAV resources}} // expected-warning@+3{{binding type 'u' only applies to types containing UAV resources}} -// expected-warning@+2{{binding type 'u' only applies to types containing UAV resources}} +// expected-error@+2{{binding type 'u' cannot be applied more than once}} // expected-error@+1{{binding type 'u' cannot be applied more than once}} Eg13 e13 : register(u9) : register(u10) : register(u11); +// expected-error@+1{{binding type 't' cannot be applied more than once}} +Eg13 e13_2 : register(t11) : register(t12); + struct Eg14{ MyTemplatedUAV<int> r1; }; @@ -132,4 +133,3 @@ struct Eg15 { }; // expected no error Eg15 e15 : register(c0); - >From a6c06943ce5df79e6765e12874c96c907b20d030 Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Fri, 4 Oct 2024 13:52:47 -0700 Subject: [PATCH 2/4] clang-format --- clang/lib/Sema/SemaHLSL.cpp | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 568a8de30c1fc5..5c27a74a853bba 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -2250,7 +2250,7 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) { void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) { if (!D->hasGlobalStorage()) return; - + for (Attr *A : D->attrs()) { HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); if (!RBA) >From a6a52327bef4325a00a2b8a1715b8b5b1315994f Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Wed, 9 Oct 2024 16:34:06 -0700 Subject: [PATCH 3/4] Collect all resource binding requirements and analyze explicit bindings based on that Also adds bindings size calculation and removed ResourceDecl field from HLSLResourceBindingAttr. --- clang/include/clang/Basic/Attr.td | 25 ++- clang/include/clang/Sema/SemaHLSL.h | 59 +++++- clang/lib/Sema/SemaDecl.cpp | 2 +- clang/lib/Sema/SemaHLSL.cpp | 276 ++++++++++++++++++---------- 4 files changed, 256 insertions(+), 106 deletions(-) diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index 668c599da81390..3997ffe78fbf96 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4591,22 +4591,20 @@ def HLSLResourceBinding: InheritableAttr { let AdditionalMembers = [{ enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; - const FieldDecl *ResourceField = nullptr; RegisterType RegType; unsigned SlotNumber; unsigned SpaceNumber; + + // Size of the binding + // 0 == not set + //-1 == unbounded + int Size; - void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) { + void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum, int Size = 0) { RegType = RT; SlotNumber = SlotNum; SpaceNumber = SpaceNum; } - void setResourceField(const FieldDecl *FD) { - ResourceField = FD; - } - const FieldDecl *getResourceField() { - return ResourceField; - } RegisterType getRegisterType() { return RegType; } @@ -4616,6 +4614,17 @@ def HLSLResourceBinding: InheritableAttr { unsigned getSpaceNumber() { return SpaceNumber; } + unsigned getSize() { + assert(Size == -1 || Size > 0 && "size not set"); + return Size; + } + void setSize(int N) { + assert(N == -1 || N > 0 && "unexpected size value"); + Size = N; + } + bool isSizeUnbounded() { + return Size == -1; + } }]; } diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index 018e7ea5901a2b..ce262fd41dff37 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -30,12 +30,60 @@ class ParsedAttr; class Scope; class VarDecl; +using llvm::dxil::ResourceClass; + // FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no // longer need to create builtin buffer types in HLSLExternalSemaSource. bool CreateHLSLAttributedResourceType( Sema &S, QualType Wrapped, ArrayRef<const Attr *> AttrList, QualType &ResType, HLSLAttributedResourceLocInfo *LocInfo = nullptr); +enum class BindingType : uint8_t { NotAssigned, Explicit, Implicit }; + +// DeclBindingInfo struct stores information about required/assigned resource +// binding onon a declaration for specific resource class. +struct DeclBindingInfo { + const VarDecl *Decl; + ResourceClass ResClass; + int Size; // -1 == unbounded array + const HLSLResourceBindingAttr *Attr; + BindingType BindType; + + DeclBindingInfo(const VarDecl *Decl, ResourceClass ResClass, int Size = 0, + BindingType BindType = BindingType::NotAssigned, + const HLSLResourceBindingAttr *Attr = nullptr) + : Decl(Decl), ResClass(ResClass), Size(Size), Attr(Attr), + BindType(BindType) {} + + void setBindingAttribute(HLSLResourceBindingAttr *A, BindingType BT) { + assert(Attr == nullptr && BindType == BindingType::NotAssigned && + "binding attribute already assigned"); + Attr = A; + BindType = BT; + } +}; + +// ResourceBindings class stores information about all resource bindings +// in a shader. It is used for binding diagnostics and implicit binding +// assigments. +class ResourceBindings { +public: + DeclBindingInfo *addDeclBindingInfo(const VarDecl *VD, ResourceClass ResClass, + int Size); + DeclBindingInfo *getDeclBindingInfo(const VarDecl *VD, + ResourceClass ResClass); + bool hasBindingInfoForDecl(const VarDecl *VD); + +private: + // List of all resource bindings required by the shader. + // A global declaration can have multiple bindings for different + // resource classes. They are all stored sequentially in this list. + // The DeclToBindingListIndex hashtable maps a declaration to the + // index of the first binding info in the list. + llvm::SmallVector<DeclBindingInfo> BindingsList; + llvm::DenseMap<const VarDecl *, unsigned> DeclToBindingListIndex; +}; + class SemaHLSL : public SemaBase { public: SemaHLSL(Sema &S); @@ -56,6 +104,7 @@ class SemaHLSL : public SemaBase { mergeParamModifierAttr(Decl *D, const AttributeCommonInfo &AL, HLSLParamModifierAttr::Spelling Spelling); void ActOnTopLevelFunction(FunctionDecl *FD); + void ActOnVariableDeclarator(VarDecl *VD); void CheckEntryPoint(FunctionDecl *FD); void CheckSemanticAnnotation(FunctionDecl *EntryPoint, const Decl *Param, const HLSLAnnotationAttr *AnnotationAttr); @@ -63,7 +112,6 @@ class SemaHLSL : public SemaBase { const Attr *A, llvm::Triple::EnvironmentType Stage, std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages); void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU); - void ProcessResourceBindingOnDecl(VarDecl *D); QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, @@ -104,6 +152,15 @@ class SemaHLSL : public SemaBase { llvm::DenseMap<const HLSLAttributedResourceType *, HLSLAttributedResourceLocInfo> LocsForHLSLAttributedResources; + + // List of all resource bindings + ResourceBindings Bindings; + +private: + void FindResourcesOnVarDecl(VarDecl *D); + void FindResourcesOnUserRecordDecl(const VarDecl *VD, const RecordType *RT, + int Size); + void ProcessExplicitBindingsOnDecl(VarDecl *D); }; } // namespace clang diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 8e27a5e068e702..770d00710a6816 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -7877,7 +7877,7 @@ NamedDecl *Sema::ActOnVariableDeclarator( ProcessDeclAttributes(S, NewVD, D); if (getLangOpts().HLSL) - HLSL().ProcessResourceBindingOnDecl(NewVD); + HLSL().ActOnVariableDeclarator(NewVD); // FIXME: This is probably the wrong location to be doing this and we should // probably be doing this for more attributes (especially for function diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 5c27a74a853bba..197ee63c07deeb 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -40,8 +40,8 @@ #include <utility> using namespace clang; -using llvm::dxil::ResourceClass; using RegisterType = HLSLResourceBindingAttr::RegisterType; + static RegisterType getRegisterType(ResourceClass RC) { switch (RC) { case ResourceClass::SRV: @@ -81,6 +81,49 @@ static RegisterType getRegisterType(StringRef Slot) { } } +static ResourceClass getResourceClass(RegisterType RT) { + switch (RT) { + case RegisterType::SRV: + return ResourceClass::SRV; + case RegisterType::UAV: + return ResourceClass::UAV; + case RegisterType::CBuffer: + return ResourceClass::CBuffer; + case RegisterType::Sampler: + return ResourceClass::Sampler; + default: + llvm_unreachable("unexpected RegisterType value"); + } +} + +DeclBindingInfo *ResourceBindings::addDeclBindingInfo(const VarDecl *VD, + ResourceClass ResClass, + int Size) { + assert(getDeclBindingInfo(VD, ResClass) == nullptr && + "DeclBindingInfo already added"); + if (DeclToBindingListIndex.find(VD) == DeclToBindingListIndex.end()) + DeclToBindingListIndex[VD] = BindingsList.size(); + return &BindingsList.emplace_back(DeclBindingInfo(VD, ResClass, Size)); +} + +DeclBindingInfo *ResourceBindings::getDeclBindingInfo(const VarDecl *VD, + ResourceClass ResClass) { + auto Entry = DeclToBindingListIndex.find(VD); + if (Entry != DeclToBindingListIndex.end()) { + unsigned Index = Entry->getSecond(); + while (Index < BindingsList.size() && BindingsList[Index].Decl == VD) { + if (BindingsList[Index].ResClass == ResClass) + return &BindingsList[Index]; + Index++; + } + } + return nullptr; +} + +bool ResourceBindings::hasBindingInfoForDecl(const VarDecl *VD) { + return DeclToBindingListIndex.contains(VD); +} + SemaHLSL::SemaHLSL(Sema &S) : SemaBase(S) {} Decl *SemaHLSL::ActOnStartBuffer(Scope *BufferScope, bool CBuffer, @@ -983,14 +1026,11 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) { return LocInfo; } -// Returns handle type of a resource, if the VarDecl is a resource -// or an array of resources +// Returns handle type of a resource, if the type is a resource static const HLSLAttributedResourceType * -FindHandleTypeOnResource(const VarDecl *VD) { - // If VarDecl is a resource class, the first field must +FindHandleTypeOnResource(const Type *Ty) { + // If Ty is a resource class, the first field must // be the resource handle of type HLSLAttributedResourceType - assert(VD != nullptr && "expected VarDecl"); - const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) { if (!RD->fields().empty()) { const auto &FirstFD = RD->fields().begin(); @@ -1001,51 +1041,53 @@ FindHandleTypeOnResource(const VarDecl *VD) { return nullptr; } -// Walks though the user defined record type, finds resource class -// that matches the RegisterBinding.Type and assigns it to -// RegisterBinding::Decl. -static bool -ProcessResourceBindingOnUserRecordDecl(const RecordType *RT, - HLSLResourceBindingAttr *RBA) { - - llvm::SmallVector<const Type *> TypesToScan; - TypesToScan.emplace_back(RT); - RegisterType RegType = RBA->getRegisterType(); - - while (!TypesToScan.empty()) { - const Type *T = TypesToScan.pop_back_val(); - - while (T->isArrayType()) { - // FIXME: calculate the binding size from the array dimensions (or - // unbounded for unsized array) size *= (size_of_array); - T = T->getArrayElementTypeNoTypeQual(); - } - if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { - if (RegType == RegisterType::C) - return true; +// Returns handle type of a resource, if the VarDecl is a resource +static const HLSLAttributedResourceType * +FindHandleTypeOnResource(const VarDecl *VD) { + assert(VD != nullptr && "expected VarDecl"); + return FindHandleTypeOnResource(VD->getType().getTypePtr()); +} + +// Walks though the global variable declaration, collects all resource binding +// requirements and adds them to Bindings +void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD, + const RecordType *RT, int Size) { + const RecordDecl *RD = RT->getDecl(); + for (FieldDecl *FD : RD->fields()) { + const Type *Ty = FD->getType()->getUnqualifiedDesugaredType(); + + // Calculate array size and unwrap + int ArraySize = 1; + assert(!Ty->isIncompleteArrayType() && + "incomplete arrays inside user defined types are not supported"); + while (Ty->isConstantArrayType()) { + const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); + ArraySize *= CAT->getSize().getSExtValue(); + Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); } - const RecordType *RT = T->getAs<RecordType>(); - if (!RT) + + if (!Ty->isRecordType()) continue; - const RecordDecl *RD = RT->getDecl(); - for (FieldDecl *FD : RD->fields()) { - const Type *FieldTy = FD->getType().getTypePtr(); - if (const HLSLAttributedResourceType *AttrResType = - dyn_cast<HLSLAttributedResourceType>(FieldTy)) { - ResourceClass RC = AttrResType->getAttrs().ResourceClass; - if (getRegisterType(RC) == RegType) { - assert(RBA->getResourceField() == nullptr && - "multiple register bindings of the same type are not allowed"); - RBA->setResourceField(FD); - return true; - } - } else { - TypesToScan.emplace_back(FD->getType().getTypePtr()); - } + // Field is a resource or array of resources + if (const HLSLAttributedResourceType *AttrResType = + FindHandleTypeOnResource(Ty)) { + ResourceClass RC = AttrResType->getAttrs().ResourceClass; + + // Add a new DeclBindingInfo to Bindings. Update the binding size if + // a binding info already exists (there are multiple resources of same + // resource class in this user decl) + if (auto *DBI = Bindings.getDeclBindingInfo(VD, RC)) + DBI->Size += Size * ArraySize; + else + Bindings.addDeclBindingInfo(VD, RC, Size); + } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) { + // Recursively scan embedded struct or class; it would be nice to do this + // without recursion, but tricky to corrently calculate the size. + // Hopefully nesting of structs in structs too many levels is unlikely. + FindResourcesOnUserRecordDecl(VD, RT, Size); } } - return false; } // return false if the register binding is not valid @@ -1093,11 +1135,9 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, // Basic types if (Ty->isArithmeticType()) { - bool IsValid = true; bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext()); if (SpecifiedSpace && !DeclaredInCOrTBuffer) { S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); - IsValid = false; } if (!DeclaredInCOrTBuffer && @@ -1107,17 +1147,15 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); else if (RegType != RegisterType::C) { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - IsValid = false; } } else { if (RegType == RegisterType::C) S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); else { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - IsValid = false; } } - return IsValid; + return false; } if (Ty->isRecordType()) // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl @@ -2057,6 +2095,7 @@ bool SemaHLSL::IsIntangibleType(clang::QualType QT) { CXXRecordDecl *RD = RT->getAsCXXRecordDecl(); assert(RD != nullptr && "all HLSL struct and classes should be CXXRecordDecl"); + assert(RD->isCompleteDefinition() && "expecting complete type"); return RD->isHLSLIntangible(); } @@ -2243,61 +2282,106 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) { return Ty; } -// Walks though existing explicit bindings, finds the actual resource class -// decl the binding applies to and sets it to attr->ResourceField. -// Additional processing of resource binding can be added here later on, -// such as preparation for overapping resource detection or implicit binding. -void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) { - if (!D->hasGlobalStorage()) +void SemaHLSL::ActOnVariableDeclarator(VarDecl *VD) { + if (VD->hasGlobalStorage()) { + // make sure the declaration has a complete type + if (SemaRef.RequireCompleteType( + VD->getLocation(), + SemaRef.getASTContext().getBaseElementType(VD->getType()), + diag::err_typecheck_decl_incomplete_type)) { + VD->setInvalidDecl(); + return; + } + + // find all resources on decl + if (IsIntangibleType(VD->getType())) + FindResourcesOnVarDecl(VD); + + // process explicit bindings + ProcessExplicitBindingsOnDecl(VD); + } +} + +// Walks though the global variable declaration, collects all resource binding +// requirements and adds them to Bindings +void SemaHLSL::FindResourcesOnVarDecl(VarDecl *VD) { + assert(VD->hasGlobalStorage() && IsIntangibleType(VD->getType()) && + "expected global variable that contains HLSL resource"); + + // Cbuffers and Tbuffers are HLSLBufferDecl types + if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(VD)) { + Bindings.addDeclBindingInfo(VD, + CBufferOrTBuffer->isCBuffer() + ? ResourceClass::CBuffer + : ResourceClass::SRV, + 1); return; + } - for (Attr *A : D->attrs()) { - HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); - if (!RBA) - continue; + // Calculate size of array and unwrap + int Size = 1; + const Type *Ty = VD->getType()->getUnqualifiedDesugaredType(); + if (Ty->isIncompleteArrayType()) + Size = -1; + while (Ty->isConstantArrayType()) { + const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); + Size *= CAT->getSize().getSExtValue(); + Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); + } - // // Cbuffers and Tbuffers are HLSLBufferDecl types - if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { - assert(RBA->getRegisterType() == - getRegisterType(CBufferOrTBuffer->isCBuffer() - ? ResourceClass::CBuffer - : ResourceClass::SRV) && - "this should have been handled in DiagnoseLocalRegisterBinding"); - // should we handle HLSLBufferDecl here? - continue; - } + // Resource (or array of resources) + if (const HLSLAttributedResourceType *AttrResType = + FindHandleTypeOnResource(Ty)) { + Bindings.addDeclBindingInfo(VD, AttrResType->getAttrs().ResourceClass, + Size); + return; + } - // Samplers, UAVs, and SRVs are VarDecl types - assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl"); - const VarDecl *VD = cast<VarDecl>(D); + assert(Size != -1 && + "unbounded arrays of user defined types are not supported"); - // Register binding directly on global resource class variable - if (const HLSLAttributedResourceType *AttrResType = - FindHandleTypeOnResource(VD)) { - // FIXME: if array, calculate the binding size from the array dimensions - // (or unbounded for unsized array) - assert(RBA->getResourceField() == nullptr); + // User defined record type + if (const RecordType *RT = dyn_cast<RecordType>(Ty)) + FindResourcesOnUserRecordDecl(VD, RT, Size); +} + +// Walks though the explicit resource binding attributes on the declaration, +// and makes sure there is a resource that matched the binding and updates +// DeclBindingInfoLists +void SemaHLSL::ProcessExplicitBindingsOnDecl(VarDecl *VD) { + assert(VD->hasGlobalStorage() && "expected global variable"); + + for (Attr *A : VD->attrs()) { + HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); + if (!RBA) continue; - } - // Global array - const clang::Type *Ty = VD->getType().getTypePtr(); - while (Ty->isArrayType()) { - Ty = Ty->getArrayElementTypeNoTypeQual(); - } + RegisterType RT = RBA->getRegisterType(); + assert(RT != RegisterType::I && RT != RegisterType::Invalid && + "invalid or obsolete register type should never have an attribute " + "created"); - // Basic types - if (Ty->isArithmeticType()) { + // These were already diagnosed earlier + if (RT == RegisterType::C) { + if (Bindings.hasBindingInfoForDecl(VD)) + SemaRef.Diag(VD->getLocation(), + diag::warn_hlsl_user_defined_type_missing_member) + << static_cast<int>(RT); continue; } - if (Ty->isRecordType()) { - if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(), - RBA)) { - SemaRef.Diag(D->getLocation(), - diag::warn_hlsl_user_defined_type_missing_member) - << static_cast<int>(RBA->getRegisterType()); - } + // Find DeclBindingInfo for this binding and update it, or report error + // if it does not exist (user type does to contain resources with the + // expected resource class). + ResourceClass RC = getResourceClass(RT); + if (DeclBindingInfo *BI = Bindings.getDeclBindingInfo(VD, RC)) { + // update binding info + RBA->setSize(BI->Size); + BI->setBindingAttribute(RBA, BindingType::Explicit); + } else { + SemaRef.Diag(VD->getLocation(), + diag::warn_hlsl_user_defined_type_missing_member) + << static_cast<int>(RT); } } } >From aa6247f414b2bd3d39f349646f3a97ec72d5d517 Mon Sep 17 00:00:00 2001 From: Helena Kotas <heko...@microsoft.com> Date: Wed, 9 Oct 2024 17:08:25 -0700 Subject: [PATCH 4/4] removed unused variable, cleanup --- clang/lib/Sema/SemaHLSL.cpp | 14 +++----------- 1 file changed, 3 insertions(+), 11 deletions(-) diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 197ee63c07deeb..0423340ee5fc4f 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -1136,24 +1136,21 @@ static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, // Basic types if (Ty->isArithmeticType()) { bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext()); - if (SpecifiedSpace && !DeclaredInCOrTBuffer) { + if (SpecifiedSpace && !DeclaredInCOrTBuffer) S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); - } if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) { // Default Globals if (RegType == RegisterType::CBuffer) S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); - else if (RegType != RegisterType::C) { + else if (RegType != RegisterType::C) S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - } } else { if (RegType == RegisterType::C) S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); - else { + else S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - } } return false; } @@ -1172,13 +1169,8 @@ static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, // make sure that there are no two register annotations // applied to the decl with the same register type bool RegisterTypesDetected[5] = {false}; - RegisterTypesDetected[static_cast<int>(regType)] = true; - // we need a static map to keep track of previous conflicts - // so that we don't emit the same error multiple times - static std::map<Decl *, std::set<RegisterType>> PreviousConflicts; - for (auto it = TheDecl->attr_begin(); it != TheDecl->attr_end(); ++it) { if (HLSLResourceBindingAttr *attr = dyn_cast<HLSLResourceBindingAttr>(*it)) { _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits