Author: Helena Kotas Date: 2024-09-03T14:11:06-07:00 New Revision: 334d1238aafa8ca017d433caaf8f6e00f2622111
URL: https://github.com/llvm/llvm-project/commit/334d1238aafa8ca017d433caaf8f6e00f2622111 DIFF: https://github.com/llvm/llvm-project/commit/334d1238aafa8ca017d433caaf8f6e00f2622111.diff LOG: [HLSL] Adjust resource binding diagnostic flags code (#106657) Adjust register binding diagnostic flags code in a couple of ways: - Store the resource class in the Flags struct to avoid duplicated scanning for HLSLResourceClassAttribute - Avoid unnecessary indirection when converting resource class to register type - Remove recursion and reduce duplicated code Also fixes a case where struct with an array was incorrectly diagnosed unfit for `c` register binding. This will also simplify work that is needed to be done in this area for llvm/llvm-project#104861. Added: Modified: clang/lib/Sema/SemaHLSL.cpp clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl Removed: ################################################################################ diff --git a/clang/lib/Sema/SemaHLSL.cpp b/clang/lib/Sema/SemaHLSL.cpp index fabc6f32906b10..05d2bdf8a57a27 100644 --- a/clang/lib/Sema/SemaHLSL.cpp +++ b/clang/lib/Sema/SemaHLSL.cpp @@ -612,6 +612,9 @@ struct RegisterBindingFlags { bool ContainsNumeric = false; bool DefaultGlobals = false; + + // used only when Resource == true + std::optional<llvm::dxil::ResourceClass> ResourceClass; }; static bool isDeclaredWithinCOrTBuffer(const Decl *TheDecl) { @@ -677,65 +680,38 @@ static const T *getSpecifiedHLSLAttrFromVarDecl(VarDecl *VD) { return getSpecifiedHLSLAttrFromRecordDecl<T>(TheRecordDecl); } -static void updateFlagsFromType(QualType TheQualTy, - RegisterBindingFlags &Flags); - -static void updateResourceClassFlagsFromRecordDecl(RegisterBindingFlags &Flags, - const RecordDecl *RD) { - if (!RD) - return; - - if (RD->isCompleteDefinition()) { - for (auto Field : RD->fields()) { - QualType T = Field->getType(); - updateFlagsFromType(T, Flags); +static void updateResourceClassFlagsFromRecordType(RegisterBindingFlags &Flags, + const RecordType *RT) { + llvm::SmallVector<const Type *> TypesToScan; + TypesToScan.emplace_back(RT); + + while (!TypesToScan.empty()) { + const Type *T = TypesToScan.pop_back_val(); + while (T->isArrayType()) + T = T->getArrayElementTypeNoTypeQual(); + if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { + Flags.ContainsNumeric = true; + continue; } - } -} - -static void updateFlagsFromType(QualType TheQualTy, - RegisterBindingFlags &Flags) { - // if the member's type is a numeric type, set the ContainsNumeric flag - if (TheQualTy->isIntegralOrEnumerationType() || TheQualTy->isFloatingType()) { - Flags.ContainsNumeric = true; - return; - } - - const clang::Type *TheBaseType = TheQualTy.getTypePtr(); - while (TheBaseType->isArrayType()) - TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); - // otherwise, if the member's base type is not a record type, return - const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>(); - if (!TheRecordTy) - return; - - RecordDecl *SubRecordDecl = TheRecordTy->getDecl(); - const HLSLResourceClassAttr *Attr = - getSpecifiedHLSLAttrFromRecordDecl<HLSLResourceClassAttr>(SubRecordDecl); - // find the attr if it's on the member, or on any of the member's fields - if (Attr) { - llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass(); - updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass); - } + const RecordType *RT = T->getAs<RecordType>(); + if (!RT) + continue; - // otherwise, dig deeper and recurse into the member - else { - updateResourceClassFlagsFromRecordDecl(Flags, SubRecordDecl); + const RecordDecl *RD = RT->getDecl(); + for (FieldDecl *FD : RD->fields()) { + if (HLSLResourceClassAttr *RCAttr = + FD->getAttr<HLSLResourceClassAttr>()) { + updateResourceClassFlagsFromDeclResourceClass( + Flags, RCAttr->getResourceClass()); + continue; + } + TypesToScan.emplace_back(FD->getType().getTypePtr()); + } } } static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, Decl *TheDecl) { - - // Cbuffers and Tbuffers are HLSLBufferDecl types - HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl); - // Samplers, UAVs, and SRVs are VarDecl types - VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl); - - assert(((TheVarDecl && !CBufferOrTBuffer) || - (!TheVarDecl && CBufferOrTBuffer)) && - "either TheVarDecl or CBufferOrTBuffer should be set"); - RegisterBindingFlags Flags; // check if the decl type is groupshared @@ -744,58 +720,60 @@ static RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, return Flags; } - if (!isDeclaredWithinCOrTBuffer(TheDecl)) { - // make sure the type is a basic / numeric type - if (TheVarDecl) { - QualType TheQualTy = TheVarDecl->getType(); - // a numeric variable or an array of numeric variables - // will inevitably end up in $Globals buffer - const clang::Type *TheBaseType = TheQualTy.getTypePtr(); - while (TheBaseType->isArrayType()) - TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); - if (TheBaseType->isIntegralType(S.getASTContext()) || - TheBaseType->isFloatingType()) - Flags.DefaultGlobals = true; - } - } - - if (CBufferOrTBuffer) { + // Cbuffers and Tbuffers are HLSLBufferDecl types + if (HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(TheDecl)) { Flags.Resource = true; - if (CBufferOrTBuffer->isCBuffer()) - Flags.CBV = true; - else - Flags.SRV = true; - } else if (TheVarDecl) { + Flags.ResourceClass = CBufferOrTBuffer->isCBuffer() + ? llvm::dxil::ResourceClass::CBuffer + : llvm::dxil::ResourceClass::SRV; + } + // Samplers, UAVs, and SRVs are VarDecl types + else if (VarDecl *TheVarDecl = dyn_cast<VarDecl>(TheDecl)) { const HLSLResourceClassAttr *resClassAttr = getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl); - if (resClassAttr) { - llvm::hlsl::ResourceClass DeclResourceClass = - resClassAttr->getResourceClass(); Flags.Resource = true; - updateResourceClassFlagsFromDeclResourceClass(Flags, DeclResourceClass); + Flags.ResourceClass = resClassAttr->getResourceClass(); } else { const clang::Type *TheBaseType = TheVarDecl->getType().getTypePtr(); while (TheBaseType->isArrayType()) TheBaseType = TheBaseType->getArrayElementTypeNoTypeQual(); - if (TheBaseType->isArithmeticType()) + + if (TheBaseType->isArithmeticType()) { Flags.Basic = true; - else if (TheBaseType->isRecordType()) { + if (!isDeclaredWithinCOrTBuffer(TheDecl) && + (TheBaseType->isIntegralType(S.getASTContext()) || + TheBaseType->isFloatingType())) + Flags.DefaultGlobals = true; + } else if (TheBaseType->isRecordType()) { Flags.UDT = true; const RecordType *TheRecordTy = TheBaseType->getAs<RecordType>(); - assert(TheRecordTy && "The Qual Type should be Record Type"); - const RecordDecl *TheRecordDecl = TheRecordTy->getDecl(); - // recurse through members, set appropriate resource class flags. - updateResourceClassFlagsFromRecordDecl(Flags, TheRecordDecl); + updateResourceClassFlagsFromRecordType(Flags, TheRecordTy); } else Flags.Other = true; } + } else { + llvm_unreachable("expected be VarDecl or HLSLBufferDecl"); } return Flags; } enum class RegisterType { SRV, UAV, CBuffer, Sampler, C, I, Invalid }; +static RegisterType getRegisterType(llvm::dxil::ResourceClass RC) { + switch (RC) { + case llvm::dxil::ResourceClass::SRV: + return RegisterType::SRV; + case llvm::dxil::ResourceClass::UAV: + return RegisterType::UAV; + case llvm::dxil::ResourceClass::CBuffer: + return RegisterType::CBuffer; + case llvm::dxil::ResourceClass::Sampler: + return RegisterType::Sampler; + } + llvm_unreachable("unexpected ResourceClass value"); +} + static RegisterType getRegisterType(StringRef Slot) { switch (Slot[0]) { case 't': @@ -886,34 +864,8 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, // next, if resource is set, make sure the register type in the register // annotation is compatible with the variable's resource type. if (Flags.Resource) { - const HLSLResourceClassAttr *resClassAttr = nullptr; - if (CBufferOrTBuffer) { - resClassAttr = CBufferOrTBuffer->getAttr<HLSLResourceClassAttr>(); - } else if (TheVarDecl) { - resClassAttr = - getSpecifiedHLSLAttrFromVarDecl<HLSLResourceClassAttr>(TheVarDecl); - } - - assert(resClassAttr && - "any decl that set the resource flag on analysis should " - "have a resource class attribute attached."); - const llvm::hlsl::ResourceClass DeclResourceClass = - resClassAttr->getResourceClass(); - - // confirm that the register type is bound to its expected resource class - static RegisterType ExpectedRegisterTypesForResourceClass[] = { - RegisterType::SRV, - RegisterType::UAV, - RegisterType::CBuffer, - RegisterType::Sampler, - }; - assert((size_t)DeclResourceClass < - std::size(ExpectedRegisterTypesForResourceClass) && - "DeclResourceClass has unexpected value"); - - RegisterType ExpectedRegisterType = - ExpectedRegisterTypesForResourceClass[(int)DeclResourceClass]; - if (regType != ExpectedRegisterType) { + RegisterType expRegType = getRegisterType(Flags.ResourceClass.value()); + if (regType != expRegType) { S.Diag(TheDecl->getLocation(), diag::err_hlsl_binding_type_mismatch) << regTypeNum; } @@ -955,7 +907,7 @@ static void DiagnoseHLSLRegisterAttribute(Sema &S, SourceLocation &ArgLoc, } void SemaHLSL::handleResourceBindingAttr(Decl *TheDecl, const ParsedAttr &AL) { - if (dyn_cast<VarDecl>(TheDecl)) { + if (isa<VarDecl>(TheDecl)) { if (SemaRef.RequireCompleteType(TheDecl->getBeginLoc(), cast<ValueDecl>(TheDecl)->getType(), diag::err_incomplete_type)) diff --git a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl index f8e38b6d2851d9..edb3f30739cdfd 100644 --- a/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl +++ b/clang/test/SemaHLSL/resource_binding_attr_error_udt.hlsl @@ -126,3 +126,10 @@ struct Eg14{ }; // expected-warning@+1{{binding type 't' only applies to types containing SRV resources}} Eg14 e14 : register(t9); + +struct Eg15 { + float f[4]; +}; +// expected no error +Eg15 e15 : register(c0); + _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits