================ @@ -437,7 +460,406 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) { D->addAttr(NewAttr); } +struct register_binding_flags { + bool resource = false; + bool udt = false; + bool other = false; + bool basic = false; + + bool srv = false; + bool uav = false; + bool cbv = false; + bool sampler = false; + + bool contains_numeric = false; + bool default_globals = false; +}; + +bool isDeclaredWithinCOrTBuffer(const Decl *decl) { + if (!decl) + return false; + + // Traverse up the parent contexts + const DeclContext *context = decl->getDeclContext(); + while (context) { + if (isa<HLSLBufferDecl>(context)) { + return true; + } + context = context->getParent(); + } + + return false; +} + +const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *SamplerUAVOrSRV) { + const Type *Ty = SamplerUAVOrSRV->getType()->getPointeeOrArrayElementType(); + if (!Ty) + llvm_unreachable("Resource class must have an element type."); + + if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) { + return nullptr; + } + + const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl(); + if (!TheRecordDecl) + llvm_unreachable("Resource class should have a resource type declaration."); + + if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl)) + TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + TheRecordDecl = TheRecordDecl->getCanonicalDecl(); + return TheRecordDecl; +} + +const HLSLResourceAttr * +getHLSLResourceAttrFromEitherDecl(VarDecl *SamplerUAVOrSRV, + HLSLBufferDecl *CBufferOrTBuffer) { + + if (SamplerUAVOrSRV) { + const CXXRecordDecl *TheRecordDecl = + getRecordDeclFromVarDecl(SamplerUAVOrSRV); + if (!TheRecordDecl) + return nullptr; + const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>(); + return Attr; + } else if (CBufferOrTBuffer) { + const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>(); + return Attr; + } + llvm_unreachable("one of the two conditions should be true."); + return nullptr; +} + +void traverseType(QualType T, register_binding_flags &r) { + if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { + r.contains_numeric = true; + return; + } else if (const RecordType *RT = T->getAs<RecordType>()) { + RecordDecl *SubRD = RT->getDecl(); + if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) { + auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + TheRecordDecl = TheRecordDecl->getCanonicalDecl(); + const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>(); + llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass(); + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: { + r.srv = true; + break; + } + case llvm::hlsl::ResourceClass::UAV: { + r.uav = true; + break; + } + case llvm::hlsl::ResourceClass::CBuffer: { + r.cbv = true; + break; + } + case llvm::hlsl::ResourceClass::Sampler: { + r.sampler = true; + break; + } + } + } + + else if (SubRD->isCompleteDefinition()) { + for (auto Field : SubRD->fields()) { + QualType T = Field->getType(); + traverseType(T, r); + } + } + } +} + +void setResourceClassFlagsFromRecordDecl(register_binding_flags &r, + const RecordDecl *RD) { + if (!RD) + return; + + if (RD->isCompleteDefinition()) { + for (auto Field : RD->fields()) { + QualType T = Field->getType(); + traverseType(T, r); + } + } +} + +register_binding_flags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) { + register_binding_flags r; + if (!isDeclaredWithinCOrTBuffer(D)) { + // make sure the type is a basic / numeric type + if (VarDecl *v = dyn_cast<VarDecl>(D)) { + QualType t = v->getType(); + // a numeric variable will inevitably end up in $Globals buffer + if (t->isIntegralType(S.getASTContext()) || t->isFloatingType()) + r.default_globals = true; + } + } + // Cbuffers and Tbuffers are HLSLBufferDecl types + HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D); + // Samplers, UAVs, and SRVs are VarDecl types + VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D); + + if (CBufferOrTBuffer) { + r.resource = true; + if (CBufferOrTBuffer->isCBuffer()) + r.cbv = true; + else + r.srv = true; + } else if (SamplerUAVOrSRV) { + const HLSLResourceAttr *res_attr = + getHLSLResourceAttrFromEitherDecl(SamplerUAVOrSRV, CBufferOrTBuffer); + if (res_attr) { + llvm::hlsl::ResourceClass DeclResourceClass = + res_attr->getResourceClass(); + r.resource = true; + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: { + r.srv = true; + break; + } + case llvm::hlsl::ResourceClass::UAV: { + r.uav = true; + break; + } + case llvm::hlsl::ResourceClass::CBuffer: { + r.cbv = true; + break; + } + case llvm::hlsl::ResourceClass::Sampler: { + r.sampler = true; + break; + } + } + } else { + if (SamplerUAVOrSRV->getType()->isBuiltinType()) + r.basic = true; + else if (SamplerUAVOrSRV->getType()->isAggregateType()) { + r.udt = true; + QualType VarType = SamplerUAVOrSRV->getType(); + if (const RecordType *RT = VarType->getAs<RecordType>()) { + const RecordDecl *RD = RT->getDecl(); + // recurse through members, set appropriate resource class flags. + setResourceClassFlagsFromRecordDecl(r, RD); + } + } else + r.other = true; + } + } else { + llvm_unreachable("unknown decl type"); + } + return r; +} + +static void ValidateMultipleRegisterAnnotations(Sema &S, Decl *D, + StringRef &Slot) { + // make sure that there are no register annotations applied to the decl + // with the same register type but different numbers + std::unordered_map<char, std::set<char>> + s; // store unique register type + numbers + std::set<char> starting_set = {Slot[1]}; + s.insert(std::make_pair(Slot[0], starting_set)); + for (auto it = D->attr_begin(); it != D->attr_end(); ++it) { + if (HLSLResourceBindingAttr *attr = + dyn_cast<HLSLResourceBindingAttr>(*it)) { + std::string otherSlot(attr->getSlot().data()); + + // insert into hash map + if (s.find(otherSlot[0]) != s.end()) { + // if the register type is already in the map, insert the number + // into the set (if it's not already there + s[otherSlot[0]].insert(otherSlot[1]); + } else { + // if the register type is not in the map, insert it with the number + std::set<char> otherSet; + otherSet.insert(otherSlot[1]); + s.insert(std::make_pair(otherSlot[0], otherSet)); + } + } + } + + for (auto regType : s) { + if (regType.second.size() > 1) { + std::string regTypeStr(1, regType.first); + S.Diag(D->getLocation(), diag::err_hlsl_conflicting_register_annotations) + << regTypeStr; + } + } +} + +static void DiagnoseHLSLResourceRegType(Sema &S, SourceLocation &ArgLoc, + Decl *D, StringRef &Slot) { + + // Samplers, UAVs, and SRVs are VarDecl types + VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D); + // Cbuffers and Tbuffers are HLSLBufferDecl types + HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D); + + // exactly one of these two types should be set + if (!SamplerUAVOrSRV && !CBufferOrTBuffer) + return; + if (SamplerUAVOrSRV && CBufferOrTBuffer) + return; + + register_binding_flags f = HLSLFillRegisterBindingFlags(S, D); + assert((int)f.other + (int)f.resource + (int)f.basic + (int)f.udt == 1 && + "only one resource analysis result should be expected"); + + // get the variable type + std::string typestr; + if (SamplerUAVOrSRV) { + QualType QT = SamplerUAVOrSRV->getType(); + PrintingPolicy PP = S.getPrintingPolicy(); + typestr = QualType::getAsString(QT.split(), PP); + } else + typestr = CBufferOrTBuffer->isCBuffer() ? "cbuffer" : "tbuffer"; + + std::string registerType(Slot.substr(0, 1)); + + // first, if "other" is set, emit an error + if (f.other) { + S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_type_and_variable_type) + << Slot << typestr; + return; + } + + // next, if multiple register annotations exist, check that none conflict. + ValidateMultipleRegisterAnnotations(S, D, Slot); + + // next, if resource is set, make sure the register type in the register + // annotation is compatible with the variable's resource type. + if (f.resource) { + const HLSLResourceAttr *res_attr = + getHLSLResourceAttrFromEitherDecl(SamplerUAVOrSRV, CBufferOrTBuffer); + assert(res_attr && "any decl that set the resource flag on analysis should " + "have a resource attribute attached."); ---------------- damyanp wrote:
How does one decide whether to use `llvm_unreachable` or `assert`? https://github.com/llvm/llvm-project/pull/97103 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits