================ @@ -985,88 +1026,92 @@ SemaHLSL::TakeLocForHLSLAttribute(const HLSLAttributedResourceType *RT) { return LocInfo; } -// get the record decl from a var decl that we expect -// represents a resource -static CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) { - const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); - assert(Ty && "Resource must have an element type."); - - if (Ty->isBuiltinType()) - return nullptr; - - CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl(); - assert(TheRecordDecl && "Resource should have a resource type declaration."); - return TheRecordDecl; -} - +// Returns handle type of a resource, if the type is a resource static const HLSLAttributedResourceType * -findAttributedResourceTypeOnField(VarDecl *VD) { - assert(VD != nullptr && "expected VarDecl"); - if (RecordDecl *RD = getRecordDeclFromVarDecl(VD)) { - for (auto *FD : RD->fields()) { - if (const HLSLAttributedResourceType *AttrResType = - dyn_cast<HLSLAttributedResourceType>(FD->getType().getTypePtr())) - return AttrResType; +FindHandleTypeOnResource(const Type *Ty) { + // If Ty is a resource class, the first field must + // be the resource handle of type HLSLAttributedResourceType + if (RecordDecl *RD = Ty->getAsCXXRecordDecl()) { + if (!RD->fields().empty()) { + const auto &FirstFD = RD->fields().begin(); + return dyn_cast<HLSLAttributedResourceType>( + FirstFD->getType().getTypePtr()); } } return nullptr; } -// Iterate over RecordType fields and return true if any of them matched the -// register type -static bool ContainsResourceForRegisterType(Sema &S, const RecordType *RT, - RegisterType RegType) { - llvm::SmallVector<const Type *> TypesToScan; - TypesToScan.emplace_back(RT); - - while (!TypesToScan.empty()) { - const Type *T = TypesToScan.pop_back_val(); - while (T->isArrayType()) - T = T->getArrayElementTypeNoTypeQual(); - if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { - if (RegType == RegisterType::C) - return true; +// Returns handle type of a resource, if the VarDecl is a resource +static const HLSLAttributedResourceType * +FindHandleTypeOnResource(const VarDecl *VD) { + assert(VD != nullptr && "expected VarDecl"); + return FindHandleTypeOnResource(VD->getType().getTypePtr()); +} + +// Walks though the global variable declaration, collects all resource binding +// requirements and adds them to Bindings +void SemaHLSL::FindResourcesOnUserRecordDecl(const VarDecl *VD, + const RecordType *RT, int Size) { + const RecordDecl *RD = RT->getDecl(); + for (FieldDecl *FD : RD->fields()) { + const Type *Ty = FD->getType()->getUnqualifiedDesugaredType(); + + // Calculate array size and unwrap + int ArraySize = 1; + assert(!Ty->isIncompleteArrayType() && + "incomplete arrays inside user defined types are not supported"); + while (Ty->isConstantArrayType()) { + const ConstantArrayType *CAT = cast<ConstantArrayType>(Ty); + ArraySize *= CAT->getSize().getSExtValue(); + Ty = CAT->getElementType()->getUnqualifiedDesugaredType(); } - const RecordType *RT = T->getAs<RecordType>(); - if (!RT) + + if (!Ty->isRecordType()) continue; - const RecordDecl *RD = RT->getDecl(); - for (FieldDecl *FD : RD->fields()) { - const Type *FieldTy = FD->getType().getTypePtr(); - if (const HLSLAttributedResourceType *AttrResType = - dyn_cast<HLSLAttributedResourceType>(FieldTy)) { - ResourceClass RC = AttrResType->getAttrs().ResourceClass; - if (getRegisterType(RC) == RegType) - return true; - } else { - TypesToScan.emplace_back(FD->getType().getTypePtr()); - } + // Field is a resource or array of resources + if (const HLSLAttributedResourceType *AttrResType = + FindHandleTypeOnResource(Ty)) { + ResourceClass RC = AttrResType->getAttrs().ResourceClass; + + // Add a new DeclBindingInfo to Bindings. Update the binding size if + // a binding info already exists (there are multiple resources of same + // resource class in this user decl) + if (auto *DBI = Bindings.getDeclBindingInfo(VD, RC)) + DBI->Size += Size * ArraySize; + else + Bindings.addDeclBindingInfo(VD, RC, Size); + } else if (const RecordType *RT = dyn_cast<RecordType>(Ty)) { + // Recursively scan embedded struct or class; it would be nice to do this + // without recursion, but tricky to corrently calculate the size. + // Hopefully nesting of structs in structs too many levels is unlikely. + FindResourcesOnUserRecordDecl(VD, RT, Size); } } - return false; } -static void CheckContainsResourceForRegisterType(Sema &S, - SourceLocation &ArgLoc, - Decl *D, RegisterType RegType, - bool SpecifiedSpace) { +// return false if the register binding is not valid +static bool DiagnoseLocalRegisterBinding(Sema &S, SourceLocation &ArgLoc, + Decl *D, RegisterType RegType, + bool SpecifiedSpace) { ---------------- bogner wrote:
If we're going to add a comment here it should really say what the function does, not just what the return value means. https://github.com/llvm/llvm-project/pull/111203 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits