================ @@ -437,7 +460,406 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr &AL) { D->addAttr(NewAttr); } +struct register_binding_flags { + bool resource = false; + bool udt = false; + bool other = false; + bool basic = false; + + bool srv = false; + bool uav = false; + bool cbv = false; + bool sampler = false; + + bool contains_numeric = false; + bool default_globals = false; +}; + +bool isDeclaredWithinCOrTBuffer(const Decl *decl) { + if (!decl) + return false; + + // Traverse up the parent contexts + const DeclContext *context = decl->getDeclContext(); + while (context) { + if (isa<HLSLBufferDecl>(context)) { + return true; + } + context = context->getParent(); + } + + return false; +} + +const CXXRecordDecl *getRecordDeclFromVarDecl(VarDecl *SamplerUAVOrSRV) { + const Type *Ty = SamplerUAVOrSRV->getType()->getPointeeOrArrayElementType(); + if (!Ty) + llvm_unreachable("Resource class must have an element type."); + + if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) { + return nullptr; + } + + const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl(); + if (!TheRecordDecl) + llvm_unreachable("Resource class should have a resource type declaration."); + + if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl)) + TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + TheRecordDecl = TheRecordDecl->getCanonicalDecl(); + return TheRecordDecl; +} + +const HLSLResourceAttr * +getHLSLResourceAttrFromEitherDecl(VarDecl *SamplerUAVOrSRV, + HLSLBufferDecl *CBufferOrTBuffer) { + + if (SamplerUAVOrSRV) { + const CXXRecordDecl *TheRecordDecl = + getRecordDeclFromVarDecl(SamplerUAVOrSRV); + if (!TheRecordDecl) + return nullptr; + const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>(); + return Attr; + } else if (CBufferOrTBuffer) { + const auto *Attr = CBufferOrTBuffer->getAttr<HLSLResourceAttr>(); + return Attr; + } + llvm_unreachable("one of the two conditions should be true."); + return nullptr; +} + +void traverseType(QualType T, register_binding_flags &r) { + if (T->isIntegralOrEnumerationType() || T->isFloatingType()) { + r.contains_numeric = true; + return; + } else if (const RecordType *RT = T->getAs<RecordType>()) { + RecordDecl *SubRD = RT->getDecl(); + if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) { + auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl(); + TheRecordDecl = TheRecordDecl->getCanonicalDecl(); + const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>(); + llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass(); + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: { + r.srv = true; + break; + } + case llvm::hlsl::ResourceClass::UAV: { + r.uav = true; + break; + } + case llvm::hlsl::ResourceClass::CBuffer: { + r.cbv = true; + break; + } + case llvm::hlsl::ResourceClass::Sampler: { + r.sampler = true; + break; + } + } + } + + else if (SubRD->isCompleteDefinition()) { + for (auto Field : SubRD->fields()) { + QualType T = Field->getType(); + traverseType(T, r); + } + } + } +} + +void setResourceClassFlagsFromRecordDecl(register_binding_flags &r, + const RecordDecl *RD) { + if (!RD) + return; + + if (RD->isCompleteDefinition()) { + for (auto Field : RD->fields()) { + QualType T = Field->getType(); + traverseType(T, r); + } + } +} + +register_binding_flags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) { + register_binding_flags r; + if (!isDeclaredWithinCOrTBuffer(D)) { + // make sure the type is a basic / numeric type + if (VarDecl *v = dyn_cast<VarDecl>(D)) { + QualType t = v->getType(); + // a numeric variable will inevitably end up in $Globals buffer + if (t->isIntegralType(S.getASTContext()) || t->isFloatingType()) + r.default_globals = true; + } + } + // Cbuffers and Tbuffers are HLSLBufferDecl types + HLSLBufferDecl *CBufferOrTBuffer = dyn_cast<HLSLBufferDecl>(D); + // Samplers, UAVs, and SRVs are VarDecl types + VarDecl *SamplerUAVOrSRV = dyn_cast<VarDecl>(D); + + if (CBufferOrTBuffer) { + r.resource = true; + if (CBufferOrTBuffer->isCBuffer()) + r.cbv = true; + else + r.srv = true; + } else if (SamplerUAVOrSRV) { + const HLSLResourceAttr *res_attr = + getHLSLResourceAttrFromEitherDecl(SamplerUAVOrSRV, CBufferOrTBuffer); + if (res_attr) { + llvm::hlsl::ResourceClass DeclResourceClass = + res_attr->getResourceClass(); + r.resource = true; + switch (DeclResourceClass) { + case llvm::hlsl::ResourceClass::SRV: { + r.srv = true; + break; + } + case llvm::hlsl::ResourceClass::UAV: { + r.uav = true; + break; + } + case llvm::hlsl::ResourceClass::CBuffer: { + r.cbv = true; + break; + } + case llvm::hlsl::ResourceClass::Sampler: { + r.sampler = true; + break; + } + } + } else { + if (SamplerUAVOrSRV->getType()->isBuiltinType()) + r.basic = true; + else if (SamplerUAVOrSRV->getType()->isAggregateType()) { ---------------- bob80905 wrote:
Renamed to VD for extra generality, though this is hungarian notation. Not sure of a better name. https://github.com/llvm/llvm-project/pull/97103 _______________________________________________ cfe-commits mailing list cfe-commits@lists.llvm.org https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits