================ @@ -437,7 +444,419 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) { D->addAttr(NewAttr); } +struct RegisterBindingFlags { + bool Resource = false; + bool UDT = false; + bool Other = false; + bool Basic = false; + + bool SRV = false; + bool UAV = false; + bool CBV = false; + bool Sampler = false; + + bool ContainsNumeric = false; + bool DefaultGlobals = false; +}; + +bool isDeclaredWithinCOrTBuffer(const Decl *decl) { + if (!decl) + return false; + + // Traverse up the parent contexts + const DeclContext *context = decl->getDeclContext(); + if (isa<HLSLBufferDecl>(context)) { + return true; + } + + return false; +} + +const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *VD) { + const Type *Ty = VD->getType()->getPointeeOrArrayElementType(); + assert(Ty && "Resource class must have an element type."); + + if (const auto *BTy = dyn_cast<BuiltinType>(Ty)) + return nullptr; + + const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl(); + assert(TheRecordDecl && + "Resource class should have a resource type declaration."); + + if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl)) + TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + TheRecordDecl = TheRecordDecl->getCanonicalDecl(); + return TheRecordDecl; +} + +const HLSLResourceAttr * +getHLSLResourceAttrFromEitherDecl(VarDecl *VD, + HLSLBufferDecl *CBufferOrTBuffer) { + + if (VD) { + const CXXRecordDecl *TheRecordDecl = getRecordDeclFromVarDecl(VD); + if (!TheRecordDecl) + return nullptr; + const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>(); + return Attr; + } else if (CBufferOrTBuffer) { + const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>(); + return Attr; + } + llvm_unreachable("one of the two conditions should be true."); + return nullptr; +} + +void traverseType(QualType T, RegisterBindingFlags &r) { + if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { + r.ContainsNumeric = true; + return; + } + const RecordType *RT = T->getAs<RecordType>(); + if (!RT) + return; + + RecordDecl *SubRD = RT->getDecl(); + if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) { + auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + TheRecordDecl = TheRecordDecl->getCanonicalDecl(); + const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>(); + llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass(); + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: + r.SRV = true; + break; + case llvm::hlsl::ResourceClass::UAV: + r.UAV = true; + break; + case llvm::hlsl::ResourceClass::CBuffer: + r.CBV = true; + break; + case llvm::hlsl::ResourceClass::Sampler: + r.Sampler = true; + break; + } + } + + else if (SubRD->isCompleteDefinition()) { + for (auto Field : SubRD->fields()) { + QualType T = Field->getType(); + traverseType(T, r); + } + } +} + +void setResourceClassFlagsFromRecordDecl(RegisterBindingFlags &r, + const RecordDecl *RD) { + if (!RD) + return; + + if (RD->isCompleteDefinition()) { + for (auto Field : RD->fields()) { + QualType T = Field->getType(); + traverseType(T, r); + } + } +} + +RegisterBindingFlags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) { + RegisterBindingFlags r; + if (!isDeclaredWithinCOrTBuffer(D)) { + // make sure the type is a basic / numeric type + if (VarDecl *v = dyn_cast<VarDecl>(D)) { + QualType t = v->getType(); + // a numeric variable will inevitably end up in $Globals buffer + if (t->isIntegralType(S.getASTContext()) || t->isFloatingType()) + r.DefaultGlobals = true; + } + } + // Cbuffers and Tbuffers are HLSLBufferDecl types + HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D); + // Samplers, UAVs, and SRVs are VarDecl types + VarDecl *VD = dyn_cast<VarDecl>(D); + + assert(((VD && !CBufferOrTBuffer) || (!VD && CBufferOrTBuffer)) && + "either VD or CBufferOrTBuffer should be set"); + + if (CBufferOrTBuffer) { + r.Resource = true; + if (CBufferOrTBuffer->isCBuffer()) + r.CBV = true; + else + r.SRV = true; + } else if (VD) { + const HLSLResourceAttr *res_attr = + getHLSLResourceAttrFromEitherDecl(VD, CBufferOrTBuffer); + if (res_attr) { + llvm::hlsl::ResourceClass DeclResourceClass = + res_attr->getResourceClass(); + r.Resource = true; + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: { + r.SRV = true; + break; + } + case llvm::hlsl::ResourceClass::UAV: { + r.UAV = true; + break; + } + case llvm::hlsl::ResourceClass::CBuffer: { + r.CBV = true; + break; + } + case llvm::hlsl::ResourceClass::Sampler: { + r.Sampler = true; + break; + } + } + } else { + if (VD->getType()->isBuiltinType()) + r.Basic = true; + else if (VD->getType()->isAggregateType()) { + r.UDT = true; + QualType VarType = VD->getType(); + if (const RecordType *RT = VarType->getAs<RecordType>()) { + const RecordDecl *RD = RT->getDecl(); + // recurse through members, set appropriate resource class flags. + setResourceClassFlagsFromRecordDecl(r, RD); + } + } else + r.Other = true; + } + } + return r; +} + +int getRegisterTypeIndex(StringRef Slot) { + switch (Slot[0]) { + case 't': + case 'T': + return 0; + case 'u': + case 'U': + return 1; + case 'b': + case 'B ': + return 2; + case 's': + case 'S': + return 3; + case 'c': + case 'C': + return 4; + case 'i': + case 'I': + return 5; + default: + llvm_unreachable("invalid register type"); + } +} + +static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *D, + StringRef &Slot) { + // make sure that there are no tworegister annotations + // applied to the decl with the same register type + bool RegisterTypesDetected[6] = {false}; + RegisterTypesDetected[getRegisterTypeIndex(Slot)] = true; + + for (auto it = D->attr_begin(); it != D->attr_end(); ++it) { + if (HLSLResourceBindingAttr *attr = + dyn_cast<HLSLResourceBindingAttr>(*it)) { + + int registerTypeIndex = getRegisterTypeIndex(attr->getSlot()); + if (RegisterTypesDetected[registerTypeIndex]) { + S.Diag(D->getLocation(), + diag::err_hlsl_conflicting_register_annotations) + << attr->getSlot().substr(0, 1); + } else { + RegisterTypesDetected[registerTypeIndex] = true; + } + } + } +} + +std::string getHLSLResourceTypeStr(Sema &S, Decl *D) { + VarDecl *VD = dyn_cast<VarDecl>(D); + HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D); + + if (VD) { + QualType QT = VD->getType(); + PrintingPolicy PP = S.getPrintingPolicy(); + return QualType::getAsString(QT.split(), PP); + } else { + return CBufferOrTBuffer->isCBuffer() ? "cbuffer" : "tbuffer"; + } +} + +static void DiagnoseHLSLResourceRegType(Sema &S, SourceLocation &ArgLoc, + Decl *D, StringRef &Slot) { + + // Samplers, UAVs, and SRVs are VarDecl types + VarDecl *VD = dyn_cast<VarDecl>(D); + // Cbuffers and Tbuffers are HLSLBufferDecl types + HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D); + + // exactly one of these two types should be set + assert(((VD && !CBufferOrTBuffer) || (!VD && CBufferOrTBuffer)) && + "either VD or CBufferOrTBuffer should be set"); + + RegisterBindingFlags f = HLSLFillRegisterBindingFlags(S, D); + assert((int)f.Other + (int)f.Resource + (int)f.Basic + (int)f.UDT == 1 && + "only one resource analysis result should be expected"); + + std::string registerType(Slot.substr(0, 1)); ---------------- damyanp wrote:
How sure are you that constructing this string doesn't allocate memory? It'd be better to only construct the `std::string` if you need it. https://github.com/llvm/llvm-project/pull/97103 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits