llvmbot wrote:
<!--LLVM PR SUMMARY COMMENT--> @llvm/pr-subscribers-clang Author: Helena Kotas (hekota) <details> <summary>Changes</summary> Adjust register binding diagnostic flags code in a couple of ways: - Store the resource class in the Flags struct to avoid duplicated scanning for HLSLResourceClassAttribute - Avoid unnecessary indirection when converting resource class to register type - Remove recursion and reduce duplicated code Also fixes a case where struct with an array was incorrectly diagnosed unfit for `c` register binding. This will also simplify work that is needed to be done in this area for llvm/llvm-project#<!-- -->104861. --- Full diff: https://github.com/llvm/llvm-project/pull/106657.diff 2 Files Affected: - (modified) clang/lib/Sema/SemaHLSL.cpp (+68-113) - (modified) clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl (+7) ``````````diff diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index 714e8f5cfa9926..1e484f754b931d 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -480,6 +480,9 @@ struct RegisterBindingFlags { bool ContainsNumeric = false; bool DefaultGlobals = false; + + // used only when Resource == true + llvm::dxil::ResourceClass ResourceClass = llvm::dxil::ResourceClass::UAV; }; static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) { @@ -545,65 +548,38 @@ static const T *getSpecifiedHLSLAttrFromVarDecl(VarDecl *VD) { return getSpecifiedHLSLAttrFromRecordDecl<T>(TheRecordDecl); } -static void updateFlagsFromType(QualType TheQualTy, - RegisterBindingFlags &Flags); - -static void updateResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags, - const RecordDecl *RD) { - if (!RD) - return; - - if (RD->isCompleteDefinition()) { - for (auto Field : RD->fields()) { - QualType T = Field->getType(); - updateFlagsFromType(T, Flags); +static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags, + const RecordType *RT) { + llvm::SmallVector<const Type *> TypesToScan; + TypesToScan.emplace_back(RT); + + while (!TypesToScan.empty()) { + const Type *T = TypesToScan.pop_back_val(); + while (T->isArrayType()) + T = T->getArrayElementTypeNoTypeQual(); + if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { + Flags.ContainsNumeric = true; + continue; } - } -} - -static void updateFlagsFromType(QualType TheQualTy, - RegisterBindingFlags &Flags) { - // if the member's type is a numeric type, set the ContainsNumeric flag - if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) { - Flags.ContainsNumeric = true; - return; - } - - const clang::Type *TheBaseType = TheQualTy.getTypePtr(); - while (TheBaseType->isArrayType()) - TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); - // otherwise, if the member's base type is not a record type, return - const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>(); - if (!TheRecordTy) - return; - - RecordDecl *SubRecordDecl = TheRecordTy->getDecl(); - const HLSLResourceClassAttr *Attr = - getSpecifiedHLSLAttrFromRecordDecl<HLSLResourceClassAttr>(SubRecordDecl); - // find the attr if it's on the member, or on any of the member's fields - if (Attr) { - llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass(); - updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass); - } + const RecordType *RT = T->getAs<RecordType>(); + if (!RT) + continue; - // otherwise, dig deeper and recurse into the member - else { - updateResourceClassFlagsFromRecordDecl(Flags, SubRecordDecl); + const RecordDecl *RD = RT->getDecl(); + for (FieldDecl *FD : RD->fields()) { + if (HLSLResourceClassAttr *RCAttr = + FD->getAttr<HLSLResourceClassAttr>()) { + updateResourceClassFlagsFromDeclResourceClass( + Flags, RCAttr->getResourceClass()); + continue; + } + TypesToScan.emplace_back(FD->getType().getTypePtr()); + } } } static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, Decl *TheDecl) { - - // Cbuffers and Tbuffers are HLSLBufferDecl types - HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl); - // Samplers, UAVs, and SRVs are VarDecl types - VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl); - - assert(((TheVarDecl && !CBufferOrTBuffer) || - (!TheVarDecl && CBufferOrTBuffer)) && - "either TheVarDecl or CBufferOrTBuffer should be set"); - RegisterBindingFlags Flags; // check if the decl type is groupshared @@ -612,57 +588,61 @@ static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, return Flags; } - if (!isDeclaredWithinCOrTBuffer(TheDecl)) { - // make sure the type is a basic / numeric type - if (TheVarDecl) { - QualType TheQualTy = TheVarDecl->getType(); - // a numeric variable or an array of numeric variables - // will inevitably end up in $Globals buffer - const clang::Type *TheBaseType = TheQualTy.getTypePtr(); - while (TheBaseType->isArrayType()) - TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); - if (TheBaseType->isIntegralType(S.getASTContext()) || - TheBaseType->isFloatingType()) - Flags.DefaultGlobals = true; - } - } - - if (CBufferOrTBuffer) { + // Cbuffers and Tbuffers are HLSLBufferDecl types + if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) { Flags.Resource = true; - if (CBufferOrTBuffer->isCBuffer()) - Flags.CBV = true; - else - Flags.SRV = true; - } else if (TheVarDecl) { + Flags.ResourceClass = CBufferOrTBuffer->isCBuffer() + ? llvm::dxil::ResourceClass::CBuffer + : llvm::dxil::ResourceClass::SRV; + } + // Samplers, UAVs, and SRVs are VarDecl types + else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) { const HLSLResourceClassAttr *resClassAttr = getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl); - if (resClassAttr) { - llvm::hlsl::ResourceClass DeclResourceClass = - resClassAttr->getResourceClass(); Flags.Resource = true; - updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass); + Flags.ResourceClass = resClassAttr->getResourceClass(); } else { const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr(); while (TheBaseType->isArrayType()) TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); - if (TheBaseType->isArithmeticType()) + + if (TheBaseType->isArithmeticType()) { Flags.Basic = true; - else if (TheBaseType->isRecordType()) { + if (!isDeclaredWithinCOrTBuffer(TheDecl) && + (TheBaseType->isIntegralType(S.getASTContext()) || + TheBaseType->isFloatingType())) + Flags.DefaultGlobals = true; + } else if (TheBaseType->isRecordType()) { Flags.UDT = true; const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>(); - assert(TheRecordTy && "The Qual Type should be Record Type"); - const RecordDecl *TheRecordDecl = TheRecordTy->getDecl(); - // recurse through members, set appropriate resource class flags. - updateResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl); + updateResourceClassFlagsFromRecordType(Flags, TheRecordTy); } else Flags.Other = true; } + } else { + llvm_unreachable("expected be VarDecl or HLSLBufferDecl"); } return Flags; } -enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; +enum class RegisterType { + SRV = static_cast<int>(llvm::dxil::ResourceClass::SRV), + UAV = static_cast<int>(llvm::dxil::ResourceClass::UAV), + CBuffer = static_cast<int>(llvm::dxil::ResourceClass::CBuffer), + Sampler = static_cast<int>(llvm::dxil::ResourceClass::Sampler), + C, + I, + Invalid +}; + +static RegisterType +convertResourceClassToRegisterType(llvm::dxil::ResourceClass RC) { + assert(RC >= llvm::dxil::ResourceClass::SRV && + RC <= llvm::dxil::ResourceClass::Sampler && + "unexpected resource class value"); + return static_cast<RegisterType>(RC); +} static RegisterType getRegisterType(StringRef Slot) { switch (Slot[0]) { @@ -754,34 +734,9 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, // next, if resource is set, make sure the register type in the register // annotation is compatible with the variable's resource type. if (Flags.Resource) { - const HLSLResourceClassAttr *resClassAttr = nullptr; - if (CBufferOrTBuffer) { - resClassAttr = CBufferOrTBuffer->getAttr<HLSLResourceClassAttr>(); - } else if (TheVarDecl) { - resClassAttr = - getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl); - } - - assert(resClassAttr && - "any decl that set the resource flag on analysis should " - "have a resource class attribute attached."); - const llvm::hlsl::ResourceClass DeclResourceClass = - resClassAttr->getResourceClass(); - - // confirm that the register type is bound to its expected resource class - static RegisterType ExpectedRegisterTypesForResourceClass[] = { - RegisterType::SRV, - RegisterType::UAV, - RegisterType::CBuffer, - RegisterType::Sampler, - }; - assert((size_t)DeclResourceClass < - std::size(ExpectedRegisterTypesForResourceClass) && - "DeclResourceClass has unexpected value"); - - RegisterType ExpectedRegisterType = - ExpectedRegisterTypesForResourceClass[(int)DeclResourceClass]; - if (regType != ExpectedRegisterType) { + RegisterType expRegType = + convertResourceClassToRegisterType(Flags.ResourceClass); + if (regType != expRegType) { S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch) << regTypeNum; } @@ -823,7 +778,7 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, } void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { - if (dyn_cast<VarDecl>(TheDecl)) { + if (isa<VarDecl>(TheDecl)) { if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), cast<ValueDecl>(TheDecl)->getType(), diag::err_incomplete_type)) diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl index f8e38b6d2851d9..edb3f30739cdfd 100644 --- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl +++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl @@ -126,3 +126,10 @@ struct Eg14{ }; // expected-warning@+1{{binding type 't' only applies to types containing SRV resources}} Eg14 e14 : register(t9); + +struct Eg15 { + float f[4]; +}; +// expected no error +Eg15 e15 : register(c0); + `````````` </details> https://github.com/llvm/llvm-project/pull/106657 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits