================
@@ -437,7 +460,206 @@ void SemaHLSL::handleShaderAttr(Decl *D, const ParsedAttr 
&AL) {
     D->addAttr(NewAttr);
 }
 
+struct register_binding_flags {
+  bool resource = false;
+  bool udt = false;
+  bool other = false;
+  bool basic = false;
+
+  bool srv = false;
+  bool uav = false;
+  bool cbv = false;
+  bool sampler = false;
+
+  bool contains_numeric = false;
+  bool default_globals = false;
+};
+
+bool isDeclaredWithinCOrTBuffer(const Decl *decl) {
+  if (!decl)
+    return false;
+
+  // Traverse up the parent contexts
+  const DeclContext *context = decl->getDeclContext();
+  while (context) {
+    if (isa<HLSLBufferDecl>(context)) {
+      return true;
+    }
+    context = context->getParent();
+  }
+
+  return false;
+}
+
+const HLSLResourceAttr *
+getHLSLResourceAttrFromVarDecl(VarDecl *SamplerUAVOrSRV) {
+  const Type *Ty = SamplerUAVOrSRV->getType()->getPointeeOrArrayElementType();
+  if (!Ty)
+    llvm_unreachable("Resource class must have an element type.");
+
+  if (const BuiltinType *BTy = dyn_cast<BuiltinType>(Ty)) {
+    /* QualType QT = SamplerUAVOrSRV->getType();
+    PrintingPolicy PP = S.getPrintingPolicy();
+    std::string typestr = QualType::getAsString(QT.split(), PP);
+
+    S.Diag(ArgLoc, diag::err_hlsl_unsupported_register_resource_type)
+        << typestr;
+    return; */
+    return nullptr;
+  }
+
+  const CXXRecordDecl *TheRecordDecl = Ty->getAsCXXRecordDecl();
+  if (!TheRecordDecl)
+    llvm_unreachable("Resource class should have a resource type 
declaration.");
+
+  if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(TheRecordDecl))
+    TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+  TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+  const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+  return Attr;
+}
+
+void traverseType(QualType T, register_binding_flags &r) {
+  if (T->isIntegralOrEnumerationType() || T->isFloatingType()) {
+    r.contains_numeric = true;
+    return;
+  } else if (const RecordType *RT = T->getAs<RecordType>()) {
+    RecordDecl *SubRD = RT->getDecl();
+    if (auto TDecl = dyn_cast<ClassTemplateSpecializationDecl>(SubRD)) {
+      auto TheRecordDecl = TDecl->getSpecializedTemplate()->getTemplatedDecl();
+      TheRecordDecl = TheRecordDecl->getCanonicalDecl();
+      const auto *Attr = TheRecordDecl->getAttr<HLSLResourceAttr>();
+      llvm::hlsl::ResourceClass DeclResourceClass = Attr->getResourceClass();
+      switch (DeclResourceClass) {
+      case llvm::hlsl::ResourceClass::SRV: {
+        r.srv = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::UAV: {
+        r.uav = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::CBuffer: {
+        r.cbv = true;
+        break;
+      }
+      case llvm::hlsl::ResourceClass::Sampler: {
+        r.sampler = true;
+        break;
+      }
+      }
+    }
+
+    else if (SubRD->isCompleteDefinition()) {
+      for (auto Field : SubRD->fields()) {
+        QualType T = Field->getType();
+        traverseType(T, r);
+      }
+    }
+  }
+}
+
+void setResourceClassFlagsFromRecordDecl(register_binding_flags &r,
+                                         const RecordDecl *RD) {
+  if (!RD)
+    return;
+
+  if (RD->isCompleteDefinition()) {
+    for (auto Field : RD->fields()) {
+      QualType T = Field->getType();
+      traverseType(T, r);
+    }
+  }
+}
+
+register_binding_flags HLSLFillRegisterBindingFlags(Sema &S, Decl *D) {
----------------
damyanp wrote:

How is this tested?

https://github.com/llvm/llvm-project/pull/97103
_______________________________________________
cfe-commits mailing list
cfe-commits@lists.llvm.org
https://lists.llvm.org/cgi-bin/mailman/listinfo/cfe-commits

Reply via email to