llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Helena Kotas (hekota) <details> <summary>Changes</summary> Adds fields to `HLSLResourceBindingAttr` to store processed binding information. This will be used by CodeGen or Sema for resource initialization or overlapping mapping diagnostic. Moves binding checks for user defined types (UDTs) to `ProcessResourceBindingOnDecl` (called from ActOnVariableDeclarator), which updated the information in the attribute and where additional processing of the explicit resource binding will be added in the future. Changed `handleResourceBindingAttr` to not create the resource binding attribute if the local binding diagnostic detects errors. Part 1 of #<!-- -->110719 --- Full diff: https://github.com/llvm/llvm-project/pull/111203.diff 5 Files Affected: - (modified) clang/include/clang/Basic/Attr.td (+29) - (modified) clang/include/clang/Sema/SemaHLSL.h (+2) - (modified) clang/lib/Sema/SemaDecl.cpp (+3) - (modified) clang/lib/Sema/SemaHLSL.cpp (+150-77) - (modified) clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl (+4-4) ``````````diff diff --git a/clang/include/clang/Basic/Attr.td b/clang/include/clang/Basic/Attr.td index fbcbf0ed416416..668c599da81390 100644 --- a/clang/include/clang/Basic/Attr.td +++ b/clang/include/clang/Basic/Attr.td @@ -4588,6 +4588,35 @@ def HLSLResourceBinding: InheritableAttr { let LangOpts = [HLSL]; let Args = [StringArgument<"Slot">, StringArgument<"Space", 1>]; let Documentation = [HLSLResourceBindingDocs]; + let AdditionalMembers = [{ + enum class RegisterType : unsigned { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; + + const FieldDecl *ResourceField = nullptr; + RegisterType RegType; + unsigned SlotNumber; + unsigned SpaceNumber; + + void setBinding(RegisterType RT, unsigned SlotNum, unsigned SpaceNum) { + RegType = RT; + SlotNumber = SlotNum; + SpaceNumber = SpaceNum; + } + void setResourceField(const FieldDecl *FD) { + ResourceField = FD; + } + const FieldDecl *getResourceField() { + return ResourceField; + } + RegisterType getRegisterType() { + return RegType; + } + unsigned getSlotNumber() { + return SlotNumber; + } + unsigned getSpaceNumber() { + return SpaceNumber; + } + }]; } def HLSLPackOffset: HLSLAnnotationAttr { diff --git a/clang/include/clang/Sema/SemaHLSL.h b/clang/include/clang/Sema/SemaHLSL.h index fa957abc9791af..018e7ea5901a2b 100644 --- a/clang/include/clang/Sema/SemaHLSL.h +++ b/clang/include/clang/Sema/SemaHLSL.h @@ -28,6 +28,7 @@ class AttributeCommonInfo; class IdentifierInfo; class ParsedAttr; class Scope; +class VarDecl; // FIXME: This can be hidden (as static function in SemaHLSL.cpp) once we no // longer need to create builtin buffer types in HLSLExternalSemaSource. @@ -62,6 +63,7 @@ class SemaHLSL : public SemaBase { const Attr *A, llvm::Triple::EnvironmentType Stage, std::initializer_list<llvm::Triple::EnvironmentType> AllowedStages); void DiagnoseAvailabilityViolations(TranslationUnitDecl *TU); + void ProcessResourceBindingOnDecl(VarDecl *D); QualType handleVectorBinOpConversion(ExprResult &LHS, ExprResult &RHS, QualType LHSType, QualType RHSType, diff --git a/clang/lib/Sema/SemaDecl.cpp b/clang/lib/Sema/SemaDecl.cpp index 2bf610746bc317..8e27a5e068e702 100644 --- a/clang/lib/Sema/SemaDecl.cpp +++ b/clang/lib/Sema/SemaDecl.cpp @@ -7876,6 +7876,9 @@ NamedDecl *Sema::ActOnVariableDeclarator( // Handle attributes prior to checking for duplicates in MergeVarDecl ProcessDeclAttributes(S, NewVD, D); + if (getLangOpts().HLSL) + HLSL().ProcessResourceBindingOnDecl(NewVD); + // FIXME: This is probably the wrong location to be doing this and we should // probably be doing this for more attributes (especially for function // pointer attributes such as format, warn_unused_result, etc.). Ideally diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index fbcba201a351a6..568a8de30c1fc5 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -41,9 +41,7 @@ using namespace clang; using llvm::dxil::ResourceClass; - -enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; - +using RegisterType = HLSLResourceBindingAttr::RegisterType; static RegisterType getRegisterType(ResourceClass RC) { switch (RC) { case ResourceClass::SRV: @@ -985,44 +983,43 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) { return LocInfo; } -// get the record decl from a var decl that we expect -// represents a resource -static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) { - const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); - assert(Ty && "Resource must have an element type."); - - if (Ty->isBuiltinType()) - return nullptr; - - CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl(); - assert(TheRecordDecl && "Resource should have a resource type declaration."); - return TheRecordDecl; -} - +// Returns handle type of a resource, if the VarDecl is a resource +// or an array of resources static const HLSLAttributedResourceType * -findAttributedResourceTypeOnField(VarDecl *VD) { +FindHandleTypeOnResource(const VarDecl *VD) { + // If VarDecl is a resource class, the first field must + // be the resource handle of type HLSLAttributedResourceType assert(VD != nullptr && "expected VarDecl"); - if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) { - for (auto *FD : RD->fields()) { - if (const HLSLAttributedResourceType *AttrResType = - dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr())) - return AttrResType; + const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); + if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) { + if (!RD->fields().empty()) { + const auto &FirstFD = RD->fields().begin(); + return dyn_cast<HLSLAttributedResourceType>( + FirstFD->getType().getTypePtr()); } } return nullptr; } -// Iterate over RecordType fields and return true if any of them matched the -// register type -static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, - RegisterType RegType) { +// Walks though the user defined record type, finds resource class +// that matches the RegisterBinding.Type and assigns it to +// RegisterBinding::Decl. +static bool +ProcessResourceBindingOnUserRecordDecl(const RecordType *RT, + HLSLResourceBindingAttr *RBA) { + llvm::SmallVector<const Type *> TypesToScan; TypesToScan.emplace_back(RT); + RegisterType RegType = RBA->getRegisterType(); while (!TypesToScan.empty()) { const Type *T = TypesToScan.pop_back_val(); - while (T->isArrayType()) + + while (T->isArrayType()) { + // FIXME: calculate the binding size from the array dimensions (or + // unbounded for unsized array) size *= (size_of_array); T = T->getArrayElementTypeNoTypeQual(); + } if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { if (RegType == RegisterType::C) return true; @@ -1037,8 +1034,12 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, if (const HLSLAttributedResourceType *AttrResType = dyn_cast<HLSLAttributedResourceType>(FieldTy)) { ResourceClass RC = AttrResType->getAttrs().ResourceClass; - if (getRegisterType(RC) == RegType) + if (getRegisterType(RC) == RegType) { + assert(RBA->getResourceField() == nullptr && + "multiple register bindings of the same type are not allowed"); + RBA->setResourceField(FD); return true; + } } else { TypesToScan.emplace_back(FD->getType().getTypePtr()); } @@ -1047,26 +1048,28 @@ static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, return false; } -static void CheckContainsResourceForRegisterType(Sema &S, - SourceLocation &ArgLoc, - Decl *D, RegisterType RegType, - bool SpecifiedSpace) { +// return false if the register binding is not valid +static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, + Decl *D, RegisterType RegType, + bool SpecifiedSpace) { int RegTypeNum = static_cast<int>(RegType); // check if the decl type is groupshared if (D->hasAttr<HLSLGroupSharedAddressSpaceAttr>()) { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; - return; + return false; } // Cbuffers and Tbuffers are HLSLBufferDecl types if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { ResourceClass RC = CBufferOrTBuffer->isCBuffer() ? ResourceClass::CBuffer : ResourceClass::SRV; - if (RegType != getRegisterType(RC)) - S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) - << RegTypeNum; - return; + if (RegType == getRegisterType(RC)) + return true; + + S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) + << RegTypeNum; + return false; } // Samplers, UAVs, and SRVs are VarDecl types @@ -1075,11 +1078,13 @@ static void CheckContainsResourceForRegisterType(Sema &S, // Resource if (const HLSLAttributedResourceType *AttrResType = - findAttributedResourceTypeOnField(VD)) { - if (RegType != getRegisterType(AttrResType->getAttrs().ResourceClass)) - S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) - << RegTypeNum; - return; + FindHandleTypeOnResource(VD)) { + if (RegType == getRegisterType(AttrResType->getAttrs().ResourceClass)) + return true; + + S.Diag(D->getLocation(), diag::err_hlsl_binding_type_mismatch) + << RegTypeNum; + return false; } const clang::Type *Ty = VD->getType().getTypePtr(); @@ -1088,36 +1093,43 @@ static void CheckContainsResourceForRegisterType(Sema &S, // Basic types if (Ty->isArithmeticType()) { + bool IsValid = true; bool DeclaredInCOrTBuffer = isa<HLSLBufferDecl>(D->getDeclContext()); - if (SpecifiedSpace && !DeclaredInCOrTBuffer) + if (SpecifiedSpace && !DeclaredInCOrTBuffer) { S.Diag(ArgLoc, diag::err_hlsl_space_on_global_constant); + IsValid = false; + } if (!DeclaredInCOrTBuffer && (Ty->isIntegralType(S.getASTContext()) || Ty->isFloatingType())) { // Default Globals if (RegType == RegisterType::CBuffer) S.Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_b); - else if (RegType != RegisterType::C) + else if (RegType != RegisterType::C) { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + IsValid = false; + } } else { if (RegType == RegisterType::C) S.Diag(ArgLoc, diag::warn_hlsl_register_type_c_packoffset); - else + else { S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + IsValid = false; + } } - } else if (Ty->isRecordType()) { - // Class/struct types - walk the declaration and check each field and - // subclass - if (!ContainsResourceForRegisterType(S, Ty->getAs<RecordType>(), RegType)) - S.Diag(D->getLocation(), diag::warn_hlsl_user_defined_type_missing_member) - << RegTypeNum; - } else { - // Anything else is an error - S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + return IsValid; } + if (Ty->isRecordType()) + // RecordTypes will be diagnosed in ProcessResourceBindingOnDecl + // that is called from ActOnVariableDeclarator + return true; + + // Anything else is an error + S.Diag(ArgLoc, diag::err_hlsl_binding_type_mismatch) << RegTypeNum; + return false; } -static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, +static bool ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType regType) { // make sure that there are no two register annotations // applied to the decl with the same register type @@ -1135,21 +1147,19 @@ static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *TheDecl, RegisterType otherRegType = getRegisterType(attr->getSlot()); if (RegisterTypesDetected[static_cast<int>(otherRegType)]) { - if (PreviousConflicts[TheDecl].count(otherRegType)) - continue; int otherRegTypeNum = static_cast<int>(otherRegType); S.Diag(TheDecl->getLocation(), diag::err_hlsl_duplicate_register_annotation) << otherRegTypeNum; - PreviousConflicts[TheDecl].insert(otherRegType); - } else { - RegisterTypesDetected[static_cast<int>(otherRegType)] = true; + return false; } + RegisterTypesDetected[static_cast<int>(otherRegType)] = true; } } + return true; } -static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, +static bool DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, Decl *D, RegisterType RegType, bool SpecifiedSpace) { @@ -1159,10 +1169,11 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, "expecting VarDecl or HLSLBufferDecl"); // check if the declaration contains resource matching the register type - CheckContainsResourceForRegisterType(S, ArgLoc, D, RegType, SpecifiedSpace); + if (!DiagnoseLocalRegisterBinding(S, ArgLoc, D, RegType, SpecifiedSpace)) + return false; // next, if multiple register annotations exist, check that none conflict. - ValidateMultipleRegisterAnnotations(S, D, RegType); + return ValidateMultipleRegisterAnnotations(S, D, RegType); } void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { @@ -1203,23 +1214,24 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { Slot = Str; } - RegisterType regType; + RegisterType RegType; + unsigned SlotNum = 0; + unsigned SpaceNum = 0; // Validate. if (!Slot.empty()) { - regType = getRegisterType(Slot); - if (regType == RegisterType::I) { + RegType = getRegisterType(Slot); + if (RegType == RegisterType::I) { Diag(ArgLoc, diag::warn_hlsl_deprecated_register_type_i); return; } - if (regType == RegisterType::Invalid) { + if (RegType == RegisterType::Invalid) { Diag(ArgLoc, diag::err_hlsl_binding_type_invalid) << Slot.substr(0, 1); return; } - StringRef SlotNum = Slot.substr(1); - unsigned Num = 0; - if (SlotNum.getAsInteger(10, Num)) { + StringRef SlotNumStr = Slot.substr(1); + if (SlotNumStr.getAsInteger(10, SlotNum)) { Diag(ArgLoc, diag::err_hlsl_unsupported_register_number); return; } @@ -1229,20 +1241,22 @@ void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space; return; } - StringRef SpaceNum = Space.substr(5); - unsigned Num = 0; - if (SpaceNum.getAsInteger(10, Num)) { + StringRef SpaceNumStr = Space.substr(5); + if (SpaceNumStr.getAsInteger(10, SpaceNum)) { Diag(SpaceArgLoc, diag::err_hlsl_expected_space) << Space; return; } - DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, regType, - SpecifiedSpace); + if (!DiagnoseHLSLRegisterAttribute(SemaRef, ArgLoc, TheDecl, RegType, + SpecifiedSpace)) + return; HLSLResourceBindingAttr *NewAttr = HLSLResourceBindingAttr::Create(getASTContext(), Slot, Space, AL); - if (NewAttr) + if (NewAttr) { + NewAttr->setBinding(RegType, SlotNum, SpaceNum); TheDecl->addAttr(NewAttr); + } } void SemaHLSL::handleParamModifierAttr(Decl *D, const ParsedAttr &AL) { @@ -2228,3 +2242,62 @@ QualType SemaHLSL::getInoutParameterType(QualType Ty) { Ty.addRestrict(); return Ty; } + +// Walks though existing explicit bindings, finds the actual resource class +// decl the binding applies to and sets it to attr->ResourceField. +// Additional processing of resource binding can be added here later on, +// such as preparation for overapping resource detection or implicit binding. +void SemaHLSL::ProcessResourceBindingOnDecl(VarDecl *D) { + if (!D->hasGlobalStorage()) + return; + + for (Attr *A : D->attrs()) { + HLSLResourceBindingAttr *RBA = dyn_cast<HLSLResourceBindingAttr>(A); + if (!RBA) + continue; + + // // Cbuffers and Tbuffers are HLSLBufferDecl types + if (const HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D)) { + assert(RBA->getRegisterType() == + getRegisterType(CBufferOrTBuffer->isCBuffer() + ? ResourceClass::CBuffer + : ResourceClass::SRV) && + "this should have been handled in DiagnoseLocalRegisterBinding"); + // should we handle HLSLBufferDecl here? + continue; + } + + // Samplers, UAVs, and SRVs are VarDecl types + assert(isa<VarDecl>(D) && "D is expected to be VarDecl or HLSLBufferDecl"); + const VarDecl *VD = cast<VarDecl>(D); + + // Register binding directly on global resource class variable + if (const HLSLAttributedResourceType *AttrResType = + FindHandleTypeOnResource(VD)) { + // FIXME: if array, calculate the binding size from the array dimensions + // (or unbounded for unsized array) + assert(RBA->getResourceField() == nullptr); + continue; + } + + // Global array + const clang::Type *Ty = VD->getType().getTypePtr(); + while (Ty->isArrayType()) { + Ty = Ty->getArrayElementTypeNoTypeQual(); + } + + // Basic types + if (Ty->isArithmeticType()) { + continue; + } + + if (Ty->isRecordType()) { + if (!ProcessResourceBindingOnUserRecordDecl(Ty->getAs<RecordType>(), + RBA)) { + SemaRef.Diag(D->getLocation(), + diag::warn_hlsl_user_defined_type_missing_member) + << static_cast<int>(RBA->getRegisterType()); + } + } + } +} diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl index ea2d576e4cca55..40517f393e1284 100644 --- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl +++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl @@ -106,7 +106,6 @@ struct Eg12{ MySRV s1; MySRV s2; }; -// expected-warning@+3{{binding type 'u' only applies to types containing UAV resources}} // expected-warning@+2{{binding type 'u' only applies to types containing UAV resources}} // expected-error@+1{{binding type 'u' cannot be applied more than once}} Eg12 e12 : register(u9) : register(u10); @@ -115,12 +114,14 @@ struct Eg13{ MySRV s1; MySRV s2; }; -// expected-warning@+4{{binding type 'u' only applies to types containing UAV resources}} // expected-warning@+3{{binding type 'u' only applies to types containing UAV resources}} -// expected-warning@+2{{binding type 'u' only applies to types containing UAV resources}} +// expected-error@+2{{binding type 'u' cannot be applied more than once}} // expected-error@+1{{binding type 'u' cannot be applied more than once}} Eg13 e13 : register(u9) : register(u10) : register(u11); +// expected-error@+1{{binding type 't' cannot be applied more than once}} +Eg13 e13_2 : register(t11) : register(t12); + struct Eg14{ MyTemplatedUAV<int> r1; }; @@ -132,4 +133,3 @@ struct Eg15 { }; // expected no error Eg15 e15 : register(c0); - `````````` </details> https://github.com/llvm/llvm-project/pull/111203 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits